tensor: Attempt row-major mapping for C store (WIP)

Doesn't work because 1x2 jagged mapping is required to achieve
throughput for storing the bigger C matrix (2x4, vs. 2x2 in A).
This commit is contained in:
Hansung Kim
2024-10-02 15:14:55 -07:00
parent 3490294626
commit 34d0956cd5

View File

@@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
// A (row major)
// row 0~ 3: threadgroup 0
// row 4~ 7: threadgroup 1
row = tid % 4;
row += tg * 4;
// B (column major)
// col 0~ 3: threadgroup 0
// col 4~ 7: threadgroup 1
col = tid % 4;
col += tg * 4;
}
void vx_wmma_load() {
int tid = vx_thread_id();
int tg = tid / 4;
@@ -174,11 +191,31 @@ void store_wmma_result() {
int row = 0;
int col = 0;
map_c_8lanes(tid, row, col);
// map_c_8lanes(tid, row, col);
map_c_rowmajor_8lanes(tid, row, col);
// store C
float *const results_wid = results + (DIM_M * DIM_N * wid);
// uncomment to have two accum buffers in rf
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col]));
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col]));
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col]));
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col]));
// 1x2 jagged mapping
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
@@ -187,14 +224,14 @@ void store_wmma_result() {
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
// asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)]));
// asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)]));
// asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)]));
// asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)]));
// asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)]));
// asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)]));
// asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)]));
// asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)]));
}
void print_wmma_result() {