sgemm_tcore: More asserts on manual unrolling

This commit is contained in:
Hansung Kim
2024-06-06 12:26:07 -07:00
parent a42fa6a113
commit ab4d525970

View File

@@ -384,6 +384,10 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
static_assert(
row_stride_as * 8 <= BK,
"manual loop unrolling condition not met; consider increasing BK");
static_assert(
(BK % (row_stride_as * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 2
for (uint32_t local_row_offset = 0; local_row_offset < BK;
@@ -435,6 +439,9 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
static_assert(
row_stride_b * 8 <= BK,
"manual loop unrolling condition not met; consider increasing BK");
static_assert(
(BK % (row_stride_b * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 2
for (uint32_t load_offset = 0; load_offset < BK;
@@ -546,6 +553,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
if (warpgroup_id == 0) {
// TODO: bring initiation pipeline here
// NOTE: this *should* be signed integer to trigger arithmetic right-shift
int32_t k_index = 0;
#pragma GCC unroll 1
for (uint32_t k = 0; k < dim_k - BK; k += BK) {