diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index c373507a..05a80454 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -93,6 +93,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } +inline constexpr void map_c_rowmajor_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + + void vx_wmma_load() { int tid = vx_thread_id(); int tg = tid / 4; @@ -174,11 +191,31 @@ void store_wmma_result() { int row = 0; int col = 0; - map_c_8lanes(tid, row, col); + // map_c_8lanes(tid, row, col); + map_c_rowmajor_8lanes(tid, row, col); // store C float *const results_wid = results + (DIM_M * DIM_N * wid); - // uncomment to have two accum buffers in rf + + // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * 0 + col])); + // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * 1 + col])); + // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * 2 + col])); + // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_N * 3 + col])); + // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_N * 4 + col])); + // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * 5 + col])); + // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * 6 + col])); + // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * 7 + col])); + asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * 0 + col])); + asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * 1 + col])); + asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * 2 + col])); + asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * 3 + col])); + asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * 4 + col])); + asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * 5 + col])); + asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * 6 + col])); + asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * 7 + col])); + + + // 1x2 jagged mapping // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); @@ -187,14 +224,14 @@ void store_wmma_result() { // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); - asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); - asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); - asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); - asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); - asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); - asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); - asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); - asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); + // asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 0)])); + // asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 1)])); + // asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 0)])); + // asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 1)])); + // asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 4)])); + // asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_N * (row + 0) + (col + 5)])); + // asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 4)])); + // asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_N * (row + 2) + (col + 5)])); } void print_wmma_result() {