sgemm_wg: Cleanup & proper unroll
This commit is contained in:
@@ -40,30 +40,30 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||
const uint32_t global_b_col = BN * threadblock_id_x + local_b_col;
|
||||
|
||||
// each thread generates one output element
|
||||
// each thread generates TM output element
|
||||
float reg_c[TM] = { 0.0f };
|
||||
|
||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||
float *local_a = sharedmem_per_threadblock;
|
||||
size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
||||
float *local_b = sharedmem_per_threadblock + local_a_elems;
|
||||
volatile float *local_a = sharedmem_per_threadblock;
|
||||
const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
||||
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
|
||||
|
||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||
uint32_t global_a_offset = dim_k * global_a_row + (k + local_a_col);
|
||||
uint32_t global_b_offset = dim_n * (k + local_b_row) + global_b_col;
|
||||
|
||||
// NOTE: local_b is transposed to column-major to facilitate better memory
|
||||
// access.
|
||||
local_a[BK * local_a_row + local_a_col] = A[global_a_offset];
|
||||
local_b[BN * local_b_row + local_b_col] = B[global_b_offset];
|
||||
|
||||
vx_barrier(threadblock_id_in_core, threadblock_dim_y);
|
||||
vx_fence();
|
||||
|
||||
#pragma GCC unroll TM
|
||||
for (uint32_t local_k = 0; local_k < BK; local_k++) {
|
||||
// Compute multiple result elements (TM) per thread
|
||||
const float local_b_tmp = local_b[BN * local_k + local_b_col];
|
||||
#pragma GCC unroll 4
|
||||
#pragma GCC unroll TM
|
||||
for (uint32_t result_idx = 0; result_idx < TM; result_idx++) {
|
||||
// NOTE use of local_b_row
|
||||
reg_c[result_idx] +=
|
||||
local_a[BK * (TM * local_b_row + result_idx) + local_k] *
|
||||
local_b_tmp;
|
||||
@@ -74,8 +74,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
vx_fence();
|
||||
}
|
||||
|
||||
#pragma GCC unroll 4
|
||||
#pragma GCC unroll TM
|
||||
for (uint32_t result_idx = 0; result_idx < TM; result_idx++) {
|
||||
// NOTE use of local_b_row and global_b_col here
|
||||
C[dim_n * (BM * threadblock_id_y + TM * local_b_row + result_idx) +
|
||||
global_b_col] = reg_c[result_idx];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user