diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index ecc02d22..0294a8b6 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -380,12 +380,16 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8); // gemmini_fence(); - // TODO: branch is probably slow - if (block_k & 1) { - GEMMINI_CISC_CMD_I(12); - } else { // block_k == 0 is here - GEMMINI_CISC_CMD_I(13); - } + // block_k is even: opcode 13 (write to local_a_buf) + // block_k is odd: opcode 12 (write to local_a) + const uint32_t opcode = 13 - (block_k & 1); + GEMMINI_CISC_CMD_R(opcode); + // // TODO: branch is probably slow + // if (block_k & 1) { + // GEMMINI_CISC_CMD_I(12); + // } else { // block_k == 0 is here + // GEMMINI_CISC_CMD_I(13); + // } // configure loop iteration bounds // FIXME: shouldn't be necessary @@ -404,22 +408,26 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // k_LOOP_WS) // gemmini_fence(); +#if 0 + uint32_t spad_a_produce; + uint32_t spad_b_produce; + const uint32_t mask_odd = (block_k & 1) << 31 >> 31; + const uint32_t mask_even = ((block_k & 1) ^ 1) << 31 >> 31; + spad_a_produce = + ((mask_odd & (SPAD_ADDR_Q0)) | (mask_even & (SPAD_ADDR_Q2))); + spad_b_produce = + ((mask_odd & (SPAD_ADDR_Q1)) | (mask_even & (SPAD_ADDR_Q3))); // sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS // FIXME: block_k is 0 for two times -// sp_tiled_matmul_full_spad_ws( -// #if 1 -// SPAD_ADDR_Q2, -// SPAD_ADDR_Q3, -// #else -// (/*block_k:*/ 0 & 1) ? SPAD_ADDR_Q2 : SPAD_ADDR_Q0, -// (/*block_k:*/ 0 & 1) ? SPAD_ADDR_Q3 : SPAD_ADDR_Q1, -// #endif -// /*spad_D=*/0, /*spad_C=*/SPAD_ADDR_Q1, -// /*I=*/BM / DIM, /*J=*/BN / DIM, /*K=*/BK / DIM, /*pad_I=*/0, -// /*pad_J=*/0, /*pad_K=*/0, -// /*a_transpose=*/1, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, -// /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) -// gemmini_fence(); + sp_tiled_matmul_full_spad_ws( + spad_a_produce, + spad_b_produce, + /*spad_D=*/0, /*spad_C=*/SPAD_ADDR_Q1, + /*I=*/BM / DIM, /*J=*/BN / DIM, /*K=*/BK / DIM, /*pad_I=*/0, + /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) +#endif } #else global_dmem_load(dim_n, dim_k, block_k * BK, A, B, local_a, local_b, @@ -431,6 +439,27 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in + const volatile float *local_a_consume; + const volatile float *local_b_consume; + if constexpr (GEMMINI_DMA) { + // local_a_consume = (k_index % 2) ? local_a_buf : local_a; + // local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // FIXME: swap multiply with bitshifts + // const uint32_t mask_odd = (block_k & 1) << 31 >> 31; + // const uint32_t mask_even = ((block_k & 1) ^ 1) << 31 >> 31; + // local_a_consume = reinterpret_cast( + // (mask_odd & reinterpret_cast(local_a_buf)) | + // (mask_even & reinterpret_cast(local_a))); + // local_b_consume = reinterpret_cast( + // (mask_odd & reinterpret_cast(local_b_buf)) | + // (mask_even & reinterpret_cast(local_b))); + local_a_consume = local_a + (block_k & 1) * (local_a_elems + local_b_elems); + local_b_consume = local_b + (block_k & 1) * (local_a_elems + local_b_elems); + } else { + local_a_consume = local_a; + local_b_consume = local_b; + } + #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { #pragma GCC unroll 4 @@ -438,11 +467,11 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { // SMEM -> RF - vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); #pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // SMEM -> RF - vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, tid_in_warp); // perform mma vx_wmma(wm_iter); @@ -513,7 +542,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster float *sharedmem_per_threadblock = - (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; + (float *)DEV_SMEM_START_ADDR + (GEMMINI_DMA ? 2 /*double-buffer*/ : 1) * + (2 * BM * BK) * + threadblock_id_in_cluster; thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_y,