diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index d748a909..ecc02d22 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -7,6 +7,36 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" +#define GEMMINI_DMA 1 +#if SMEM_SIZE == 0x4000 +#define SMEM_ADDR_Q0 ((float * const) 0xff000000) +#define SMEM_ADDR_Q1 ((float * const) 0xff001000) +#define SMEM_ADDR_Q2 ((float * const) 0xff002000) +#define SMEM_ADDR_Q3 ((float * const) 0xff003000) +#define SPAD_ADDR_Q0 0x0 +#define SPAD_ADDR_Q1 0x80 +#define SPAD_ADDR_Q2 0x100 +#define SPAD_ADDR_Q3 0x180 +#define BOUND_INST 0x400040004ULL +#elif SMEM_SIZE == 0x10000 +#define SMEM_ADDR_Q0 ((float * const) 0xff000000) +#define SMEM_ADDR_Q1 ((float * const) 0xff004000) +#define SMEM_ADDR_Q2 ((float * const) 0xff008000) +#define SMEM_ADDR_Q3 ((float * const) 0xff00c000) +#define SPAD_ADDR_Q0 0x0 +#define SPAD_ADDR_Q1 0x200 +#define SPAD_ADDR_Q2 0x400 +#define SPAD_ADDR_Q3 0x600 +#define BOUND_INST 0x800080008ULL +#else +#error Unsupported smem size +#endif + +// FIXME: NUM_THREADS and NUM_WARPS hardcoded +#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) +#error "threadblock size too big for cluster" +#endif + inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const float *A, const float *B, volatile float *local_a, volatile float *local_b, @@ -204,14 +234,16 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 4; + local_b_tmp += BN * row_stride_b * 2; + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 2; asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 4; + local_b_tmp += BN * row_stride_b * 2; + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 2; } } @@ -221,8 +253,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_dim_y, /*const uint32_t threadblock_id_x, const uint32_t threadblock_id_y,*/ - const uint32_t num_threadblocks, - const uint32_t threadblock_id, + const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster, float *sharedmem_per_threadblock) { const float *A = (const float *)arg->addr_a; @@ -276,8 +307,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #endif // divide rows (M) by the number of threadblocks - const uint32_t dim_m_range = (dim_m / num_threadblocks); - const uint32_t dim_m_start = dim_m_range * threadblock_id; + const uint32_t dim_m_range = (dim_m / threadblocks_per_cluster); + const uint32_t dim_m_start = dim_m_range * threadblock_id_in_cluster; const uint32_t block_m_start = dim_m_start / BM; const uint32_t block_m_end = (dim_m_start + dim_m_range) / BM; @@ -303,9 +334,10 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); gemmini_fence(); - // GEMMINI_CISC_CMD_I(12); - // gemmini_fence(); + GEMMINI_CISC_CMD_I(12); + gemmini_fence(); +#if 0 // sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS // FIXME: block_k is 0 for two times sp_tiled_matmul_full_spad_ws( @@ -321,6 +353,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) gemmini_fence(); +#endif } threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); @@ -340,23 +373,22 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + block_k * BK), - (uint64_t)(B + block_k * BK * dim_n + block_n * BN), + (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), + (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); // gemmini_fence(); - // TODO: this is probably slow - // if (block_k & 1) { - // GEMMINI_CISC_CMD_I(12); - // } else { // block_k == 0 is here - // GEMMINI_CISC_CMD_I(13); - // } + // TODO: branch is probably slow + if (block_k & 1) { + GEMMINI_CISC_CMD_I(12); + } else { // block_k == 0 is here + GEMMINI_CISC_CMD_I(13); + } // configure loop iteration bounds // FIXME: shouldn't be necessary - // #define BOUND_INST 0x400040004ULL // ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST, // k_LOOP_WS_CONFIG_BOUNDS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, // SPAD_ADDR_Q0, SPAD_ADDR_Q1, k_LOOP_WS_CONFIG_SPAD_AB) @@ -483,12 +515,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *sharedmem_per_threadblock = (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; - const int warp_id = vx_warp_id(); thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_y, /*threadblock_id_x, threadblock_id_y,*/ - num_threadblocks, - threadblock_id, + threadblocks_per_cluster, + // threadblock_id, threadblock_id_in_cluster, sharedmem_per_threadblock); } diff --git a/tests/regression/sgemm_tcore/kernel.warpspecial.cpp b/tests/regression/sgemm_tcore/kernel.warpspecial.cpp index d8764bb1..7da03a99 100644 --- a/tests/regression/sgemm_tcore/kernel.warpspecial.cpp +++ b/tests/regression/sgemm_tcore/kernel.warpspecial.cpp @@ -11,6 +11,11 @@ #undef ELEM_PER_THREAD #define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1)) +// FIXME: NUM_THREADS and NUM_WARPS hardcoded +#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) +#error "threadblock size too big for cluster" +#endif + inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const float *A, const float *B, volatile float *local_a, volatile float *local_b, @@ -85,11 +90,12 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, asm volatile ("fsw ft1, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); asm volatile ("fsw ft2, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); asm volatile ("fsw ft3, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += BK * row_stride_a * 8; + local_a_tmp += BK * row_stride_a * 4; + asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BK * row_stride_a * 4; } } else { if constexpr (!GMEM_COALESCED_A) { @@ -245,13 +251,16 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN * row_stride_b * 8; + local_b_tmp += BN * row_stride_b * 2; + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 2; + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 2; + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 2; } } diff --git a/tests/regression/sgemm_tcore/util.hpp b/tests/regression/sgemm_tcore/util.hpp index a601d22c..b4634b2b 100644 --- a/tests/regression/sgemm_tcore/util.hpp +++ b/tests/regression/sgemm_tcore/util.hpp @@ -20,9 +20,9 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 32 -#define BN 32 -#define BK 32 +#define BM 64 +#define BN 64 +#define BK 64 #define WM 16 #define WN 8 #define TCM 8 @@ -42,29 +42,12 @@ // For correctness, only one of either should be 1. To model the case where // the entire A matrix is already stored transposed in GMEM ("TN" kernel), set // both to 0. -#define TRANSPOSE_AT_PRODUCE 0 +#define TRANSPOSE_AT_PRODUCE 1 #define TRANSPOSE_AT_CONSUME 0 // GMEM_COALESCED sets bank conflict-free accesses for // 1: GMEM loads of A matrix // 0: SMEM stores of A matrix #define GMEM_COALESCED_A 1 -#define GEMMINI_DMA 0 -#if SMEM_SIZE != 0x4000 -#error Currently only supports 16K spad -#endif -#define SMEM_ADDR_Q0 ((float * const) 0xff000000) -#define SMEM_ADDR_Q1 ((float * const) 0xff001000) -#define SMEM_ADDR_Q2 ((float * const) 0xff002000) -#define SMEM_ADDR_Q3 ((float * const) 0xff003000) -#define SPAD_ADDR_Q0 0x0 -#define SPAD_ADDR_Q1 0x80 -#define SPAD_ADDR_Q2 0x100 -#define SPAD_ADDR_Q3 0x180 - -// FIXME: NUM_THREADS and NUM_WARPS hardcoded -#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) -#error "threadblock size too big for cluster" -#endif inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4;