tensor: Attempt row-major mapping for C store (WIP)
Doesn't work because 1x2 jagged mapping is required to achieve throughput for storing the bigger C matrix (2x4, vs. 2x2 in A).
This commit is contained in:
@@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
// A (row major)
|
||||
// row 0~ 3: threadgroup 0
|
||||
// row 4~ 7: threadgroup 1
|
||||
row = tid % 4;
|
||||
row += tg * 4;
|
||||
|
||||
// B (column major)
|
||||
// col 0~ 3: threadgroup 0
|
||||
// col 4~ 7: threadgroup 1
|
||||
col = tid % 4;
|
||||
col += tg * 4;
|
||||
}
|
||||
|
||||
|
||||
void vx_wmma_load() {
|
||||
int tid = vx_thread_id();
|
||||
int tg = tid / 4;
|
||||
@@ -174,11 +191,31 @@ void store_wmma_result() {
|
||||
int row = 0;
|
||||
int col = 0;
|
||||
|
||||
map_c_8lanes(tid, row, col);
|
||||
// map_c_8lanes(tid, row, col);
|
||||
map_c_rowmajor_8lanes(tid, row, col);
|
||||
|
||||
// store C
|
||||
float *const results_wid = results + (DIM_M * DIM_N * wid);
|
||||
// uncomment to have two accum buffers in rf
|
||||
|
||||
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col]));
|
||||
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col]));
|
||||
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col]));
|
||||
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col]));
|
||||
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col]));
|
||||
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col]));
|
||||
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col]));
|
||||
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col]));
|
||||
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col]));
|
||||
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col]));
|
||||
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col]));
|
||||
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col]));
|
||||
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col]));
|
||||
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col]));
|
||||
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col]));
|
||||
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col]));
|
||||
|
||||
|
||||
// 1x2 jagged mapping
|
||||
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||
@@ -187,14 +224,14 @@ void store_wmma_result() {
|
||||
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
|
||||
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
|
||||
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||
// asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
|
||||
// asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
|
||||
// asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
|
||||
// asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
|
||||
// asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
|
||||
// asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
|
||||
// asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
|
||||
// asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
|
||||
}
|
||||
|
||||
void print_wmma_result() {
|
||||
|
||||
Reference in New Issue
Block a user