diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index fd2e73da..d26bae36 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -378,115 +378,114 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, local_a_tmp += BK * row_stride_a; } } else { -#if !GMEM_COALESCED_A - constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; - const float *global_a = A + dim_k * global_a_row + (k + local_as_row); - // FIXME experimenting with global coalescing - // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; - // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); - volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; - - static_assert( - row_stride_as * 8 <= BK, - "manual loop unrolling condition not met; consider increasing BK"); - static_assert( - (BK % (row_stride_as * 8)) == 0, - "manual loop unrolling condition not met; BK should be power-of-two"); - -#pragma GCC unroll 2 - for (uint32_t local_row_offset = 0; local_row_offset < BK; - local_row_offset += row_stride_as * 8) { - // @perf: bank conflicts here - // const uint32_t global_a_offset = - // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + if constexpr (!GMEM_COALESCED_A) { + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; + const float *global_a = A + dim_k * global_a_row + (k + local_as_row); // FIXME experimenting with global coalescing - // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = - // A[global_a_offset]; + // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); + volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; - // *local_a_tmp = *global_a; - asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - - asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += BM * row_stride_as * 8; - } -#else - constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; - const float *global_a = A + dim_k * global_a_row + (k + local_a_col); - // NOTE that SMEM writes are transposed - volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; - - static_assert( - row_stride_a * 8 <= BM_d, - "manual loop unrolling condition not met; consider increasing BM"); - static_assert( - (BM_d % (row_stride_a * 8)) == 0, - "manual loop unrolling condition not met; BM should be power-of-two"); + static_assert( + row_stride_as * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + static_assert( + (BK % (row_stride_as * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 2 - for (uint32_t local_row_offset = 0; local_row_offset < BM_d; - local_row_offset += row_stride_a * 8) { - // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + for (uint32_t local_row_offset = 0; local_row_offset < BK; + local_row_offset += row_stride_as * 8) { + // @perf: bank conflicts here + // const uint32_t global_a_offset = + // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // FIXME experimenting with global coalescing + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); + // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = + // A[global_a_offset]; + + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BM * row_stride_as * 8; + } + } else { + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); // NOTE that SMEM writes are transposed - // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = - // A[global_a_offset]; + volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; - // *local_a_tmp = *global_a; - asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + static_assert( + row_stride_a * 8 <= BM_d, + "manual loop unrolling condition not met; consider increasing BM"); + static_assert( + (BM_d % (row_stride_a * 8)) == 0, + "manual loop unrolling condition not met; BM should be power-of-two"); - // stride along columns - asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += row_stride_a * 8; +#pragma GCC unroll 4 + for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + local_row_offset += row_stride_a * 8) { + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE that SMEM writes are transposed + // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = + // A[global_a_offset]; + + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + + // stride along columns + asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += row_stride_a * 8; + } } -#endif } constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d;