flash: Reduce fence calls to improve util

This commit is contained in:
Hansung Kim
2024-11-09 16:44:17 -08:00
parent 6990fcc1e6
commit ad75561efe

View File

@@ -10,7 +10,9 @@
#define FENCE_GEMM_II
#define GEMMINI_NEW_CISC
#define GEMMINI_NEW_CISC 1
static_assert(GEMMINI_NEW_CISC, "NOTE: old non-CISC code is untested; look for "
"any misalignment of fields in ciscArgs.");
constexpr bool DEBUG = false;
@@ -282,6 +284,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
// block until DMA complete
gemmini_fence();
// also move Q to spad_addr_Q1 for the second iteration
@@ -309,12 +313,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
gemmini_fence();
gemmini_fence();
gemmini_fence();
// block until DMA complete
gemmini_fence();
// re-configure DMA for K and V load that will later happen in the loop
// FIXME: not sure necessary with new CISC
//
// GMEM addr stride for K
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
MVIN_SCALE_IDENTITY, false, 0);
@@ -424,9 +428,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// FIXME: perf: prevent GMEM->SMEM load for O tile
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
#ifdef GEMMINI_NEW_CISC
gemmini_tile_compute</*store_to_spad=*/true>(
spad_hex_P_consume, spad_hex_V_consume, spad_hex_O,
@@ -458,16 +459,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (tid_in_warpgroup == 0) {
// fence to GEMM II completion
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
#ifdef FENCE_GEMM_II
asm volatile("rescale_fence_write_start_%=:" ::);
// signal that GEMM II is finished to O rescale step
*smem_O_flag = 1;
vx_fence();
asm volatile("rescale_fence_write_end_%=:" ::);
#endif
// Kick off GEMM I
//
// 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum)
// const uint32_t opcode = 3 * (tile_k & 1);
@@ -485,10 +487,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif
// gemmini_fence();
// gemmini_fence();
// gemmini_fence();
// gemmini_fence();
asm volatile("gemm_qk_finish_%=:" ::);
// data move for K and V
@@ -511,7 +509,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
#endif
// gemmini_fence();
// do DMA
if (tile_k == 0) {
@@ -554,9 +551,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// fence everything before going to the next tile
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// threadblock_barrier(warpgroup_id_in_cluster,
@@ -625,6 +619,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
#ifdef FENCE_GEMM_II
asm volatile("rescale_fence_read_start_%=:" ::);
// check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale
if (tid_in_warpgroup_simt == 0) {
@@ -634,6 +629,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
*smem_O_flag = 0;
vx_fence();
}
asm volatile("rescale_fence_read_end_%=:" ::);
#endif
#if 0