diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 1daed02d..59fd7194 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -84,11 +84,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster uint8_t *sharedmem_per_threadblock = reinterpret_cast( - DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ * - (2 * BM * BK) * threadblock_id_in_cluster); + DEV_SMEM_START_ADDR + + sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster); thread_block_gemm( + /*write_to_gmem=*/true, + /*smem_a_offset=*/0, + /*smem_a_dbuf_offset=*/0, + /*smem_b_offset=*/2 * BM * BK * sizeof(float), + /*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float)>( (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index faa2c382..e0b1c2e1 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -70,7 +70,7 @@ using float_type = float16_t; // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. #define TRANSPOSE_AT_PRODUCE 0 -#define TRANSPOSE_AT_CONSUME 0 +#define TRANSPOSE_AT_CONSUME 1 #define GEMMINI_DMA 1 #if SMEM_SIZE == 0x4000 @@ -230,19 +230,42 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int packed_factor = (std::is_same_v ? 2 : 1); const int local_k_adjusted = local_k / packed_factor; + static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major), + "GEMMINI_DMA only supported for K-major A tile"); + if constexpr (layout == MemLayout::K_major) { constexpr int smem_A_cols = leading_dim; - // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; - // f8-f15 stores a single row of A + const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row; + const uint32_t smem_logical_col = local_k + 0; /* FIXME: adjust for fp16? */ + uint32_t smem_row; + uint32_t smem_col; + if constexpr (GEMMINI_DMA) { + // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level + // block-row-major layout + static_assert( + DIM == 8, + "GEMMINI_DMA layout remapping code only written for DIM == 8"); + constexpr int dim_blocks_in_row = (smem_A_cols / DIM); + smem_row = (smem_logical_row / dim_blocks_in_row) * DIM + + (smem_logical_col / DIM); + smem_col = (smem_logical_row % dim_blocks_in_row) * DIM + + (smem_logical_col % DIM); + } else { + smem_row = smem_logical_row; + smem_col = smem_logical_col; + } + const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( &reinterpret_cast( - smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + - local_k /* FIXME: adjust for fp16? */]); + smem_A)[smem_A_cols * smem_row + smem_col]); // step to the next column // @perf: bank conflicts; threads read from different rows + // below is correct for GEMMINI_DMA; smem_col is always a multiple of 8, + // and the next 7 elements in the row are guaranteed to be consecutive in + // the memory asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr)); @@ -325,24 +348,53 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int local_k_adjusted = local_k / packed_factor; // B is stored N-major in smem - constexpr int smem_B_rows = tile_dim_k_adjusted; constexpr int smem_B_cols = tile_dim_n; + const uint32_t smem_logical_row = local_k_adjusted + 0; + const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col; + uint32_t smem_row; + uint32_t smem_col; + if constexpr (GEMMINI_DMA) { + // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level + // block-row-major layout + constexpr int dim_blocks_in_row = (smem_B_cols / DIM); + smem_row = + (smem_logical_row / dim_blocks_in_row) * DIM + (smem_logical_col / DIM); + smem_col = + (smem_logical_row % dim_blocks_in_row) * DIM + (smem_logical_col % DIM); + } else { + smem_row = smem_logical_row; + smem_col = smem_logical_col; + } + const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( &reinterpret_cast( - smem_B)[((local_k_adjusted + 0) * smem_B_cols) + - (WN * warp_col + TCN * wn_iter) + col]); + smem_B)[smem_B_cols * smem_row + smem_col]); // f8-f15 stores a single column of B // threads read from different columns; no bank conflicts - asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); + if constexpr (GEMMINI_DMA) { + // for GEMMINI_DMA, moving rows for the next 7 elements in the same column + // is the same as moving DIM elements forward in the memory because of the + // block-row-major layout + asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f11, %0(%1)" :: "i"(DIM * 3 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f12, %0(%1)" :: "i"(DIM * 4 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f13, %0(%1)" :: "i"(DIM * 5 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f14, %0(%1)" :: "i"(DIM * 6 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f15, %0(%1)" :: "i"(DIM * 7 * sizeof(float)), "r"(smem_addr)); + } else { + asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); + } asm volatile ("wmma_load_b_finish_%=:" :: ); }