diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ae0aaf07..82a83aa4 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,7 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -90,69 +90,48 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { static_assert( threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, "flashattention kernel assumes 1 threadblock occupancy per cluster"); - uint8_t *smem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR); - float *smem_cursor = reinterpret_cast(smem_per_threadblock); - // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); - float *smem_Q0 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_Q1 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_K0 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_K1 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_V0 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_V1 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_S0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_S1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_O0 = smem_cursor; - smem_cursor += smem_O_size; - float *smem_O1 = smem_cursor; - smem_cursor += smem_O_size; + uint8_t *smem_per_threadblock = reinterpret_cast(DEV_SMEM_START_ADDR); + constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4); - // NOTE: this has to match with smem_* - static_assert(sizeof(elem_t) == sizeof(float)); - constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); - constexpr uint32_t spad_addr_Q0 = 0; - constexpr uint32_t spad_addr_Q1 = - spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K0 = - spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K1 = - spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V0 = - spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V1 = - spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S0 = - spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S1 = - spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_P0 = - spad_addr_S1 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_P1 = - spad_addr_P0 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_O0 = - spad_addr_P1 + (smem_QK_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_O1 = - spad_addr_O0 + (smem_O_size * sizeof(float) / spad_addr_factor); + // Q/V/S in quart0/1, K/P/O in quart2/3 + constexpr uint32_t smem_Q0_offset = smem_quart0; + constexpr uint32_t smem_Q1_offset = smem_quart1; + constexpr uint32_t smem_K0_offset = smem_quart2; + constexpr uint32_t smem_K1_offset = smem_quart3; + constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_S0_offset = smem_V0_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float); + constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused + + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - // FIXME: dangerous - smem_cursor = reinterpret_cast(0xff038000); + float *smem_cursor = smem_O1 + smem_O_size; + // // FIXME: dangerous + // smem_cursor = reinterpret_cast(0xff038000); float *smem_rowmax_0 = smem_cursor; smem_cursor += smem_rowmax_size; float *smem_rowmax_1 = smem_cursor; @@ -176,6 +155,21 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad_1 = smem_cursor; smem_cursor += smem_scratchpad_size; + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); @@ -184,11 +178,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary - // // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - // if (WARP_SPECIALIZED && warpgroup_id == 1) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } - static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -207,7 +196,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*skip_ex=*/1, /*skip_stc=*/0); constexpr uint32_t skips_matmul = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, - /*skip_ex=*/0, /*skip_stc=*/1); + /*skip_ex=*/0, /*skip_stc=*/0); constexpr uint32_t skips_matmul_preload = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); @@ -327,9 +316,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; tile_k++) { - if constexpr (DEBUG) { + if constexpr (DEBUG || true) { // barrier for debugging - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } // select the correct double buffer by tile iteration @@ -394,6 +383,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // DMA knows the full matrix dimensions // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); + gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, @@ -449,11 +439,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // fence GEMM II to make sure dependency on O tile is settled if (tid_in_warpgroup == 0) { - // fence GEMM-II to make sure dependency on O tile is settled + gemmini_fence(); + gemmini_fence(); gemmini_fence(); gemmini_fence(); +#if 1 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -463,6 +456,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); +#endif } // reconverge from mmio divergence @@ -497,6 +491,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (tid_in_warpgroup == 0) { gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) const uint32_t opcode = 3 * (tile_k & 1); @@ -574,6 +571,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); +#if 1 // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( @@ -584,7 +582,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - +#endif } // reconverge from mmio divergence @@ -668,12 +666,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile ("tile_loop_finish_%=:" :: ); - - // // wait for warpgroup 1 to finish, which called the global barrier before - // // entering the loop - // if (warpgroup_id == 0) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } } int main() {