diff --git a/kernel/include/gemmini_mmio.h b/kernel/include/gemmini_mmio.h index 894d5fc6..6043e0be 100644 --- a/kernel/include/gemmini_mmio.h +++ b/kernel/include/gemmini_mmio.h @@ -9,10 +9,10 @@ // #define SMEM_SIZE 0x4000 // 64KB // #define SMEM_SIZE 0x10000 -// 128KB +// 128KB (FP16 GEMM) // #define SMEM_SIZE 0x20000 -// 256KB -#define SMEM_SIZE 0x20000 +// 256KB (FlashAttention) +#define SMEM_SIZE 0x40000 #define SMEM_MASK (SMEM_SIZE - 1) #define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 47e21c70..ec5e6f9c 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -4,6 +4,9 @@ #include #include +#define MARK_BEG() asm volatile ("slti x0, x1, -1047") +#define MARK_END() asm volatile ("slti x0, x1, -499") + #define B_ROW 64 #define B_COL 64 #define HEADDIM 64 @@ -11,8 +14,10 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = true; -constexpr bool TENSOR_CORE = true; +// constexpr bool WARP_SPECIALIZED = true; +// constexpr bool TENSOR_CORE = true; +constexpr bool WARP_SPECIALIZED = false; +constexpr bool TENSOR_CORE = false; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 3c2d463c..c3298a61 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -219,6 +219,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/1); + MARK_BEG(); + if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -259,7 +261,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -549,7 +551,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); @@ -813,8 +815,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); } } -#if 0 -#endif } asm volatile ("tile_loop_finish_%=:" :: ); @@ -824,6 +824,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } + + MARK_END(); } int main() { diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ac3788d4..090233cb 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -212,6 +212,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); + MARK_BEG(); + if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -336,7 +338,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; - tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + tile_k < (4 /*for perf measurement*/ * + // virgo kernel is fully pipelined around (2 GEMMs | softmax); + // requires two loop iterations to finish one tile compute + (2 * k_tiles)) + + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); @@ -677,6 +683,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile ("tile_loop_finish_%=:" :: ); + + MARK_END(); } int main() { diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index afc37542..989c5df9 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 16 +#define FP_SIZE 32 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -19,7 +19,7 @@ using float_type = float16_t; // Generate kernel for the Hopper-style SMEM-decoupled tensor core. This uses // asynchronous HGMMA and HGMMA_WAIT instructions. -#define TENSOR_HOPPER 1 +#define TENSOR_HOPPER 0 // Constraints on parameters: // * Memory: @@ -104,6 +104,12 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 +// if 1, wmma_store() will not respect the register <-> matrix fragment mapping +// scheme and instead do a fast coalesced GMEM writes for move out. This +// doesn't necessarily mean breaking correctness; it means that the final +// result matrix will be stored in a swizzled form in the global memory. +#define WMMA_STORE_FAST 1 + #define GEMMINI_DMA 1 #define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 @@ -213,10 +219,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } -inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) { +inline constexpr void map_c_8lanes_coalesced(const int tid, int &row, int &col) { const int tg = tid / 2; - // FIXME wrong!!! row = 0; col = tid; } @@ -225,8 +230,8 @@ inline constexpr void map_c(const int tid, int &row, int &col) { if constexpr (NUM_THREADS == 32) { map_c_32lanes(tid, row, col); } else if constexpr (NUM_THREADS == 8) { - if constexpr (TENSOR_HOPPER) { - map_c_8lanes_hopper(tid, row, col); + if constexpr (TENSOR_HOPPER || WMMA_STORE_FAST) { + map_c_8lanes_coalesced(tid, row, col); } else { map_c_8lanes(tid, row, col); } @@ -664,26 +669,48 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, volatile uint8_t *addr = reinterpret_cast( &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); - asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); - asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); - asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); - asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); - asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + if constexpr (!WMMA_STORE_FAST) { + asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + } else { + asm volatile("fsw f16, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f19, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f20, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f23, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr)); + } } else { volatile uint8_t *addr = reinterpret_cast( &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); - asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); - asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); - asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); - asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); - asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + if constexpr (!WMMA_STORE_FAST) { + asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"(2 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f27, %0(%1)" ::"i"(3 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"(6 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f31, %0(%1)" ::"i"(7 * sizeof(float)), "r"(addr_tworow)); + } else { + asm volatile("fsw f24, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f27, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f28, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f31, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr)); + } } asm volatile ("wmma_store_finish_%=:" :: ); @@ -1150,19 +1177,20 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, if constexpr (GEMMINI_DMA) { // pipeline initiation - if (tid_in_threadblock == 0) { - // configure dma gmem address to load from - ROCC_INSTRUCTION_RS1_RS2( - XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), - (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), - k_LOOP_WS_CONFIG_ADDRS_AB) - // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); - gemmini_fence(); + if (block_m == 0 && block_n == 0) { + if (tid_in_threadblock == 0) { + // configure dma gmem address to load from + ROCC_INSTRUCTION_RS1_RS2( + XCUSTOM_ACC, + (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/ 0 * BK), + (uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN), + k_LOOP_WS_CONFIG_ADDRS_AB) + // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); + gemmini_fence(); - GEMMINI_CISC_CMD_I(10); - gemmini_fence(); + GEMMINI_CISC_CMD_I(10); + gemmini_fence(); #if 0 // sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS @@ -1181,10 +1209,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) gemmini_fence(); #endif - } + } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } #pragma GCC unroll 1 @@ -1197,12 +1226,28 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // this is either done using DMA or SIMT cores depending on GEMMINI_DMA #if (GEMMINI_DMA == 1) - if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) { + if (tid_in_threadblock == 0) { + asm volatile("next_index_start_%=:" ::); + + const uint32_t next_block_k = + ((block_k + 1) * BK == dim_k) ? 0 : block_k + 1; + const uint32_t next_block_n = + (next_block_k == 0) + ? (((block_n + 1) * BN == dim_n) ? 0 : block_n + 1) + : block_n; + const uint32_t next_block_m = + (next_block_n == 0) + ? (((block_m + 1) == block_m_end) ? block_m_start /*unused*/ + : block_m + 1) + : block_m; + + asm volatile("next_index_end_%=:" ::); + // configure dma gmem address to load from ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), - (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), + (uint64_t)(A + next_block_m * BM * dim_k + next_block_k * BK), + (uint64_t)(B + next_block_k * BK * dim_n + next_block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); @@ -1210,6 +1255,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) + // + // FIXME: This depends on (dim_k / BK) being an even number, since + // the last iteration of the k-loop is prefetching for the first + // iteration of the n-loop. The ping-poing indexing has to match for + // the two loop end to connect. const uint32_t opcode = 11 - (block_k & 1); GEMMINI_CISC_CMD_I(opcode); // // TODO: branch is probably slow @@ -1349,6 +1399,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } if constexpr (write_to_gmem) { + asm volatile("move_out_start_%=:" ::); + if constexpr (TENSOR_HOPPER) { // wait until all results are accumulated into the RF vx_wgmma_wait(); @@ -1367,6 +1419,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } } } + + asm volatile("move_out_end_%=:" ::); } } asm volatile("loop_mn_end_%=:" ::); diff --git a/tests/regression/sgemm_tcore/switch_args_input.sh b/tests/regression/sgemm_tcore/switch_args_input.sh new file mode 100755 index 00000000..392794e4 --- /dev/null +++ b/tests/regression/sgemm_tcore/switch_args_input.sh @@ -0,0 +1,51 @@ +#!/bin/sh +# +# Updates symlink to args.bin, input.a.bin, input.b.bin to point to the right +# binary according to the dimension size given as the argument. + +if [ "$#" != "2" ]; then + echo "usage: $0 DIMENSION 1|0" + echo "second argument indicates using DMA or not." + exit 1 +fi + +dim="$1" +dma="$2" +if [ "$2" == "1" ]; then + layout_a="row.swizzle_fp16" + layout_b="row" +else + layout_a="col.swizzle_fp16" + layout_b="row.swizzle_fp16" +fi + +check_exists() { + if ! [ -f "$1" ]; then + echo "error: looked for file $1 that does not exist." + exit 1 + fi +} + +args="args.m$1n$1k$1.bin" +input_a="input.a.rand01.fp16.m$1n$1k$1.$layout_a.bin" +input_b="input.b.rand01.fp16.m$1n$1k$1.$layout_b.bin" +check_exists "$args" +check_exists "$input_a" +check_exists "$input_b" + +echo "will symlink:" +echo "args.bin -> $args" +echo "input.a.bin -> $input_a" +echo "input.b.bin -> $input_b" +echo "continue? (Y/N)" +read -r -s -n 1 answer +if [ "$answer" != "Y" ] && [ "$answer" != "y" ]; then + echo "exiting..." + exit 1 +fi + +ln -sf -v "$args" "args.bin" +ln -sf -v "$input_a" "input.a.bin" +ln -sf -v "$input_b" "input.b.bin" + +echo "done."