sgemm_tcore: Use if constexpr

This commit is contained in:
Hansung Kim
2024-06-06 13:43:57 -07:00
parent deb6e5eba2
commit 7f6f096191

View File

@@ -378,7 +378,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
local_a_tmp += BK * row_stride_a;
}
} else {
#if !GMEM_COALESCED_A
if constexpr (!GMEM_COALESCED_A) {
constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d;
const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col;
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
@@ -434,7 +434,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += BM * row_stride_as * 8;
}
#else
} else {
constexpr uint32_t row_stride_a = threads_in_warpgroup / BK;
const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row;
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
@@ -448,7 +448,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
(BM_d % (row_stride_a * 8)) == 0,
"manual loop unrolling condition not met; BM should be power-of-two");
#pragma GCC unroll 2
#pragma GCC unroll 4
for (uint32_t local_row_offset = 0; local_row_offset < BM_d;
local_row_offset += row_stride_a * 8) {
// const uint32_t global_a_offset =
@@ -457,7 +457,6 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
// local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] =
// A[global_a_offset];
// *local_a_tmp = *global_a;
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
@@ -486,7 +485,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += row_stride_a * 8;
}
#endif
}
}
constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d;