From 65c653afded4eb3fba8816be5a2027a5bd64cac3 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 18:03:08 -0700 Subject: [PATCH] sgemm_tcore: Use arithmetic instead of branch for double-buffered addr --- tests/regression/sgemm_tcore/kernel.cpp | 37 +++++++++++++------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 7a05f0d4..8f73e0f4 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -385,7 +385,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, row_stride_as * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); -#pragma GCC ivdep +#pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here @@ -436,7 +436,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, row_stride_b * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); -#pragma GCC ivdep +#pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = @@ -551,18 +551,18 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, for (uint32_t k = 0; k < dim_k - BK; k += BK) { volatile float *local_a_produce; volatile float *local_b_produce; - volatile float *local_a_consume; - volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { local_a_produce = (k_index % 2) ? local_a : local_a_buf; local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_consume = (k_index % 2) ? local_a_buf : local_a; - local_b_consume = (k_index % 2) ? local_b_buf : local_b; + local_a_produce = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_a) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_a_buf)); + local_b_produce = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_b) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_b_buf)); } else { local_a_produce = local_a; local_b_produce = local_b; - local_a_consume = local_a; - local_b_consume = local_b; } k_index++; @@ -578,18 +578,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, uint32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { - volatile float *local_a_produce; - volatile float *local_b_produce; volatile float *local_a_consume; volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { - local_a_produce = (k_index % 2) ? local_a : local_a_buf; - local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_consume = (k_index % 2) ? local_a_buf : local_a; - local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // local_a_consume = (k_index % 2) ? local_a_buf : local_a; + // local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // FIXME: swap multiply with bitshifts + local_a_consume = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_a_buf) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_a)); + local_b_consume = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_b_buf) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_b)); } else { - local_a_produce = local_a; - local_b_produce = local_b; local_a_consume = local_a; local_b_consume = local_b; } @@ -600,7 +601,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 +#pragma GCC unroll 10 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, @@ -612,7 +613,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #if TC_SINGLE_WARP if (warp_in_warpgroup == 0) {