From 8a521a1de89b1d93cdb9c66d8b8fbc551e0db7df Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 10 May 2024 23:23:11 -0700 Subject: [PATCH] Add 8-lane operand mapping --- tests/kernel/tensor/main.cpp | 166 ++++++++++++++++++++++++++--------- 1 file changed, 124 insertions(+), 42 deletions(-) diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 0fc4274d..2fadbbd9 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -4,35 +4,103 @@ #include #include +constexpr int DIM_M = 16; + inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } #include "test_data.h" +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + void vx_wmma_load() { int tid = vx_thread_id(); int tg = tid / 4; - // load A - int row = tid % 4; - row += (tg * 8) % 16; - row += (tg / 4) * 4; + int row = 0; + int col = 0; + map_operand_32lanes(tid, row, col); + + // load A + // each operand element is read twice by two threadgroups (Sec. III-B); + // i.e. 8 regs * 32 lanes = 256 fp32 elements = 2 * (16 * 8) elements asm volatile ("flw f0, %0" :: "m"(A[row][0])); asm volatile ("flw f1, %0" :: "m"(A[row][1])); - asm volatile ("flw f2, %0" :: "m"(A[row][2])); - asm volatile ("flw f3, %0" :: "m"(A[row][3])); - asm volatile ("flw f4, %0" :: "m"(A[row][4])); - asm volatile ("flw f5, %0" :: "m"(A[row][5])); - asm volatile ("flw f6, %0" :: "m"(A[row][6])); - asm volatile ("flw f7, %0" :: "m"(A[row][7])); - - // load B - int col = tid % 4; - col += ((tg % 4) / 2) * 8; - col += (tg / 4) * 4; + asm volatile ("flw f2, %0" :: "m"(A[row][2])); + asm volatile ("flw f3, %0" :: "m"(A[row][3])); + asm volatile ("flw f4, %0" :: "m"(A[row][4])); + asm volatile ("flw f5, %0" :: "m"(A[row][5])); + asm volatile ("flw f6, %0" :: "m"(A[row][6])); + asm volatile ("flw f7, %0" :: "m"(A[row][7])); + // load B asm volatile ("flw f8 , %0" :: "m"(B[0][col])); asm volatile ("flw f9 , %0" :: "m"(B[1][col])); asm volatile ("flw f10, %0" :: "m"(B[2][col])); @@ -42,14 +110,9 @@ void vx_wmma_load() { asm volatile ("flw f14, %0" :: "m"(B[6][col])); asm volatile ("flw f15, %0" :: "m"(B[7][col])); - // load C - col = ((tg % 4) / 2) * 8; - row = (tg * 8) % 16; - row += (tg / 4) * 4; - - row += (tid % 4) % 2; - col += ((tid % 4) / 2) * 2; + map_c_32lanes(tid, row, col); + // load C asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); @@ -60,38 +123,57 @@ void vx_wmma_load() { asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); } -float results[32*8]; +// float results[32*8]; +float *const results = reinterpret_cast(0xc0000000UL); void store_wmma_result() { int tid = vx_thread_id(); - - float *results = reinterpret_cast(0xc0000000UL); - asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); - asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); - asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); - asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); - asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); - asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); - asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); - asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + int tg = tid / 4; + + int row = 0; + int col = 0; + + map_c_32lanes(tid, row, col); + + // store C + // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); + // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); + // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); + // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); + // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); + // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); + // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); + // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + + asm volatile ("fsw f16, %0" :: "m"(results[DIM_M * (row + 0) + (col + 0)])); + asm volatile ("fsw f17, %0" :: "m"(results[DIM_M * (row + 0) + (col + 1)])); + asm volatile ("fsw f18, %0" :: "m"(results[DIM_M * (row + 2) + (col + 0)])); + asm volatile ("fsw f19, %0" :: "m"(results[DIM_M * (row + 2) + (col + 1)])); + asm volatile ("fsw f20, %0" :: "m"(results[DIM_M * (row + 0) + (col + 4)])); + asm volatile ("fsw f21, %0" :: "m"(results[DIM_M * (row + 0) + (col + 5)])); + asm volatile ("fsw f22, %0" :: "m"(results[DIM_M * (row + 2) + (col + 4)])); + asm volatile ("fsw f23, %0" :: "m"(results[DIM_M * (row + 2) + (col + 5)])); } void print_wmma_result() { - for (int tid = 0; tid < 32; tid += 1) { - for (int reg = 0; reg < 8; reg += 1) { - vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); - } - } + const int num_threads = vx_num_threads(); + + for (int tid = 0; tid < num_threads; tid += 1) { + for (int reg = 0; reg < 8; reg += 1) { + vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); + } + } } int main() { vx_tmc(-1); vx_wmma_load(); -#pragma GCC unroll 100 - for (int i = 0; i < 100; i++) { - vx_wmma(); - } +// #pragma GCC unroll 100 +// for (int i = 0; i < 100; i++) { +// vx_wmma(); +// } + vx_wmma(); store_wmma_result(); vx_tmc(1); // print_wmma_result();