diff --git a/hw/VX_config.h b/hw/VX_config.h index e7a6b559..63be0e93 100644 --- a/hw/VX_config.h +++ b/hw/VX_config.h @@ -84,7 +84,7 @@ #endif #ifndef NUM_CORES -#define NUM_CORES 8 +#define NUM_CORES 4 #endif #ifndef NUM_WARPS diff --git a/kernel/include/gemmini_mmio.h b/kernel/include/gemmini_mmio.h index ebd3a5ba..ed55236c 100644 --- a/kernel/include/gemmini_mmio.h +++ b/kernel/include/gemmini_mmio.h @@ -9,6 +9,8 @@ // #define SMEM_SIZE 0x4000 // 64KB // #define SMEM_SIZE 0x10000 +// 128KB +// #define SMEM_SIZE 0x20000 // 256KB #define SMEM_SIZE 0x40000 diff --git a/kernel/include/vx_intrinsics.h b/kernel/include/vx_intrinsics.h index f6cfbf58..f51601f7 100644 --- a/kernel/include/vx_intrinsics.h +++ b/kernel/include/vx_intrinsics.h @@ -149,6 +149,7 @@ inline void vx_join(unsigned stack_ptr) { } // Warp Barrier +__attribute__((convergent)) inline void vx_barrier(unsigned barried_id, unsigned num_warps) { asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps)); } diff --git a/kernel/include/vx_spawn.h b/kernel/include/vx_spawn.h index 83052f30..db77e683 100644 --- a/kernel/include/vx_spawn.h +++ b/kernel/include/vx_spawn.h @@ -18,7 +18,7 @@ #include #ifndef CORES_PER_CLUSTER -#define CORES_PER_CLUSTER 8 +#define CORES_PER_CLUSTER 4 #endif #ifdef __cplusplus diff --git a/tests/kernel/tensor/generate_matrix.py b/tests/kernel/tensor/generate_matrix.py index 796a6ea9..c9255465 100644 --- a/tests/kernel/tensor/generate_matrix.py +++ b/tests/kernel/tensor/generate_matrix.py @@ -80,9 +80,13 @@ if __name__ == "__main__": fp16 = True if fp16: A_packed = pack_fp16_by_row(A_array) + A_swizzled = A_packed.reshape([-1, M * 2]) + A_swizzled.astype('float16').tofile("input.a.row.bin") AT_packed = A_packed.transpose([1, 0, 2]) AT_swizzled = AT_packed.reshape([-1, M * 2]) AT_swizzled.astype('float16').tofile("input.a.col.bin") + print('A:') + print(A_swizzled) print('AT:') print(AT_swizzled) B_packed = pack_fp16_by_column(B_array) diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index c373507a..05a80454 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } +inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + + void vx_wmma_load() { int tid = vx_thread_id(); int tg = tid / 4; @@ -174,11 +191,31 @@ void store_wmma_result() { int row = 0; int col = 0; - map_c_8lanes(tid, row, col); + // map_c_8lanes(tid, row, col); + map_c_rowmajor_8lanes(tid, row, col); // store C float *const results_wid = results + (DIM_M * DIM_N * wid); - // uncomment to have two accum buffers in rf + + // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col])); + // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col])); + // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col])); + // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col])); + // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col])); + // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col])); + // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col])); + // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col])); + asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col])); + asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col])); + asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col])); + asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col])); + asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col])); + asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col])); + asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col])); + asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col])); + + + // 1x2 jagged mapping // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); @@ -187,14 +224,14 @@ void store_wmma_result() { // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); - asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); - asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); - asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); - asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); - asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); - asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); - asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); - asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); + // asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); + // asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); + // asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); + // asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); + // asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); + // asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); + // asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); + // asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); } void print_wmma_result() { diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 50efc499..f000dcf6 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -48,7 +48,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy #VX_DP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objdump #VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy -VX_CFLAGS += -v -O3 -std=c++17 +VX_CFLAGS += -v -Os -std=c++17 VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections # comment out below for regression/basic, which uses GCC that doesn't # understand these flags diff --git a/tests/regression/flash_attention/Makefile b/tests/regression/flash_attention/Makefile index 4f49f927..0456e983 100644 --- a/tests/regression/flash_attention/Makefile +++ b/tests/regression/flash_attention/Makefile @@ -2,8 +2,8 @@ PROJECT = flash_attention SRCS = main.cpp common.h -VX_SRCS = kernel.cpp -VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp +VX_SRCS = kernel.gemmini.cpp +VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp OPTS ?= -n16 diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp new file mode 100644 index 00000000..47e21c70 --- /dev/null +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -0,0 +1,559 @@ +#ifndef _FLASH_IMPL_H_ +#define _FLASH_IMPL_H_ + +#include +#include + +#define B_ROW 64 +#define B_COL 64 +#define HEADDIM 64 + +#define ROW_REMAINDER_LOGIC + +constexpr uint32_t ROWMAX_SETS = 3; +constexpr bool WARP_SPECIALIZED = true; +constexpr bool TENSOR_CORE = true; + +// temporary safety stop for wrong configs +static_assert(NUM_CORES == 4); +static_assert(NUM_THREADS == 8); +static_assert(NUM_WARPS == 8); + +inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + float *smem_O, float *smem_rowmax, + float *smem_rowsum, + float *smem_O_row_scale) { + asm volatile("threadblock_init_sharedmem_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + + static_assert((B_ROW % NUM_THREADS) == 0, + "B_ROW must be a multiple of NUM_THREADS"); + static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * + (NUM_WARPS / (WARP_SPECIALIZED ? 2 : 1))), + "not enough warps to initialize rowmax/rowsum"); + + // each thread initializes one element in rowmax/rowsum + // multiple warps participate for the whole vector + constexpr uint32_t needed_warps = B_ROW / NUM_THREADS; + if (warp_id < needed_warps /* more warps in HW than needed? */) { + uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; +#pragma GCC unroll + for (int i = 0; i < ROWMAX_SETS; i++) { + smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN; + } + smem_rowsum[offset] = 0.0f; + smem_O_row_scale[offset] = 0.0f; + } + + // each warp clears out a row of smem_O + // FIXME: dedup this pattern +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_COL; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + continue; + } +#endif + + uint32_t thread_offset = HEADDIM * row + tid_in_warp; + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; + const float one = 0.0f; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + smem_O[thread_offset] = 0.0f; + thread_offset += NUM_THREADS; + } + } + + asm volatile("threadblock_init_sharedmem_finish_%=:" ::); +} + +inline void thread_block_copy_rowmax(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { + asm volatile("threadblock_copy_rowmax_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + + // each thread copies one element in rowmax + // multiple warps participate for the whole vector + constexpr uint32_t num_warps = B_ROW / NUM_THREADS; + if (warp_id < num_warps) { + uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; + dest[offset] = src[offset]; + } + + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + asm volatile("threadblock_copy_rowmax_finish_%=:" ::); +} + +template +inline void thread_block_copy_tile(const float *src, float *dest, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { + asm volatile("threadblock_copy_tile_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + + // FIXME: dedup this pattern +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < dim_row; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + continue; + } +#endif + + constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t smem_offset = B_COL * smem_row + smem_col; + const uint32_t gmem_offset = B_COL * row + col; + + dest[gmem_offset] = src[smem_offset]; + } + + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + } + + asm volatile("threadblock_copy_tile_finish_%=:" ::); +} + +template +inline float exponential_taylor_term(const float x) { + asm volatile("exponential_taylor_term_start_%=:" ::); + + float res = 1.0f; + + if constexpr (order == 1) { + res = x; + } else if constexpr (order == 2) { + res = x * x; + res /= 2.0f; + } else if constexpr (order == 3) { + res = x * x * x; + res /= 6.0f; + } + + asm volatile("exponential_taylor_term_end_%=:" ::); + return res; +} + +template +__attribute__((always_inline)) inline void thread_block_online_softmax( + const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, + float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) { + asm volatile("thread_block_online_softmax_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + + float *smem_rowmax_this = smem_rowmax + B_ROW; + +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_ROW; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + // if the number of warps doesn't exactly divide the number of rows, + // early-exit to prevent out-of-bounds access + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + continue; + } +#endif + const uint32_t first_thread_offset = B_COL * row; + + // rowmax + // + // two-level tree reduction: reduce each row into NUM_THREADS intermediate + // maxes, then reduce it down to one row max + // one warp handles one row in tile + + constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; + // FIXME: threadblock_id needs to be in here too + float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); + +// #define DUMB_ROWMAX +#ifdef DUMB_ROWMAX + // FIXME remove + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // no tree reduction; a single thread in a warp does serialized max across + // the entire row + if (tid_in_warp == 0) { + float rowmax = smem_S[first_thread_offset]; +#pragma GCC unroll 16 + for (int i = 0; i < B_COL; i++) { + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(smem_S[first_thread_offset + i])); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + warp_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } + +#else + static_assert((B_COL % NUM_THREADS) == 0, + "B_COL must be a multiple of NUM_THREADS"); + float per_thread_max = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + const float next = smem_S[offset]; + asm volatile("fmax.s %0, %1, %2" + : "=f"(per_thread_max) + : "f"(per_thread_max), "f"(next)); + } + // stage per-thread max value in smem + warp_smem[tid_in_warp] = per_thread_max; + + // sync writes to warp_smem + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + +// #define PARALLEL_ROWMAX +#ifndef PARALLEL_ROWMAX + // elect 0-th thread to reduce all other thread's values in the warp + if (tid_in_warp == 0) { + float rowmax = per_thread_max; + for (int i = 1; i < NUM_THREADS; i++) { + float other = warp_smem[i]; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(other)); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + warp_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } +#else + if (warp_id < warps_in_threadblock / NUM_THREADS) { + const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp; + float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS); + float rowmax = FLT_MIN; +#pragma GCC unroll + for (int i = 0; i < NUM_THREADS; i++) { + const float f = thread_smem[i]; + asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f)); + } + smem_rowmax_this[row] = rowmax; + + // update previous rowmax + // i.e. mi_new = max(mi, mij) + float prev_rowmax = smem_rowmax[row]; + // stage prev rowmax in scratchpad for warp-wide broadcast + thread_smem[0] = prev_rowmax; + asm volatile("fmax.s %0, %1, %2" + : "=f"(rowmax) + : "f"(rowmax), "f"(prev_rowmax)); + smem_rowmax[row] = rowmax; + } +#endif // PARALLEL_ROWMAX +#endif // DUMB_ROWMAX + + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + + // broadcast prev rowmax to all threads in the warp + // NOTE: memory consistency is a little sketchy here + const float rowmax_prev = warp_smem[0]; + const float rowmax_this = smem_rowmax_this[row]; + + // exponential + // + // B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock)) + // const uint32_t row_stride = + // (exp_elem_per_thread * threads_per_threadblock) / B_COL; + + // broadcast updated rowmax to all threads in the warp + const float rowmax_new = smem_rowmax[row]; + + asm volatile("flashattn_exp_p_start_%=:" ::); + +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + float f0 = smem_S[offset]; + + f0 -= rowmax_new; + + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(f0); + exp += exponential_taylor_term<2>(f0); + + // Store S transposed to the shared memory + + smem_P[offset] = exp; + } + + asm volatile("flashattn_exp_p_end_%=:" ::); + + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + // rowsum + // + // two-level tree reduction, similar to rowmax + + asm volatile("flashattn_rowsum_start_%=:" ::); + + float per_thread_sum = 0.0f; + +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + per_thread_sum += smem_P[offset]; + } + // stage per-thread sum value in smem + // FIXME: threadblock_id needs to be in here too + warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); + warp_smem[tid_in_warp] = per_thread_sum; + + // sync writes to warp_smem + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + // 0-th thread collects all other thread's values in the warp + if (tid_in_warp == 0) { + float rowsum = per_thread_sum; + for (int iter = 1; iter < NUM_THREADS; iter++) { + float other = warp_smem[iter]; + rowsum += other; + } + + const float mi_prev = rowmax_prev; + const float mi_this = rowmax_this; + + const float x = mi_prev - mi_this; + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(x); + exp += exponential_taylor_term<2>(x); + + // update rowsum + const float rowsum_prev = smem_rowsum[row]; + float rowsum_new = exp * rowsum_prev + rowsum; + + smem_rowsum[row] = rowsum_new; + } + + asm volatile("flashattn_rowsum_end_%=:" ::); + + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + // compute Oi rescale factor + // FIXME: parallelize this across threads + // + asm volatile("flashattn_rescale_factor_start_%=:" ::); + +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const float mi_prev = rowmax_prev; + const float mi_new = rowmax_new; + + const float x = mi_prev - mi_new; + // 2nd-order Taylor approximation + float exp = 1.0f; + exp += exponential_taylor_term<1>(x); + exp += exponential_taylor_term<2>(x); + + // @perf: div vs. expansion on e(-x)? + smem_O_row_scale[row] = 1.0f / exp; + } + + asm volatile("flashattn_rescale_factor_end_%=:" ::); + + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + } + + asm volatile("thread_block_online_softmax_finish_%=:" ::); +} + +template +__attribute__((always_inline)) inline void thread_block_O_rescale( + const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale, + const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { + asm volatile("thread_block_O_rescale_start_%=:" ::); + + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; + const uint32_t warps_per_threadblock_per_core = + warps_in_threadblock / CORES_PER_CLUSTER; + +#pragma GCC unroll 1 + for (int row_offset = 0; row_offset < B_ROW; + row_offset += warps_in_threadblock) { + const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + continue; + } +#endif + + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; + + // Oi rescale + // +#pragma GCC unroll + for (int i = 0; i < per_row_iter; i++) { + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + + const uint32_t offset = HEADDIM * smem_row + smem_col; + const float o = smem_O_in[offset]; + const float scale = smem_O_row_scale[row]; + smem_O_out[offset] = (o * scale); + } + } + + // reconverge after warp divergence + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + + asm volatile("thread_block_O_rescale_finish_%=:" ::); +} + +#endif diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 78fc8969..3c2d463c 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -2,457 +2,17 @@ #include #include #include -#include #include "common.h" #include "sgemm_impl.hpp" #include "include/gemmini.h" #include "gemmini_mmio.h" +#include "flash_impl.hpp" -#define B_ROW BM -#define B_COL BN -// FIXME -#define HEADDIM B_COL +constexpr bool DEBUG = false; +constexpr bool Q_IS_K_MAJOR = true; -constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool DEBUG = true; -constexpr bool WARP_SPECIALIZED = false; - -constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; - -// temporary safety stop for wrong configs -static_assert(NUM_CORES == 4); -static_assert(NUM_THREADS == 8); -static_assert(NUM_WARPS == 8); - -inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - float *smem_O, float *smem_rowmax, - float *smem_rowsum, - float *smem_O_row_scale) { - asm volatile("threadblock_init_sharedmem_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - - static_assert((B_ROW % NUM_THREADS) == 0, - "B_ROW must be a multiple of NUM_THREADS"); - static_assert(B_ROW < (NUM_THREADS * CORES_PER_CLUSTER * - (NUM_WARPS / (WARP_SPECIALIZED ? 2 : 1))), - "not enough warps to initialize rowmax/rowsum"); - - // each thread initializes one element in rowmax/rowsum - // multiple warps participate for the whole vector - constexpr uint32_t needed_warps = B_ROW / NUM_THREADS; - if (warp_id < needed_warps /* more warps in HW than needed? */) { - uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < ROWMAX_SETS; i++) { - smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN; - } - smem_rowsum[offset] = 0.0f; - smem_O_row_scale[offset] = 0.0f; - } - - // each warp clears out a row of smem_O - // FIXME: dedup this pattern -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_COL; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - uint32_t thread_offset = HEADDIM * row + tid_in_warp; - constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; - const float one = 0.0f; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - smem_O[thread_offset] = 0.0f; - thread_offset += NUM_THREADS; - } - } - - asm volatile("threadblock_init_sharedmem_finish_%=:" ::); -} - -inline void thread_block_copy_rowmax(const float *src, float *dest, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { - asm volatile("threadblock_copy_rowmax_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - - // each thread copies one element in rowmax - // multiple warps participate for the whole vector - constexpr uint32_t num_warps = B_ROW / NUM_THREADS; - if (warp_id < num_warps) { - uint32_t offset = NUM_THREADS * warp_id + tid_in_warp; - dest[offset] = src[offset]; - } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - asm volatile("threadblock_copy_rowmax_finish_%=:" ::); -} - -inline void thread_block_copy_tile(const float *src, float *dest, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { - asm volatile("threadblock_copy_tile_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - - // FIXME: dedup this pattern -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - dest[thread_offset] = src[thread_offset]; - thread_offset += NUM_THREADS; - } - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } - - asm volatile("threadblock_copy_tile_finish_%=:" ::); -} - -template -inline float exponential_taylor_term(const float x) { - asm volatile("exponential_taylor_term_start_%=:" ::); - - float res = 1.0f; - - if constexpr (order == 1) { - res = x; - } else if constexpr (order == 2) { - res = x * x; - res /= 2.0f; - } else if constexpr (order == 3) { - res = x * x * x; - res /= 6.0f; - } - - asm volatile("exponential_taylor_term_end_%=:" ::); - return res; -} - -__attribute__((always_inline)) inline void thread_block_online_softmax( - const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, - float *smem_rowmax, float *smem_rowsum, float *smem_O_row_scale) { - asm volatile("thread_block_online_softmax_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - - // float ft[8]; - // asm volatile("fmv.s %0, f16" : "=f"(ft[0])); - // asm volatile("fmv.s %0, f17" : "=f"(ft[1])); - // asm volatile("fmv.s %0, f18" : "=f"(ft[2])); - // asm volatile("fmv.s %0, f19" : "=f"(ft[3])); - // asm volatile("fmv.s %0, f20" : "=f"(ft[4])); - // asm volatile("fmv.s %0, f21" : "=f"(ft[5])); - // asm volatile("fmv.s %0, f22" : "=f"(ft[6])); - // asm volatile("fmv.s %0, f23" : "=f"(ft[7])); - - float *smem_rowmax_this = smem_rowmax + B_ROW; - -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - - // rowmax - // - // two-level tree reduction: reduce each row into NUM_THREADS intermediate - // maxes, then reduce it down to one row max - // one warp handles one row in tile - - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; - // FIXME: threadblock_id needs to be in here too - float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); - -// #define DUMB_ROWMAX -#ifdef DUMB_ROWMAX - // FIXME remove - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // no tree reduction; a single thread in a warp does serialized max across - // the entire row - if (tid_in_warp == 0) { - float rowmax = smem_S[first_thread_offset]; -#pragma GCC unroll 16 - for (int i = 0; i < B_COL; i++) { - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(smem_S[first_thread_offset + i])); - } - smem_rowmax_this[row] = rowmax; - - // update previous rowmax - // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; - // stage prev rowmax in scratchpad for warp-wide broadcast - warp_smem[0] = prev_rowmax; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax[row] = rowmax; - } - -#else - static_assert((B_COL % NUM_THREADS) == 0, - "B_COL must be a multiple of NUM_THREADS"); - float per_thread_max = FLT_MIN; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float next = smem_S[thread_offset]; - asm volatile("fmax.s %0, %1, %2" - : "=f"(per_thread_max) - : "f"(per_thread_max), "f"(next)); - thread_offset += NUM_THREADS; - } - // stage per-thread max value in smem - warp_smem[tid_in_warp] = per_thread_max; - - // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - -// #define PARALLEL_ROWMAX -#ifndef PARALLEL_ROWMAX - // elect 0-th thread to reduce all other thread's values in the warp - if (tid_in_warp == 0) { - float rowmax = per_thread_max; - for (int i = 1; i < NUM_THREADS; i++) { - float other = warp_smem[i]; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(other)); - } - smem_rowmax_this[row] = rowmax; - - // update previous rowmax - // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; - // stage prev rowmax in scratchpad for warp-wide broadcast - warp_smem[0] = prev_rowmax; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax[row] = rowmax; - } -#else - if (warp_id < warps_in_threadblock / NUM_THREADS) { - const uint32_t row = row_offset + NUM_THREADS * warp_id + tid_in_warp; - float *const thread_smem = smem_scratchpad + (tid_in_warp * NUM_THREADS); - float rowmax = FLT_MIN; -#pragma GCC unroll - for (int i = 0; i < NUM_THREADS; i++) { - const float f = thread_smem[i]; - asm volatile("fmax.s %0, %1, %2" : "=f"(rowmax) : "f"(rowmax), "f"(f)); - } - smem_rowmax_this[row] = rowmax; - - // update previous rowmax - // i.e. mi_new = max(mi, mij) - float prev_rowmax = smem_rowmax[row]; - // stage prev rowmax in scratchpad for warp-wide broadcast - thread_smem[0] = prev_rowmax; - asm volatile("fmax.s %0, %1, %2" - : "=f"(rowmax) - : "f"(rowmax), "f"(prev_rowmax)); - smem_rowmax[row] = rowmax; - } -#endif // PARALLEL_ROWMAX -#endif // DUMB_ROWMAX - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // broadcast prev rowmax to all threads in the warp - // NOTE: memory consistency is a little sketchy here - const float rowmax_prev = warp_smem[0]; - const float rowmax_this = smem_rowmax_this[row]; - - // exponential - // - // B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock)) - // const uint32_t row_stride = - // (exp_elem_per_thread * threads_per_threadblock) / B_COL; - - // broadcast updated rowmax to all threads in the warp - const float rowmax_new = smem_rowmax[row]; - - asm volatile("flashattn_exp_p_start_%=:" ::); - - thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - float f0 = smem_S[thread_offset]; - - f0 -= rowmax_new; - - // 2nd-order Taylor approximation - float exp = 1.0f; - exp += exponential_taylor_term<1>(f0); - exp += exponential_taylor_term<2>(f0); - - // Store S transposed to the shared memory - - smem_P[thread_offset] = exp; - - thread_offset += NUM_THREADS; - } - - asm volatile("flashattn_exp_p_end_%=:" ::); - - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // rowsum - // - // two-level tree reduction, similar to rowmax - - asm volatile("flashattn_rowsum_start_%=:" ::); - - float per_thread_sum = 0.0f; - - thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - per_thread_sum += smem_P[thread_offset]; - thread_offset += NUM_THREADS; - } - // stage per-thread sum value in smem - // FIXME: threadblock_id needs to be in here too - warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); - warp_smem[tid_in_warp] = per_thread_sum; - - // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // 0-th thread collects all other thread's values in the warp - if (tid_in_warp == 0) { - float rowsum = per_thread_sum; - for (int iter = 1; iter < NUM_THREADS; iter++) { - float other = warp_smem[iter]; - rowsum += other; - } - - const float mi_prev = rowmax_prev; - const float mi_this = rowmax_this; - - const float x = mi_prev - mi_this; - // 2nd-order Taylor approximation - float exp = 1.0f; - exp += exponential_taylor_term<1>(x); - exp += exponential_taylor_term<2>(x); - - // update rowsum - const float rowsum_prev = smem_rowsum[row]; - float rowsum_new = exp * rowsum_prev + rowsum; - - smem_rowsum[row] = rowsum_new; - } - - asm volatile("flashattn_rowsum_end_%=:" ::); - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - - // compute Oi rescale factor - // FIXME: parallelize this across threads - // - asm volatile("flashattn_rescale_factor_start_%=:" ::); - - thread_offset = first_thread_offset + tid_in_warp; -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float mi_prev = rowmax_prev; - const float mi_new = rowmax_new; - - const float x = mi_prev - mi_new; - // 2nd-order Taylor approximation - float exp = 1.0f; - exp += exponential_taylor_term<1>(x); - exp += exponential_taylor_term<2>(x); - - // @perf: div vs. expansion on e(-x)? - smem_O_row_scale[row] = 1.0f / exp; - - thread_offset += NUM_THREADS; - } - - asm volatile("flashattn_rescale_factor_end_%=:" ::); - - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); - } - - asm volatile("thread_block_online_softmax_finish_%=:" ::); -} - -__attribute__((always_inline)) inline void thread_block_O_rescale( - const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale, - const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { - asm volatile("thread_block_O_rescale_start_%=:" ::); - - const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; - const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; - const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; - const uint32_t warps_per_threadblock_per_core = - warps_in_threadblock / CORES_PER_CLUSTER; - -#pragma GCC unroll 1 - for (int row_offset = 0; row_offset < B_ROW; - row_offset += warps_in_threadblock) { - const uint32_t row = row_offset + warp_id; - const uint32_t first_thread_offset = B_COL * row; - constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; - - // Oi rescale - // -#pragma GCC unroll - for (int i = 0; i < per_row_iter; i++) { - const float o = smem_O_in[thread_offset]; - const float scale = smem_O_row_scale[row]; - smem_O_out[thread_offset] = (o * scale); - - thread_offset += NUM_THREADS; - } - } - - asm volatile("thread_block_O_rescale_finish_%=:" ::); -} +// temporary safety stop +static_assert(TENSOR_CORE && WARP_SPECIALIZED); void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same @@ -500,11 +60,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warpgroup_id % warpgroups_per_cluster; const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; - // FIXME do proper software pipelining - // if (WARP_SPECIALIZED && warpgroup_id_in_cluster != 1) { - // return; - // } - const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -533,82 +88,83 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_QK_size = B_ROW * B_COL; constexpr uint32_t smem_V_size = B_COL * HEADDIM; constexpr uint32_t smem_O_size = B_COL * HEADDIM; + static_assert( + threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, + "flashattention kernel assumes 1 threadblock occupancy per cluster"); uint8_t *smem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR + - sizeof(float_type) * - (smem_QK_size + smem_V_size + smem_O_size) * - threadblock_id_in_cluster); - float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); - float *smem_Q0 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_Q1 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_K0 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_K1 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_V0 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_V1 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_S0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_S1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P0 = smem_S0; // in-place update - float *smem_P1 = smem_S1; // in-place update - float *smem_O0 = smem_cursor; - smem_cursor += smem_O_size; - float *smem_O1 = smem_cursor; - smem_cursor += smem_O_size; + DEV_SMEM_START_ADDR); + constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet3 = 3 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet4 = 4 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet5 = 5 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet6 = 6 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet7 = 7 * (SMEM_SIZE / 8); - // NOTE: this has to match with smem_* - constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); - constexpr uint32_t spad_addr_Q0 = 0; - constexpr uint32_t spad_addr_Q1 = - spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K0 = - spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K1 = - spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V0 = - spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V1 = - spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S0 = - spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S1 = - spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + // allocation strategy: since the two warpgroups only access *0 and *1 + // buffers each, allocate *0 in the first half of SMEM, and *1 in the latter + // half + // at the same time, make sure Q and K are in different banks so that they + // can be accessed in parallel for GEMM; same for P and V + constexpr uint32_t smem_Q0_offset = smem_octet0; + constexpr uint32_t smem_Q1_offset = smem_octet4; + constexpr uint32_t smem_K0_offset = smem_octet1; + constexpr uint32_t smem_K1_offset = smem_octet5; + constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_S0_offset = smem_octet2; + constexpr uint32_t smem_S1_offset = smem_octet6; + constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_O0_offset = smem_octet3; + constexpr uint32_t smem_O1_offset = smem_octet7; + + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - smem_cursor = reinterpret_cast(SMEM_ADDR_END); - smem_cursor -= smem_rowmax_size; - float *smem_rowmax_0 = smem_cursor; - smem_cursor -= smem_rowmax_size; - float *smem_rowmax_1 = smem_cursor; - smem_cursor -= smem_rowsum_size; - float *smem_rowsum_0 = smem_cursor; - smem_cursor -= smem_rowsum_size; - float *smem_rowsum_1 = smem_cursor; - smem_cursor -= smem_O_row_scale_size; - float *smem_O_row_scale_0 = smem_cursor; - smem_cursor -= smem_O_row_scale_size; - float *smem_O_row_scale_1 = smem_cursor; + float *smem_cursor_0 = smem_O0 + smem_O_size; + float *smem_cursor_1 = smem_O1 + smem_O_size; + // // FIXME: dangerous + // smem_cursor = reinterpret_cast(0xff038000); + float *smem_rowmax_0 = smem_cursor_0; + smem_cursor_0 += smem_rowmax_size; + float *smem_rowmax_1 = smem_cursor_1; + smem_cursor_1 += smem_rowmax_size; + float *smem_rowsum_0 = smem_cursor_0; + smem_cursor_0 += smem_rowsum_size; + float *smem_rowsum_1 = smem_cursor_1; + smem_cursor_1 += smem_rowsum_size; + float *smem_O_row_scale_0 = smem_cursor_0; + smem_cursor_0 += smem_O_row_scale_size; + float *smem_O_row_scale_1 = smem_cursor_1; + smem_cursor_1 += smem_O_row_scale_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked - // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = threads_per_warpgroup * 2 /*arbitrary slack*/; - smem_cursor -= smem_scratchpad_size; - float *smem_scratchpad_0 = smem_cursor; - smem_cursor -= smem_scratchpad_size; - float *smem_scratchpad_1 = smem_cursor; + float *smem_scratchpad_0 = smem_cursor_0; + smem_cursor_0 += smem_scratchpad_size; + float *smem_scratchpad_1 = smem_cursor_1; + smem_cursor_1 += smem_scratchpad_size; // select the correct buffer by warpgroup float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; @@ -624,26 +180,54 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + + const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0; + const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0; + // initialize rowmax/rowsum values in sharedmem - // thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, - // smem_rowmax, smem_rowsum, smem_O_row_scale); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, + smem_rowmax, smem_rowsum, smem_O_row_scale); constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - // if (warpgroup_id == 1) { - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // } + if (WARP_SPECIALIZED && warpgroup_id == 1) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } + + static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, + "DMA code assumes Q matrix is stored K-major"); + + // skip everything except DMA in the loop FSM + constexpr uint32_t skips = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); if constexpr (GEMMINI_DMA) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA for Q tile + // configure DMA for the full Q matrix gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); - // configure DMA for K tile - gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY, + // configure DMA for the full K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1); // configure DMA for Q*K store gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, @@ -652,62 +236,76 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // NOTE about barriers: placing barriers around thread-divergent branches may - // cause bugs, since the core doesn't check tmask for barriers. The compiler - // might decide to replicate vx_bar into both paths of a conditional branch, - // which will get evaluated twice along the split/join process and result in - // a different number of calls w.r.t other non-divergent warps and therefore - // stalls. + // NOTE about barriers: Placing barriers around thread-divergent branches may + // cause bugs, because the Vortex core doesn't check for tmask for barriers. + // The compiler might decide to duplicate vx_bar into both paths of a + // conditional branch, which will get evaluated twice because of the way + // branches are handled in SIMT; this might result in stalls especially when + // other warps behave differently on the branch condition. // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // move Q and K into SMEM before the loop starts // static_assert(B_ROW == B_COL, "currently only supports square tiles"); - static_assert(warps_per_warpgroup_per_core == 8); // FIXME nocheckin - if constexpr (GEMMINI_DMA) { asm volatile("dma_move_start_%=:" ::); - if (tid_in_threadblock == 0) { + if (tid_in_warpgroup == 0) { + const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; + const float *gmem_K_tile = gmem_K; // configure the GMEM addresses for the DMA to read from - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), - (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA - GEMMINI_CISC_CMD_R((B_COL << 16) | (HEADDIM << 8) | + GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); gemmini_fence(); -#define GEMMINI_DMA_CISC +// #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC - GEMMINI_CISC_CMD_I(10); + GEMMINI_CISC_CMD_I(9); gemmini_fence(); #else - // skip everything except DMA in the loop FSM - constexpr uint32_t skips = - loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, - /*skip_ex=*/1, /*skip_stc=*/1); - + // do DMA + // // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( - spad_addr_Q0, spad_addr_K0, - /*spad_D=*/0, /*spad_C=*/spad_addr_S0, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + spad_addr_Q, spad_addr_K, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); gemmini_fence(); #endif + + // re-configure DMA for K and V load that will later happen in the loop + // GMEM addr stride for K + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // GMEM addr stride for V + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + gemmini_fence(); } asm volatile("dma_move_end_%=:" ::); } else { // load Q; this stays in SMEM for the entire loop - load_tile_to_smem( - dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, - tid_in_warpgroup); + if constexpr (Q_IS_K_MAJOR) { + load_tile_to_smem( + HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } else { + load_tile_to_smem( + dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } // load K load_tile_to_smem(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // thread_block_copy_tile(smem_K0, gmem_tmp_d1, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); - // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - } + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // } -#if 0 asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + for (uint32_t tile_k = 0; tile_k < (4 /* for perf measurement */ * k_tiles); + tile_k++) { // float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; // float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; // float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1; @@ -741,6 +341,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // float *smem_O_row_scale_consume = // (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0; + asm volatile("gemm_qk_start_%=:" ::); + constexpr bool skip_gemm_qk = false; if constexpr (!skip_gemm_qk) { // GEMM I: S = Q*K @@ -751,25 +353,44 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, - HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::block_row_major, MemLayout::block_row_major, + B_ROW, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, "tile size assumption for warp-specialization not met"); - // assumes smem_Q is K-major - // FIXME: fix this to MN-major float *smem_Q_half0 = smem_Q; - float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major - // float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major + float *smem_Q_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA) + ? smem_Q + (B_ROW / 2) * HEADDIM + : smem_Q + (B_ROW / 2); float *smem_S_half0 = smem_S; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; @@ -778,26 +399,92 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, - B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } else if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, - B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, - /*load_accum=*/false, - /*write_to_smem=*/true>( - smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } else if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, B_COL, + HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/false, + /*write_to_smem=*/true>( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } } else { // load Q*K @@ -810,14 +497,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // protect write to SMEM (smem_S) before softmax threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("gemm_qk_finish_%=:" ::); + if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k == 0) { - thread_block_copy_tile(smem_S, gmem_tmp_d0, + thread_block_copy_tile(smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_S, gmem_tmp_d1, + thread_block_copy_tile(smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -837,22 +526,58 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_scratchpad, smem_rowmax, smem_rowsum, smem_O_row_scale); + // FIXME: unnecessary? + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // data movement for K and V // // Q stays in SMEM for the entire loop - // - // load K for the next iteration - load_tile_to_smem( - dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, - tid_in_warpgroup); + asm volatile("move_k_v_start_%=:" ::); + if constexpr (GEMMINI_DMA) { + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); + // load V for the current iteration + const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); - // load V for the current iteration - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, - tid_in_warpgroup); + // do DMA + sp_tiled_matmul_full_spad_ws( + spad_addr_K, spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + gemmini_fence(); + } + } else { + // load K for the next iteration + load_tile_to_smem( + dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, + tid_in_warpgroup); + + // load V for the current iteration + // V dimension is [seqlen, headdim], stored N(headdim)-major + load_tile_to_smem( + HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, + tid_in_warpgroup); + } + asm volatile("move_k_v_finish_%=:" ::); // protect write to SMEM threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -883,9 +608,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before GEMM II threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - // GEMM II: O = O + P*V - // Oi rescale + // TODO: move this back to after softmax for better load-balancing thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); @@ -897,17 +621,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { // O before PV if (tile_k == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, + thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, + thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -917,39 +641,57 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // GEMM II: O = O + P*V + + asm volatile("gemm_pv_start_%=:" ::); + if constexpr (!WARP_SPECIALIZED) { // clear out accumulators before GEMM initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); - - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroups_per_cluster, warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW, HEADDIM, B_COL, + /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P, smem_V, smem_O /*load accum*/, smem_O, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroups_per_cluster, warpgroup_id_in_cluster); + } } else { // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, "tile size assumption for warp-specialization not met"); - // assumes smem_P is K-major float *smem_P_half0 = smem_P; - float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL; + float *smem_P_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA) + ? smem_P + (B_ROW / 2) * B_COL + : smem_P + (B_ROW / 2); float *smem_O_half0 = smem_O; float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; @@ -958,39 +700,111 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } else if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + if constexpr (GEMMINI_DMA_FAST) { + thread_block_gemm_single_tile( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } + } else if constexpr (Q_IS_K_MAJOR) { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("gemm_pv_finish_%=:" ::); + if constexpr (DEBUG) { if (warpgroup_id == 0) { // O after PV if (tile_k == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -999,13 +813,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); } } - - tile_iter_end: - // synchronize progress of two warpgroups - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - // threadblock_barrier(3, // FIXME - // NUM_WARPS); +#if 0 +#endif } asm volatile ("tile_loop_finish_%=:" :: ); @@ -1015,7 +824,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } -#endif } int main() { diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp new file mode 100644 index 00000000..ac3788d4 --- /dev/null +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -0,0 +1,698 @@ +#include +#include +#include +#include +#include "common.h" +#include "sgemm_impl.hpp" +#include "include/gemmini.h" +#include "gemmini_mmio.h" +#include "flash_impl.hpp" + +#define FENCE_GEMM_II + +constexpr bool DEBUG = false; + +static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, + "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + +#ifdef RADIANCE + constexpr uint32_t cores_per_cluster = CORES_PER_CLUSTER; +#else + constexpr uint32_t cores_per_cluster = 1; +#endif + + // FIXME: headdim not considered + constexpr uint32_t threads_per_threadblock_theoretical = + (B_ROW * B_COL) / (ELEM_PER_THREAD); + constexpr uint32_t hw_threads_per_cluster = + CORES_PER_CLUSTER * NUM_THREADS * NUM_WARPS; + // cap maximum threadblock size to # of HW threads in cluster, to prevent + // multiple "wave" invocations which slows down the kernel + constexpr uint32_t threads_per_threadblock = + (threads_per_threadblock_theoretical > hw_threads_per_cluster) + ? hw_threads_per_cluster + : threads_per_threadblock_theoretical; + constexpr uint32_t threadblocks_per_cluster = + hw_threads_per_cluster / threads_per_threadblock; + constexpr uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threadblocks_per_cluster; + + const uint32_t threadblock_id = task_id / threads_per_threadblock; + const uint32_t threadblock_id_in_cluster = + threadblock_id % threadblocks_per_cluster; + const uint32_t tid_in_threadblock = task_id % threads_per_threadblock; + const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; + constexpr uint32_t warps_in_threadblock = + threads_per_threadblock / NUM_THREADS; + + // warpgroup context + constexpr uint32_t threads_per_warpgroup = + threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1); + constexpr uint32_t warpgroups_per_cluster = + threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1); + const uint32_t warps_per_warpgroup_per_core = + NUM_WARPS / warpgroups_per_cluster; + const uint32_t warpgroup_id = task_id / threads_per_warpgroup; + const uint32_t warpgroup_id_in_cluster = + warpgroup_id % warpgroups_per_cluster; + const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; + // // warpgroup 0: warp 0 + // // warpgroup 1: warp 1~7 + // const uint32_t warpgroup_id = (warp_id != 0); + + const uint32_t dim_seqlen = arg->dim_seqlen; + const uint32_t dim_headdim = arg->dim_headdim; + + // get global memory addresses from kernel arguments + const float *gmem_Q = reinterpret_cast(arg->addr_q); + const float *gmem_K = reinterpret_cast(arg->addr_k); + const float *gmem_V = reinterpret_cast(arg->addr_v); + float *gmem_O = reinterpret_cast(arg->addr_o); + + float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); + float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); + float *gmem_tmp_d2 = reinterpret_cast(0xd2000000UL); + float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); + float *gmem_tmp_d4 = reinterpret_cast(0xd4000000UL); + float *gmem_tmp_d5 = reinterpret_cast(0xd5000000UL); + float *gmem_tmp_d6 = reinterpret_cast(0xd6000000UL); + float *gmem_tmp_d7 = reinterpret_cast(0xd7000000UL); + float *gmem_tmp_e0 = reinterpret_cast(0xe0000000UL); + float *gmem_tmp_e1 = reinterpret_cast(0xe1000000UL); + float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); + float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); + + // static shared memory allocation + // these are in float elements, not bytes + constexpr uint32_t smem_Q_size = B_ROW * HEADDIM; + constexpr uint32_t smem_K_size = B_COL * HEADDIM; + constexpr uint32_t smem_QK_size = B_ROW * B_COL; + constexpr uint32_t smem_V_size = B_COL * HEADDIM; + constexpr uint32_t smem_O_size = B_COL * HEADDIM; + static_assert( + threads_per_threadblock == NUM_WARPS * NUM_THREADS * CORES_PER_CLUSTER, + "flashattention kernel assumes 1 threadblock occupancy per cluster"); + uint8_t *smem_per_threadblock = reinterpret_cast(DEV_SMEM_START_ADDR); + constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_quart0 = 0 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart1 = 1 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart2 = 2 * (SMEM_SIZE / 4); + constexpr uint32_t smem_quart3 = 3 * (SMEM_SIZE / 4); + + // Q/V/S in quart0/1, K/P/O in quart2/3 + constexpr uint32_t smem_Q0_offset = smem_quart0; + constexpr uint32_t smem_Q1_offset = smem_quart1; + constexpr uint32_t smem_K0_offset = smem_quart2; + constexpr uint32_t smem_K1_offset = smem_quart3; + constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); + // put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank + // conflicts + constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float); + constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); + // reversed! + constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float); + constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused + + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); + + // allocate rowmax/rowsum storage at the end of the sharedmem address space + constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; + constexpr uint32_t smem_rowsum_size = B_ROW; + constexpr uint32_t smem_O_row_scale_size = B_ROW; + + float *smem_cursor = smem_O1 + smem_O_size; + // // FIXME: dangerous + // smem_cursor = reinterpret_cast(0xff038000); + float *smem_rowmax_0 = smem_cursor; + smem_cursor += smem_rowmax_size; + float *smem_rowmax_1 = smem_cursor; + smem_cursor += smem_rowmax_size; + float *smem_rowsum_0 = smem_cursor; + smem_cursor += smem_rowsum_size; + float *smem_rowsum_1 = smem_cursor; + smem_cursor += smem_rowsum_size; + float *smem_O_row_scale_0 = smem_cursor; + smem_cursor += smem_O_row_scale_size; + float *smem_O_row_scale_1 = smem_cursor; + smem_cursor += smem_O_row_scale_size; + + // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction + // in rowsum + // NOTE: out-of bounds is not checked + constexpr uint32_t smem_scratchpad_size = + threads_per_warpgroup * 2 /*arbitrary slack*/; + float *smem_scratchpad_0 = smem_cursor; + smem_cursor += smem_scratchpad_size; + float *smem_scratchpad_1 = smem_cursor; + smem_cursor += smem_scratchpad_size; + uint32_t *smem_O_flag = reinterpret_cast(smem_cursor); + smem_cursor += 1 /* 4Byte */; + + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + + constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + static_assert(warps_per_threadblock_per_core == NUM_WARPS); + + // initialize rowmax/rowsum values in sharedmem + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, + smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1, + smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); + + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + + // skip everything except DMA in the loop FSM + constexpr uint32_t skips = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_only_a = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_only_b = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/0, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_mvout_spad = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/0); + constexpr uint32_t skips_matmul = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/0, /*skip_stc=*/0); + constexpr uint32_t skips_matmul_preload = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, + /*skip_ex=*/0, /*skip_stc=*/1); + + if (tid_in_warpgroup == 0) { + gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); + + // configure DMA with GMEM address strides + // Q matrix + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // K matrix + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), + MVIN_SCALE_IDENTITY, false, 1); + // configure DMA for Q*K store + gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY); + gemmini_fence(); + } + + // NOTE about barriers: Placing barriers around thread-divergent branches may + // cause bugs, because the Vortex core doesn't check for tmask for barriers. + // The compiler might decide to duplicate vx_bar into both paths of a + // conditional branch, which will get evaluated twice because of the way + // branches are handled in SIMT; this might result in stalls especially when + // other warps behave differently on the branch condition. + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + + static_assert(B_ROW == B_COL, "currently only supports square tiles"); + + // move Q and K into SMEM before the loop starts + // + asm volatile("dma_move_start_%=:" ::); + if (tid_in_warpgroup == 0) { + // make sure to read from the correct row of Q + const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; + const float *gmem_K_tile = gmem_K; + // configure the GMEM addresses for the DMA to read from + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + +// #define GEMMINI_DMA_CISC +#ifdef GEMMINI_DMA_CISC + // the target addresses of this should match with spad_addr_Q0 and + // spad_addr_K0 set in this kernel + GEMMINI_CISC_CMD_I(10); +#else + // do DMA + // + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + sp_tiled_matmul_full_spad_ws( + spad_addr_Q0, spad_addr_K0, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); +#endif + gemmini_fence(); + + // need to also move Q to spad_addr_Q1 for the next iteration + // FIXME: re-configure necessary? + gmem_K_tile = gmem_K + (B_COL * 1); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); +#ifdef GEMMINI_DMA_CISC + // GEMMINI_CISC_CMD_I(11); +#else + sp_tiled_matmul_full_spad_ws( + spad_addr_Q1, spad_addr_K1/*bogus*/, + /*spad_D=*/0, /*spad_C=*/spad_addr_S0/*bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); +#endif + + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + + // re-configure DMA for K and V load that will later happen in the loop + // GMEM addr stride for K + gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), + MVIN_SCALE_IDENTITY, false, 0); + // GMEM addr stride for V + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + gemmini_fence(); + } + + asm volatile("dma_move_end_%=:" ::); + + // protect write to SMEM + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + + // if constexpr (DEBUG) { + // thread_block_copy_tile(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // thread_block_copy_tile(smem_K0, gmem_tmp_d1, tid_in_warpgroup, + // threads_per_warpgroup, warpgroup_id_in_cluster); + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // } + + constexpr uint32_t threads_per_warpgroup_simt = + threads_per_warpgroup - + CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; + constexpr uint32_t warpgroup_id_simt = 1; + constexpr uint32_t barrier_id_simt = 1; + constexpr uint32_t barrier_count_simt = NUM_WARPS - 1; + const uint32_t tid_in_warpgroup_simt = + tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS); + static_assert(barrier_id_simt == 1 && barrier_count_simt == 7); + + asm volatile ("tile_loop_start_%=:" :: ); + + // "inner loop" along the columns of K^T + const uint32_t k_tiles = (dim_seqlen / B_COL); + for (uint32_t tile_k = 0; + tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + tile_k++) { + if constexpr (DEBUG || true) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } + + // select the correct double buffer by tile iteration + // all iterations work on the same Q row tile; no ping-pong necessary + asm volatile ("dbuf_sel_start_%=:" :: ); + // FIXME speedup by doing arithmetic + float *smem_Q = smem_Q0; + float *smem_K_consume = (tile_k & 1) ? smem_K1 : smem_K0; + float *smem_K_produce = (tile_k & 1) ? smem_K0 : smem_K1; + float *smem_V_consume = (tile_k & 1) ? smem_V1 : smem_V0; + float *smem_V_produce = (tile_k & 1) ? smem_V0 : smem_V1; + float *smem_S_consume = (tile_k & 1) ? smem_S1 : smem_S0; + float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1; + float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0; + float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1; + // O, rowmax/rowsum etc. is sequentially updated at every iteration; no + // ping-pong necessary + float *smem_O = smem_O0; + float *smem_O_row_scale = smem_O_row_scale_0; + float *smem_rowmax = smem_rowmax_0; + float *smem_rowsum = smem_rowsum_0; + float *smem_scratchpad = smem_scratchpad_0; + + const auto spad_addr_Q = spad_addr_Q0; + const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_K_produce = (tile_k & 1) ? spad_addr_K0 : spad_addr_K1; + const auto spad_addr_V_consume = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_V_produce = (tile_k & 1) ? spad_addr_V0 : spad_addr_V1; + const auto spad_addr_S_consume = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; + const auto spad_addr_S_produce = (tile_k & 1) ? spad_addr_S0 : spad_addr_S1; + const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1; + const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile + asm volatile ("dbuf_sel_end_%=:" :: ); + + if (vx_warp_id() == 0 /* warp 0 in every core */) { + if (tile_k >= 2) // delay by 2 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 2; + + // GEMM II: O = O + P*V + // -------------------- + // This is done *before* GEMM I in the software pipeline, working on the + // online softmax result tile from the previous iteration + + asm volatile("gemm_pv_start_%=:" ::); + + if (tid_in_warpgroup == 0) { +#if 0 + if (tile_k_ == 0) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(0); + } else if (tile_k_ & 1) { + gemmini_fence(); + GEMMINI_CISC_CMD_I(2); + } else { + gemmini_fence(); + GEMMINI_CISC_CMD_I(1); + } +#else + // kickoff matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + // FIXME: perf: prevent GMEM->SMEM load for O tile + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + sp_tiled_matmul_full_spad_ws( + spad_addr_P_consume, spad_addr_V_consume, + /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); +#endif + } + + // // reconverge from mmio divergence + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); + + asm volatile("gemm_pv_finish_%=:" ::); + } + + // GEMM I: S = Q*K + // + // kick off asynchronously; fence later + asm volatile("gemm_qk_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + // fence to GEMM II completion + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + +#ifdef FENCE_GEMM_II + // signal that GEMM II is finished to O rescale step + *smem_O_flag = 1; + vx_fence(); +#endif + + // 0,2,.: opcode 0 (quartile 0/2, no accum) + // 1,3,.: opcode 3 (quartile 1/3, no accum) + // const uint32_t opcode = 3 * (tile_k & 1); + //GEMMINI_CISC_CMD_I(opcode); + sp_tiled_matmul_full_spad_ws( + spad_addr_Q, spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, + /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + + // gemmini_fence(); + // gemmini_fence(); + // gemmini_fence(); + // gemmini_fence(); + asm volatile("gemm_qk_finish_%=:" ::); + + // data move for K and V + // + // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); + + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); + // load V for the *previous* iteration; this will be consumed 2 + // iterations later + const float *gmem_V_tile = + gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); + +#if 0 + // fence mvout S to SMEM + gemmini_fence(); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) +#endif + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + // gemmini_fence(); + + // do DMA + if (tile_k == 0) { + // we load (k-1)th tile for V; skip V for the 1st iteration, + // sp_tiled_matmul_full_spad_ws( + // spad_addr_K_produce, spad_addr_V_produce, + // /*spad_D=*/0, /*spad_C=*/0, + // /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + // /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + // /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + // /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + } else { + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/0, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + } + + // fence everything before going to the next tile + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); + + asm volatile("move_k_v_finish_%=:" ::); + + // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the + // branch and call this barrier earlier than when thread 0 finishes. + // Since tmask is not considered, that will be a barrier resolve done too + // early + // threadblock_barrier(0, 1); + + } else /* warp_id != 0 */ { + + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; + + if constexpr (DEBUG) { + // verify S = Q*K before softmax + if (warpgroup_id == 0) { + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + // Online softmax + // + thread_block_online_softmax( + smem_S_consume, smem_P_produce, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad, + smem_rowmax, smem_rowsum, smem_O_row_scale); + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k_ == 0) { + thread_block_copy_rowmax( + smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_rowmax( + smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + +#ifdef FENCE_GEMM_II + // check flag to make sure GEMM II finished and read-after-write + // dependency on O tile is settled for rescale + if (tid_in_warpgroup_simt == 0) { + while ((*smem_O_flag) != 1) + ; + // set it back to 0 for the next tile iteration + *smem_O_flag = 0; + vx_fence(); + } +#endif + +#if 0 + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); +#endif + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + gemmini_fence(); + gemmini_fence(); + + // O after PV + if (tile_k_ == 1 /*wait until GEMM II finshes */) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } else if (tile_k_ == 2) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + // Oi rescale + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + + // rescale-to-PV-GEMM barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O before PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + } + +#if 0 + // fence GEMM I after Oi rescale + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); +#endif + + // intra-warpgroup barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + asm volatile ("tile_loop_finish_%=:" :: ); +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + + const uint32_t hw_threads_per_cluster = + CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); + // fix to 1 threadblock per cluster + const uint32_t grid_size = hw_threads_per_cluster; + +#ifdef RADIANCE + vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#else + // NOTE: This kernel assumes contiguous thread scheduling for efficient shared + // memory allocation, and therefore does not work with original vx_spawn_tasks + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#endif + return 0; +} diff --git a/tests/regression/sgemm_gemmini_dma/kernel.cpp b/tests/regression/sgemm_gemmini_dma/kernel.cpp index 85ca285a..4c37a232 100644 --- a/tests/regression/sgemm_gemmini_dma/kernel.cpp +++ b/tests/regression/sgemm_gemmini_dma/kernel.cpp @@ -141,8 +141,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, } if (tile_k == 0) { + asm volatile("cisc_start_%=:" ::); gemmini_fence(); GEMMINI_CISC_CMD_I(0); + asm volatile("cisc_end_%=:" ::); } else if (tile_k & 1) { gemmini_fence(); GEMMINI_CISC_CMD_I(2); @@ -218,4 +220,4 @@ int main() { vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #endif return 0; -} \ No newline at end of file +} diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 59fd7194..bb904baf 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -7,7 +7,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; template inline void thread_block_copy_tile(const float *src, float *dest, @@ -90,19 +90,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_gemm( - (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, - (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, - tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, - sharedmem_per_threadblock); +#ifdef GEMMINI_DMA + /*smem_a_dbuf_offset=*/1 * 128 * 128 * sizeof(float_type), + /*smem_b_offset=*/2 * 128 * 128 * sizeof(float_type), + /*smem_b_dbuf_offset=*/3 * 128 * 128 * sizeof(float_type) + // FIXME: above offsets are hardcoded to agree with CISC + // spadQuartile +#else + /*smem_a_dbuf_offset=*/1 * BM * BK * sizeof(float_type), + /*smem_b_offset=*/2 * BM * BK * sizeof(float_type), + /*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type) +#endif + >((const float_type *)arg->addr_a, + (const float_type *)arg->addr_b, (float *)arg->addr_c, + arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster, + sharedmem_per_threadblock); float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); const float *smem_A = reinterpret_cast(sharedmem_per_threadblock); - const float *smem_B = smem_A + 2 * BM * BK; + const float *smem_B = reinterpret_cast( + sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type)); if constexpr (DEBUG) { threadblock_barrier(threadblock_id_in_cluster, diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 2014b507..0134d6e5 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 128 +#define BM ((NUM_CORES == 8) ? 128 : 64) #define BN 64 #if (FP_SIZE == 32) #define BK 64 @@ -62,8 +62,8 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define BK_LOOP 1 // Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF // (consume). This is because the tensor core expects the A tile to be stored -// in column-major order in SMEM, whereas it will be ultimately stored in -// row-major in the RF. +// in column-major order in SMEM, so a transpose is necessary if A was stored +// row-major in GMEM. // // For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0 // generates the NN kernel where both A and B are stored row-major in GMEM. @@ -72,8 +72,8 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 0 -#define GEMMINI_DMA_MN_MAJOR 1 +#define GEMMINI_DMA 1 +#define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) @@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == enum class MemLayout { MN_major, K_major, + block_row_major, // Gemmini DMA }; inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -206,7 +207,7 @@ template inline constexpr std::pair remap_to_gemmini_dma_layout(const uint32_t logical_row, const uint32_t logical_col) { - static_assert(DIM == 8, + static_assert(!use_dma || DIM == 8, "GEMMINI_DMA layout remapping code only written for DIM == 8"); if constexpr (use_dma) { @@ -253,13 +254,11 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int packed_factor = (std::is_same_v ? 2 : 1); const int local_k_adjusted = local_k / packed_factor; - static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) || - GEMMINI_DMA_MN_MAJOR, - "GEMMINI_DMA only supported for K-major A tile"); static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32), "fp16 is not really tested for K-major A layout"); - if constexpr (layout == MemLayout::K_major) { + if constexpr (layout == MemLayout::K_major || + layout == MemLayout::block_row_major) { constexpr int smem_A_cols = leading_dim; // f8-f15 stores a single row of A @@ -269,8 +268,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // block-row-major layout const auto [smem_row, smem_col] = - remap_to_gemmini_dma_layout(smem_logical_row, - smem_logical_col); + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -356,8 +356,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int thread_in_warp) { asm volatile ("wmma_load_b_start_%=:" :: ); - static_assert(layout == MemLayout::MN_major, - "only N-major layout for the B tile is supported"); + static_assert( + layout == MemLayout::MN_major || layout == MemLayout::block_row_major, + "only N-major or block-row-major layout are supported for the B tile"); const int tid = thread_in_warp; const int tg = tid / 4; @@ -379,8 +380,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // block-row-major layout const auto [smem_row, smem_col] = - remap_to_gemmini_dma_layout(smem_logical_row, - smem_logical_col); + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -388,10 +390,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, smem_B)[smem_B_cols * smem_row + smem_col]); // f8-f15 stores a single column of B // threads read from different columns; no bank conflicts - if constexpr (GEMMINI_DMA) { - // for GEMMINI_DMA, moving rows for the next 7 elements in the same column - // is the same as moving DIM elements forward in the memory because of the - // block-row-major layout + if constexpr (layout == MemLayout::block_row_major) { + // for the block-row-major layout, moving rows for the next 7 elements in + // the same column is the same as moving DIM elements forward in the memory + // because of the block-row-major layout asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr)); @@ -533,7 +535,9 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, asm volatile ("wmma_store_finish_%=:" :: ); } -inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { +__attribute__((convergent)) inline void +threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { + asm volatile("" ::: "memory"); vx_fence(); vx_barrier(barrier_id, count); } @@ -818,6 +822,10 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( if (tid_in_threadblock == 0) { gemmini_fence(); } + + // reconverge after mmio + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); } if constexpr (write_to_mem) { @@ -845,9 +853,9 @@ template < uint32_t smem_a_dbuf_offset = 0, // byte offset of A // double-buffer tile in shared // memory - uint32_t smem_b_offset = sizeof(float) * BM * BK, // byte offset of B tile + uint32_t smem_b_offset = sizeof(T) * BM * BK, // byte offset of B tile // in shared memory - uint32_t smem_b_dbuf_offset = sizeof(float) * BM * + uint32_t smem_b_dbuf_offset = sizeof(T) * BM * BK // byte offset of B double-buffer // tile in shared memory > @@ -903,6 +911,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) { #pragma GCC unroll 1 for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + asm volatile ("loop_mn_start_%=:" :: ); + // clear out accumulators initialize_accum_regs<0>(); initialize_accum_regs<1>(); @@ -911,14 +921,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // pipeline initiation if (tid_in_threadblock == 0) { // configure dma gmem address to load from - // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); gemmini_fence(); GEMMINI_CISC_CMD_I(10); @@ -949,6 +958,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #pragma GCC unroll 1 for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) { + asm volatile("loop_k_start_%=:" ::); // producer code: GMEM->SMEM memory movement // --------------------------------------------------------------------- @@ -958,20 +968,19 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #if (GEMMINI_DMA == 1) if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) { // configure dma gmem address to load from - // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); - // gemmini_fence(); + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); + gemmini_fence(); // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) const uint32_t opcode = 11 - (block_k & 1); - GEMMINI_CISC_CMD_R(opcode); + GEMMINI_CISC_CMD_I(opcode); // // TODO: branch is probably slow // if (block_k & 1) { // GEMMINI_CISC_CMD_I(12); @@ -1017,6 +1026,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) #endif } + + // reconverge after mmio divergence + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); #else // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { @@ -1038,10 +1051,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, warps_per_threadblock_per_core); #endif -#if 0 // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in + asm volatile("dbuf_sel_start_%=:" ::); const T *local_a_consume; const T *local_b_consume; if constexpr (GEMMINI_DMA) { @@ -1056,17 +1069,29 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // local_b_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_b_buf)) | // (mask_even & reinterpret_cast(local_b))); - local_a_consume = local_a + (block_k & 1) * (BM * BK); - local_b_consume = local_b + (block_k & 1) * (BK * BN); + local_a_consume = local_a + (block_k & 1) * + (smem_a_dbuf_offset - smem_a_offset) / + sizeof(T); + local_b_consume = local_b + (block_k & 1) * + (smem_b_dbuf_offset - smem_b_offset) / + sizeof(T); } else { // no double-buffering without DMA local_a_consume = local_a; local_b_consume = local_b; } + asm volatile("dbuf_sel_end_%=:" ::); constexpr MemLayout layout_a = - TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; - thread_block_gemm_single_tile( @@ -1087,7 +1112,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); -#endif + + asm volatile("loop_k_end_%=:" ::); } if constexpr (write_to_gmem) { @@ -1102,6 +1128,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } } } + asm volatile("loop_mn_end_%=:" ::); } }