sgemm_tcore: Fix mem addr stride to 4
Otherwise incurs misaligned accesses not supported in lsu.
This commit is contained in:
@@ -37,6 +37,10 @@
|
|||||||
#error "threadblock size too big for cluster"
|
#error "threadblock size too big for cluster"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// "fake" fp16 type that only has the correct word size. Proper conversion to
|
||||||
|
// fp32 need to be done in a custom function.
|
||||||
|
using float16_t = uint16_t;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||||
const uint32_t k, const T *A, const T *B,
|
const uint32_t k, const T *A, const T *B,
|
||||||
@@ -127,14 +131,18 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||||
global_a += row_stride_as;
|
global_a += row_stride_as;
|
||||||
|
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(T)), "r"(local_a_tmp));
|
// NOTE: stride is fixed to word size , i.e. sizeof(float) = 4,
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(T)), "r"(local_a_tmp));
|
// regardless of fp16 or fp32. Since Vortex core does not support fp16,
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(T)), "r"(local_a_tmp));
|
// load things at word granularity and reinterpret bits inside the
|
||||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(T)), "r"(local_a_tmp));
|
// tensor core.
|
||||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||||
local_a_tmp += BM * row_stride_as * 8;
|
local_a_tmp += BM * row_stride_as * 8;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -178,14 +186,14 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
global_a += dim_k * row_stride_a;
|
global_a += dim_k * row_stride_a;
|
||||||
|
|
||||||
// stride along columns
|
// stride along columns
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(T)), "r"(local_a_tmp));
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||||
local_a_tmp += row_stride_a * 8;
|
local_a_tmp += row_stride_a * 8;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -233,17 +241,17 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
|
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
|
||||||
global_b += dim_n * row_stride_b;
|
global_b += dim_n * row_stride_b;
|
||||||
|
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||||
local_b_tmp += BN * row_stride_b * 2;
|
local_b_tmp += BN * row_stride_b * 2;
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||||
local_b_tmp += BN * row_stride_b * 2;
|
local_b_tmp += BN * row_stride_b * 2;
|
||||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||||
local_b_tmp += BN * row_stride_b * 2;
|
local_b_tmp += BN * row_stride_b * 2;
|
||||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(T)), "r"(local_b_tmp));
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||||
local_b_tmp += BN * row_stride_b * 2;
|
local_b_tmp += BN * row_stride_b * 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -260,7 +268,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
uint8_t *sharedmem_per_threadblock) {
|
uint8_t *sharedmem_per_threadblock) {
|
||||||
const T *A = (const T *)arg->addr_a;
|
const T *A = (const T *)arg->addr_a;
|
||||||
const T *B = (const T *)arg->addr_b;
|
const T *B = (const T *)arg->addr_b;
|
||||||
T *C = (T *)arg->addr_c;
|
float *C = (float *)arg->addr_c;
|
||||||
|
|
||||||
const uint32_t dim_m = arg->dim_m;
|
const uint32_t dim_m = arg->dim_m;
|
||||||
const uint32_t dim_n = arg->dim_n;
|
const uint32_t dim_n = arg->dim_n;
|
||||||
@@ -498,8 +506,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
write_results<float>(tid_in_warp, warp_col, warp_row, wn_iter,
|
||||||
dim_n, C, block_n, block_m);
|
wm_iter, dim_n, C, block_n, block_m);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -168,16 +168,21 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
|
|
||||||
// @perf: bank conflicts
|
// @perf: bank conflicts
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
const volatile T *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = &smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k];
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(T)), "r"(smem_addr));
|
&smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k]);
|
||||||
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(T)), "r"(smem_addr));
|
// NOTE: stride is fixed to word size , i.e. sizeof(float) = 4,
|
||||||
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(T)), "r"(smem_addr));
|
// regardless of fp16 or fp32. Since Vortex core does not support fp16,
|
||||||
asm volatile("flw f3, %0(%1)" ::"i"(3 * sizeof(T)), "r"(smem_addr));
|
// load things at word granularity and reinterpret bits inside the tensor
|
||||||
asm volatile("flw f4, %0(%1)" ::"i"(4 * sizeof(T)), "r"(smem_addr));
|
// core.
|
||||||
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f3, %0(%1)" ::"i"(3 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f4, %0(%1)" ::"i"(4 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
||||||
// asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
|
// asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
|
||||||
// asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
|
// asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
|
||||||
// asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
|
// asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
|
||||||
@@ -189,16 +194,18 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
} else {
|
} else {
|
||||||
// read smem A tile as-is; bank-conflict-free AS load
|
// read smem A tile as-is; bank-conflict-free AS load
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
const volatile T *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(T)), "r"(smem_addr));
|
&smem_A[((local_k + 0) * smem_AS_cols) +
|
||||||
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(T)), "r"(smem_addr));
|
(WM * warp_row + TCM * wm_iter) + row]);
|
||||||
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||||
|
|
||||||
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
// asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
// asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
@@ -227,16 +234,18 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
constexpr int smem_B_cols = BN;
|
constexpr int smem_B_cols = BN;
|
||||||
|
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
const volatile T *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col];
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(T)), "r"(smem_addr));
|
&smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) +
|
||||||
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(T)), "r"(smem_addr));
|
col]);
|
||||||
asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(T)), "r"(smem_addr));
|
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||||
|
|
||||||
// asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
// asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
// asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
// asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
@@ -287,22 +296,22 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
|||||||
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
|
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
|
||||||
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
|
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
|
||||||
|
|
||||||
T *global_offset_C = C +
|
T *global_offset_C =
|
||||||
(BM * threadblock_id_y) * dim_n +
|
C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x;
|
||||||
BN * threadblock_id_x;
|
|
||||||
|
|
||||||
// @perf: this likely causes a lot of gmem bank conflicts
|
// @perf: this likely causes a lot of gmem bank conflicts
|
||||||
if (wm_iter == 0) {
|
if (wm_iter == 0) {
|
||||||
volatile T *gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
|
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
|
||||||
volatile T *gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||||
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(T)), "r"(gmem_addr));
|
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T);
|
||||||
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(T)), "r"(gmem_addr));
|
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(T)), "r"(gmem_addr));
|
asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(T)), "r"(gmem_addr));
|
asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
|
asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
||||||
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
||||||
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
||||||
@@ -312,16 +321,17 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
|||||||
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||||
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||||
} else {
|
} else {
|
||||||
volatile T *gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)];
|
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
|
||||||
volatile T *gmem_addr_tmp = gmem_addr + (2 * dim_n);
|
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||||
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(T)), "r"(gmem_addr));
|
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T);
|
||||||
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(T)), "r"(gmem_addr));
|
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(T)), "r"(gmem_addr));
|
asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(T)), "r"(gmem_addr));
|
asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
|
||||||
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(T)), "r"(gmem_addr_tmp));
|
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
|
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user