sgemm_tcore: Use arithmetic instead of branch for double-buffered addr
This commit is contained in:
@@ -385,7 +385,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
||||
row_stride_as * 8 <= BK,
|
||||
"manual loop unrolling condition not met; consider increasing BK");
|
||||
|
||||
#pragma GCC ivdep
|
||||
#pragma GCC unroll 2
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BK;
|
||||
local_row_offset += row_stride_as * 8) {
|
||||
// @perf: bank conflicts here
|
||||
@@ -436,7 +436,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k,
|
||||
row_stride_b * 8 <= BK,
|
||||
"manual loop unrolling condition not met; consider increasing BK");
|
||||
|
||||
#pragma GCC ivdep
|
||||
#pragma GCC unroll 2
|
||||
for (uint32_t load_offset = 0; load_offset < BK;
|
||||
load_offset += row_stride_b * 8) {
|
||||
// const uint32_t global_b_offset =
|
||||
@@ -551,18 +551,18 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
for (uint32_t k = 0; k < dim_k - BK; k += BK) {
|
||||
volatile float *local_a_produce;
|
||||
volatile float *local_b_produce;
|
||||
volatile float *local_a_consume;
|
||||
volatile float *local_b_consume;
|
||||
if constexpr (DOUBLE_BUFFER) {
|
||||
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
||||
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
||||
local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||
local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||
local_a_produce = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a_buf));
|
||||
local_b_produce = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b_buf));
|
||||
} else {
|
||||
local_a_produce = local_a;
|
||||
local_b_produce = local_b;
|
||||
local_a_consume = local_a;
|
||||
local_b_consume = local_b;
|
||||
}
|
||||
k_index++;
|
||||
|
||||
@@ -578,18 +578,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
uint32_t k_index = 0;
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||
volatile float *local_a_produce;
|
||||
volatile float *local_b_produce;
|
||||
volatile float *local_a_consume;
|
||||
volatile float *local_b_consume;
|
||||
if constexpr (DOUBLE_BUFFER) {
|
||||
local_a_produce = (k_index % 2) ? local_a : local_a_buf;
|
||||
local_b_produce = (k_index % 2) ? local_b : local_b_buf;
|
||||
local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||
local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||
// local_a_consume = (k_index % 2) ? local_a_buf : local_a;
|
||||
// local_b_consume = (k_index % 2) ? local_b_buf : local_b;
|
||||
// FIXME: swap multiply with bitshifts
|
||||
local_a_consume = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_a_buf) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_a));
|
||||
local_b_consume = reinterpret_cast<volatile float *>(
|
||||
((k_index & 1) & 1) * reinterpret_cast<uint32_t>(local_b_buf) +
|
||||
((k_index & 1) ^ 1) * reinterpret_cast<uint32_t>(local_b));
|
||||
} else {
|
||||
local_a_produce = local_a;
|
||||
local_b_produce = local_b;
|
||||
local_a_consume = local_a;
|
||||
local_b_consume = local_b;
|
||||
}
|
||||
@@ -600,7 +601,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
// vx_wmma_load
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
#pragma GCC unroll 1
|
||||
#pragma GCC unroll 10
|
||||
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
||||
// perform wmma
|
||||
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,
|
||||
@@ -612,7 +613,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter,
|
||||
tid_in_warp);
|
||||
// vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp);
|
||||
#pragma GCC unroll 1
|
||||
#pragma GCC unroll 2
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
#if TC_SINGLE_WARP
|
||||
if (warp_in_warpgroup == 0) {
|
||||
|
||||
Reference in New Issue
Block a user