From 8dadbdd42d6180eeef14fbefcec48145e6cfe9a8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 29 Oct 2024 19:43:22 -0700 Subject: [PATCH 1/8] tensor: Do DMA mvin for next m/n loop at the last k iter This increases util by pulling the DMA wait time out of the K-loop wraparound (next N) and overlapping it with the last K iter. --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 62 +++++++++++++++------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index afc37542..13226a76 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -1150,19 +1150,20 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, if constexpr (GEMMINI_DMA) { // pipeline initiation - if (tid_in_threadblock == 0) { - // configure dma gmem address to load from - ROCC_INSTRUCTION_RS1_RS2( - XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), - (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), - k_LOOP_WS_CONFIG_ADDRS_AB) - // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); - gemmini_fence(); + if (block_m == 0 && block_n == 0) { + if (tid_in_threadblock == 0) { + // configure dma gmem address to load from + ROCC_INSTRUCTION_RS1_RS2( + XCUSTOM_ACC, + (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/ 0 * BK), + (uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN), + k_LOOP_WS_CONFIG_ADDRS_AB) + // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); + gemmini_fence(); - GEMMINI_CISC_CMD_I(10); - gemmini_fence(); + GEMMINI_CISC_CMD_I(10); + gemmini_fence(); #if 0 // sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS @@ -1181,10 +1182,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) gemmini_fence(); #endif - } + } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } #pragma GCC unroll 1 @@ -1197,12 +1199,27 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // this is either done using DMA or SIMT cores depending on GEMMINI_DMA #if (GEMMINI_DMA == 1) - if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) { + if (tid_in_threadblock == 0) { + asm volatile("next_index_start_%=:" ::); + + const uint32_t next_block_k = + ((block_k + 1) * BK == dim_k) ? 0 : block_k + 1; + const uint32_t next_block_n = + (next_block_k == 0) + ? (((block_n + 1) * BN == dim_n) ? 0 : block_n + 1) + : block_n; + const uint32_t next_block_m = + (next_block_n == 0) + ? ((block_m == block_m_end) ? 0 : block_n + 1) + : block_m; + + asm volatile("next_index_end_%=:" ::); + // configure dma gmem address to load from ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), - (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), + (uint64_t)(A + next_block_m * BM * dim_k + next_block_k * BK), + (uint64_t)(B + next_block_k * BK * dim_n + next_block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); @@ -1210,6 +1227,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) + // + // FIXME: This depends on (dim_k / BK) being an even number, since + // the last iteration of the k-loop is prefetching for the first + // iteration of the n-loop. The ping-poing indexing has to match for + // the two loop end to connect. const uint32_t opcode = 11 - (block_k & 1); GEMMINI_CISC_CMD_I(opcode); // // TODO: branch is probably slow @@ -1349,6 +1371,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } if constexpr (write_to_gmem) { + asm volatile("move_out_start_%=:" ::); + if constexpr (TENSOR_HOPPER) { // wait until all results are accumulated into the RF vx_wgmma_wait(); @@ -1367,6 +1391,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } } } + + asm volatile("move_out_end_%=:" ::); } } asm volatile("loop_mn_end_%=:" ::); From 6b39a6fe703ff9ba3759d608000c8abb204e6571 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 29 Oct 2024 20:14:33 -0700 Subject: [PATCH 2/8] Add convenience script for switching input/args binaries --- .../sgemm_tcore/switch_args_input.sh | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100755 tests/regression/sgemm_tcore/switch_args_input.sh diff --git a/tests/regression/sgemm_tcore/switch_args_input.sh b/tests/regression/sgemm_tcore/switch_args_input.sh new file mode 100755 index 00000000..392794e4 --- /dev/null +++ b/tests/regression/sgemm_tcore/switch_args_input.sh @@ -0,0 +1,51 @@ +#!/bin/sh +# +# Updates symlink to args.bin, input.a.bin, input.b.bin to point to the right +# binary according to the dimension size given as the argument. + +if [ "$#" != "2" ]; then + echo "usage: $0 DIMENSION 1|0" + echo "second argument indicates using DMA or not." + exit 1 +fi + +dim="$1" +dma="$2" +if [ "$2" == "1" ]; then + layout_a="row.swizzle_fp16" + layout_b="row" +else + layout_a="col.swizzle_fp16" + layout_b="row.swizzle_fp16" +fi + +check_exists() { + if ! [ -f "$1" ]; then + echo "error: looked for file $1 that does not exist." + exit 1 + fi +} + +args="args.m$1n$1k$1.bin" +input_a="input.a.rand01.fp16.m$1n$1k$1.$layout_a.bin" +input_b="input.b.rand01.fp16.m$1n$1k$1.$layout_b.bin" +check_exists "$args" +check_exists "$input_a" +check_exists "$input_b" + +echo "will symlink:" +echo "args.bin -> $args" +echo "input.a.bin -> $input_a" +echo "input.b.bin -> $input_b" +echo "continue? (Y/N)" +read -r -s -n 1 answer +if [ "$answer" != "Y" ] && [ "$answer" != "y" ]; then + echo "exiting..." + exit 1 +fi + +ln -sf -v "$args" "args.bin" +ln -sf -v "$input_a" "input.a.bin" +ln -sf -v "$input_b" "input.b.bin" + +echo "done." From 21b6655c101e18c4fdb20f10331be4946b6530ef Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 29 Oct 2024 22:34:22 -0700 Subject: [PATCH 3/8] sgemm_impl: Implement fast coalesced wmma_store Enables a fairer comparison between core-coupled tensor core to Hopper tensor core, where the latter benefits from coalesced full-throughput moveout to GMEM because it does not use the 1x2 interleaved register mapping. This means the result matrix will be stored swizzled in the GMEM, without breaking correctness. --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 69 ++++++++++++++------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 13226a76..5bd694dd 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -104,7 +104,13 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 1 +// if 1, wmma_store() will not respect the register <-> matrix fragment mapping +// scheme and instead do a fast coalesced GMEM writes for move out. This +// doesn't necessarily mean breaking correctness; it means that the final +// result matrix will be stored in a swizzled form in the global memory. +#define WMMA_STORE_FAST 1 + +#define GEMMINI_DMA 0 #define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) @@ -213,10 +219,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } -inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) { +inline constexpr void map_c_8lanes_coalesced(const int tid, int &row, int &col) { const int tg = tid / 2; - // FIXME wrong!!! row = 0; col = tid; } @@ -225,8 +230,8 @@ inline constexpr void map_c(const int tid, int &row, int &col) { if constexpr (NUM_THREADS == 32) { map_c_32lanes(tid, row, col); } else if constexpr (NUM_THREADS == 8) { - if constexpr (TENSOR_HOPPER) { - map_c_8lanes_hopper(tid, row, col); + if constexpr (TENSOR_HOPPER || WMMA_STORE_FAST) { + map_c_8lanes_coalesced(tid, row, col); } else { map_c_8lanes(tid, row, col); } @@ -664,26 +669,48 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, volatile uint8_t *addr = reinterpret_cast( &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); - asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); - asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); - asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); - asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); - asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + if constexpr (!WMMA_STORE_FAST) { + asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + } else { + asm volatile("fsw f16, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f19, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f20, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f23, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr)); + } } else { volatile uint8_t *addr = reinterpret_cast( &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); - asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); - asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); - asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); - asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); - asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + if constexpr (!WMMA_STORE_FAST) { + asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"(2 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f27, %0(%1)" ::"i"(3 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"(6 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f31, %0(%1)" ::"i"(7 * sizeof(float)), "r"(addr_tworow)); + } else { + asm volatile("fsw f24, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f27, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f28, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f31, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr)); + } } asm volatile ("wmma_store_finish_%=:" :: ); From c001618fb91adf6debcab52dd1cee984a723541b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 29 Oct 2024 22:35:56 -0700 Subject: [PATCH 4/8] sgemm_impl: Fix wrong next block_m logic for DMA --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 5bd694dd..aaf66492 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -1237,7 +1237,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, : block_n; const uint32_t next_block_m = (next_block_n == 0) - ? ((block_m == block_m_end) ? 0 : block_n + 1) + ? (((block_m + 1) == block_m_end) ? block_m_start /*unused*/ + : block_m + 1) : block_m; asm volatile("next_index_end_%=:" ::); From 405525501822c069bcc56155789c3b95fab853bd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 8 Nov 2024 16:40:16 -0800 Subject: [PATCH 5/8] flash: Fix tcore kernel for CISC arg field changes --- tests/regression/flash_attention/kernel.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 3c2d463c..17fcf91f 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -259,7 +259,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -549,7 +549,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -813,8 +813,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); } } -#if 0 -#endif } asm volatile ("tile_loop_finish_%=:" :: ); From 4e087a8aab3c3ba8fd51f359c9a24f462352f097 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 8 Nov 2024 16:43:08 -0800 Subject: [PATCH 6/8] flash: Fix loop iteration for gemmini Kernel is software-pipelined around 2 GEMMs and softmax; it requires two iterations to fully complete a tile. --- tests/regression/flash_attention/kernel.gemmini.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ac3788d4..089f0a3f 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -336,7 +336,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; - tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + tile_k < (4 /*for perf measurement*/ * + // virgo kernel is fully pipelined around (2 GEMMs | softmax); + // requires two loop iterations to finish one tile compute + (2 * k_tiles)) + + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); From 1e3d476e70d8e521efe7082fcd73fdc52a4cfaa8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 8 Nov 2024 21:56:26 -0800 Subject: [PATCH 7/8] Switch header configs to flash --- kernel/include/gemmini_mmio.h | 4 ++-- tests/regression/flash_attention/flash_impl.hpp | 6 ++++-- tests/regression/sgemm_tcore/sgemm_impl.hpp | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/kernel/include/gemmini_mmio.h b/kernel/include/gemmini_mmio.h index ed55236c..0dec66b0 100644 --- a/kernel/include/gemmini_mmio.h +++ b/kernel/include/gemmini_mmio.h @@ -9,9 +9,9 @@ // #define SMEM_SIZE 0x4000 // 64KB // #define SMEM_SIZE 0x10000 -// 128KB +// 128KB (FP16 GEMM) // #define SMEM_SIZE 0x20000 -// 256KB +// 256KB (FlashAttention) #define SMEM_SIZE 0x40000 #define SMEM_MASK (SMEM_SIZE - 1) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 47e21c70..2e2bd693 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,8 +11,10 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = true; -constexpr bool TENSOR_CORE = true; +// constexpr bool WARP_SPECIALIZED = true; +// constexpr bool TENSOR_CORE = true; +constexpr bool WARP_SPECIALIZED = false; +constexpr bool TENSOR_CORE = false; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index aaf66492..989c5df9 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 16 +#define FP_SIZE 32 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -19,7 +19,7 @@ using float_type = float16_t; // Generate kernel for the Hopper-style SMEM-decoupled tensor core. This uses // asynchronous HGMMA and HGMMA_WAIT instructions. -#define TENSOR_HOPPER 1 +#define TENSOR_HOPPER 0 // Constraints on parameters: // * Memory: @@ -110,7 +110,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == // result matrix will be stored in a swizzled form in the global memory. #define WMMA_STORE_FAST 1 -#define GEMMINI_DMA 0 +#define GEMMINI_DMA 1 #define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) From 365b1d8e67628f70e9cc764802446508b940187b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 9 Nov 2024 10:16:40 -0800 Subject: [PATCH 8/8] flash: Add begin end markers --- tests/regression/flash_attention/flash_impl.hpp | 3 +++ tests/regression/flash_attention/kernel.cpp | 4 ++++ tests/regression/flash_attention/kernel.gemmini.cpp | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 2e2bd693..ec5e6f9c 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -4,6 +4,9 @@ #include #include +#define MARK_BEG() asm volatile ("slti x0, x1, -1047") +#define MARK_END() asm volatile ("slti x0, x1, -499") + #define B_ROW 64 #define B_COL 64 #define HEADDIM 64 diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 17fcf91f..c3298a61 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -219,6 +219,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/1); + MARK_BEG(); + if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -822,6 +824,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } + + MARK_END(); } int main() { diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 089f0a3f..090233cb 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -212,6 +212,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); + MARK_BEG(); + if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -681,6 +683,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile ("tile_loop_finish_%=:" :: ); + + MARK_END(); } int main() {