diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 0785c5bd..4b57e28c 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -14,7 +14,6 @@ using float16_t = uint16_t; #if (FP_SIZE == 32) using float_type = float; #elif (FP_SIZE == 16) - using float_type = float16_t; #endif @@ -70,7 +69,7 @@ using float_type = float16_t; // generates the NN kernel where both A and B are stored row-major in GMEM. // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. -#define TRANSPOSE_AT_PRODUCE 0 +#define TRANSPOSE_AT_PRODUCE 1 #define TRANSPOSE_AT_CONSUME 0 // GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at // GMEM->SMEM), determines whether we do bank-conflict-free accesses for @@ -393,6 +392,182 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) vx_barrier(barrier_id, count); } +enum class MemLayout { + MN_major, + K_major, +}; + +// Move a single matrix tile from global memory (GMEM) to shared memory (SMEM). +// `dim_col`: column dimension of the global matrix. +template +__attribute__((always_inline)) inline void +global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index, + const uint32_t k, const T *global_addr, + volatile T *local_addr, + const uint32_t tid_in_threadblock) { + asm volatile("global_dmem_load_start_new_%=:" ::); + + // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do + // data movement at the fp32 granularity. Assuming that the matrix is stored + // row-major in GMEM, the packed fp16 pairs belong to the same row, + // neighboring columns; therefore, it essentially becomes equivalent to + // moving a fp32 matrix whose column dimensions are compressed by a factor of + // two. + constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); + + constexpr uint32_t tile_dim_k_packed = tile_dim_k / packed_factor; + constexpr uint32_t gmem_dim_row = + (gmem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed; + constexpr uint32_t gmem_dim_col = + (gmem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn; + constexpr uint32_t smem_dim_row = + (smem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed; + constexpr uint32_t smem_dim_col = + (smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn; + + const uint32_t dim_col_ = + (gmem_layout == MemLayout::K_major) ? dim_col / packed_factor : dim_col; + // FIXME: unsure about this + const uint32_t k_ = k / packed_factor; + + // threads in the threadblock always do contiguous accesses in the gmem + const uint32_t local_row_gmem = tid_in_threadblock / gmem_dim_col; + const uint32_t local_col_gmem = tid_in_threadblock % gmem_dim_col; + + constexpr bool transposed_write = (gmem_layout != smem_layout); + // if transposed, threads write to smem in reversed col/row + const uint32_t local_row_smem = + transposed_write ? local_col_gmem : local_row_gmem; + const uint32_t local_col_smem = + transposed_write ? local_row_gmem : local_col_gmem; + + // FIXME: don't hardcode this here + constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; + + static_assert(gmem_contiguous == true, + "currently only supports contiguous accesses in GMEM"); + + const uint32_t global_row_mn_major = k_ + local_row_gmem; + const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem; + const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem; + const uint32_t global_col_k_major = k_ + local_col_gmem; + const uint32_t global_row = (gmem_layout == MemLayout::K_major) + ? global_row_k_major + : global_row_mn_major; + const uint32_t global_col = (gmem_layout == MemLayout::K_major) + ? global_col_k_major + : global_col_mn_major; + + const float *global = reinterpret_cast(global_addr) + + dim_col_ * global_row + global_col; + volatile float *local = reinterpret_cast(local_addr) + + smem_dim_col * local_row_smem + local_col_smem; + + constexpr uint32_t row_stride = threads_per_threadblock / gmem_dim_col; + static_assert(row_stride * 8 <= gmem_dim_row, + "manual loop unrolling condition not met; tile row dimension " + "is too shallow"); + static_assert((gmem_dim_row % (row_stride * 8)) == 0, + "manual loop unrolling condition not met; tile row dimension " + "should be power-of-two"); + +#pragma GCC unroll 1 + // loop-unrolled flw/fsw to increase reuse distance and IPC + for (uint32_t load_offset = 0; load_offset < gmem_dim_row; + load_offset += row_stride * 8) { + // equivalent code: + // + // *local = *global; + // global += dim_col * row_stride; + // local += BN * row_stride; + + // read same-column elements into fp registers + asm volatile("flw ft0, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft1, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft2, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft3, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft4, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft5, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft6, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + asm volatile("flw ft7, (%0)" ::"r"(global)); + global += dim_col_ * row_stride; + + // do we need to do transposed write? + if constexpr (!transposed_write) { + static_assert(gmem_layout == MemLayout::MN_major); + + // if not, do the same along-the-column accesses for registers as we did + // for gmem + asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + sizeof(float)), + "r"(local)); + asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + sizeof(float)), + "r"(local)); + local += smem_dim_col * row_stride * 2; + asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + sizeof(float)), + "r"(local)); + asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + sizeof(float)), + "r"(local)); + local += smem_dim_col * row_stride * 2; + asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + sizeof(float)), + "r"(local)); + asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + sizeof(float)), + "r"(local)); + local += smem_dim_col * row_stride * 2; + asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + sizeof(float)), + "r"(local)); + asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + sizeof(float)), + "r"(local)); + local += smem_dim_col * row_stride * 2; + } else { + static_assert(gmem_layout == MemLayout::K_major); + static_assert(smem_layout == MemLayout::MN_major); + + // if yes, write the registers along the row, doing a transpose + // @perf: this will incur bank conflicts in smem + asm volatile("fsw ft0, %0(%1)" ::"i"(row_stride * 0 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft1, %0(%1)" ::"i"(row_stride * 1 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft2, %0(%1)" ::"i"(row_stride * 2 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft3, %0(%1)" ::"i"(row_stride * 3 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft4, %0(%1)" ::"i"(row_stride * 4 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft5, %0(%1)" ::"i"(row_stride * 5 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft6, %0(%1)" ::"i"(row_stride * 6 * sizeof(float)), + "r"(local)); + asm volatile("fsw ft7, %0(%1)" ::"i"(row_stride * 7 * sizeof(float)), + "r"(local)); + local += row_stride * 8; + } + } + + asm volatile("global_dmem_load_finish_new_%=:" ::); +} + // TODO: reduce args by passing leading A/B dimensions template __attribute__((always_inline)) @@ -413,16 +588,14 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u constexpr uint32_t packed_factor = (std::is_same_v ? 2 : 1); constexpr uint32_t BK_adjusted = BK / packed_factor; const uint32_t dim_k_adjusted = dim_k / packed_factor; - constexpr uint32_t BN_adjusted = BN / packed_factor; - const uint32_t dim_n_adjusted = dim_n / packed_factor; const uint32_t k_adjusted = k / packed_factor; const uint32_t local_a_row = tid_in_threadblock / BK_adjusted; const uint32_t local_a_col = tid_in_threadblock % BK_adjusted; const uint32_t local_as_row = tid_in_threadblock / BM; const uint32_t local_as_col = tid_in_threadblock % BM; - const uint32_t local_b_row = tid_in_threadblock / BN_adjusted; - const uint32_t local_b_col = tid_in_threadblock % BN_adjusted; + const uint32_t local_b_row = tid_in_threadblock / BN; + const uint32_t local_b_col = tid_in_threadblock % BN; // FIXME: need fix for fp16? constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; @@ -436,27 +609,8 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { - // A is stored M-major in GMEM; - // no transpose at GMEM->SMEM movement - const uint32_t block_m = threadblock_id_y; - const uint32_t global_a_row = k_adjusted + local_as_row; - const uint32_t global_a_col = BM * block_m + local_as_col; - // number of rows a full TB can read at a time - constexpr uint32_t row_stride_as = threads_per_threadblock / BM; - const float *global_a = reinterpret_cast(A) + - dim_m * global_a_row + global_a_col; - volatile float *local_a_tmp = reinterpret_cast(local_a) + - BM * local_as_row + local_as_col; - -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted; - local_row_offset += row_stride_as) { - // TODO: the code GCC generates for below seems fine atm, but unroll to - // assembly to be absolutely sure - *local_a_tmp = *global_a; - global_a += dim_m * row_stride_as; - local_a_tmp += BM * row_stride_as; - } + global_dmem_load_new( + dim_m, threadblock_id_y, k, A, local_a, tid_in_threadblock); } else { if constexpr (!GMEM_COALESCED_A) { // !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in @@ -478,13 +632,12 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u #pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted; local_row_offset += row_stride_as * 8) { - // @perf: bank conflicts here // const uint32_t global_a_offset = // dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset); // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; - // *local_a_tmp = *global_a; + // @perf: bank conflicts asm volatile ("flw ft0, (%0)" :: "r"(global_a)); global_a += row_stride_as; asm volatile ("flw ft1, (%0)" :: "r"(global_a)); @@ -517,119 +670,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u local_a_tmp += BM * row_stride_as * 8; } } else { - constexpr uint32_t row_stride_a = threads_per_threadblock / BK_adjusted; - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; - const float *global_a = reinterpret_cast(A) + - dim_k_adjusted * global_a_row + - (k_adjusted + local_a_col); - // NOTE that SMEM writes are transposed - volatile float *local_a_tmp = - reinterpret_cast(local_a) + BM * local_a_col + - local_a_row; - - static_assert( - row_stride_a * 8 <= BM, - "manual loop unrolling condition not met; consider increasing BM"); - static_assert( - (BM % (row_stride_a * 8)) == 0, - "manual loop unrolling condition not met; BM should be power-of-two"); - -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BM; - local_row_offset += row_stride_a * 8) { - // const uint32_t global_a_offset = - // dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_a_col); - // NOTE that SMEM writes are transposed - // local_a[BM * (local_a_col) + local_a_row + local_row_offset] = - // A[global_a_offset]; - - asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += dim_k_adjusted * row_stride_a; - - // stride along columns - asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += row_stride_a * 8; - } + global_dmem_load_new(dim_k, threadblock_id_y, k, A, local_a, + tid_in_threadblock); } } // end move A // move B - constexpr uint32_t row_stride_b = threads_per_threadblock / BN_adjusted; - const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col; - // NOTE: not k_adjusted here; k is along the row dimension which is not - // compressed for fp16 - const float *global_b = reinterpret_cast(B) + - dim_n_adjusted * (k + local_b_row) + global_b_col; - volatile float *local_b_tmp = reinterpret_cast(local_b) + - BN_adjusted * local_b_row + local_b_col; - - static_assert( - row_stride_b * 8 <= BK_adjusted, - "manual loop unrolling condition not met; consider increasing BK"); - static_assert( - (BK_adjusted % (row_stride_b * 8)) == 0, - "manual loop unrolling condition not met; BK should be power-of-two"); - -#pragma GCC unroll 1 - for (uint32_t load_offset = 0; load_offset < BK; - load_offset += row_stride_b * 8) { - // equivalent code: - // - // *local_b_tmp = *global_b; - // global_b += dim_n * row_stride_b; - // local_b_tmp += BN * row_stride_b; - - asm volatile ("flw ft0, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft1, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft2, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft3, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft4, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft5, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft6, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - asm volatile ("flw ft7, (%0)" :: "r"(global_b)); - global_b += dim_n_adjusted * row_stride_b; - - asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN_adjusted * row_stride_b * 2; - asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN_adjusted * row_stride_b * 2; - asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN_adjusted * row_stride_b * 2; - asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN_adjusted * row_stride_b * 2; - } + global_dmem_load_new( + dim_n, threadblock_id_x, k, B, local_b, tid_in_threadblock); asm volatile ("global_dmem_load_finish_%=:" :: ); }