diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 2fadbbd9..7fa759a8 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -1,10 +1,11 @@ #define RISCV_CUSTOM3 0x7B +#include #include #include #include -constexpr int DIM_M = 16; +constexpr int DIM_M = 8; inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); @@ -86,7 +87,7 @@ void vx_wmma_load() { int row = 0; int col = 0; - map_operand_32lanes(tid, row, col); + map_operand_8lanes(tid, row, col); // load A // each operand element is read twice by two threadgroups (Sec. III-B); @@ -110,49 +111,52 @@ void vx_wmma_load() { asm volatile ("flw f14, %0" :: "m"(B[6][col])); asm volatile ("flw f15, %0" :: "m"(B[7][col])); - map_c_32lanes(tid, row, col); + map_c_8lanes(tid, row, col); // load C - asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); - asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); - asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); - asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); - asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); - asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); - asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); - asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); + asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); } // float results[32*8]; float *const results = reinterpret_cast(0xc0000000UL); void store_wmma_result() { - int tid = vx_thread_id(); - int tg = tid / 4; + int wid = vx_warp_id(); + int tid = vx_thread_id(); + int tg = tid / 4; - int row = 0; - int col = 0; + int row = 0; + int col = 0; - map_c_32lanes(tid, row, col); + map_c_8lanes(tid, row, col); - // store C - // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); - // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); - // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); - // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); - // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); - // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); - // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); - // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + // store C + // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); + // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); + // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); + // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); + // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); + // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); + // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); + // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); - asm volatile ("fsw f16, %0" :: "m"(results[DIM_M * (row + 0) + (col + 0)])); - asm volatile ("fsw f17, %0" :: "m"(results[DIM_M * (row + 0) + (col + 1)])); - asm volatile ("fsw f18, %0" :: "m"(results[DIM_M * (row + 2) + (col + 0)])); - asm volatile ("fsw f19, %0" :: "m"(results[DIM_M * (row + 2) + (col + 1)])); - asm volatile ("fsw f20, %0" :: "m"(results[DIM_M * (row + 0) + (col + 4)])); - asm volatile ("fsw f21, %0" :: "m"(results[DIM_M * (row + 0) + (col + 5)])); - asm volatile ("fsw f22, %0" :: "m"(results[DIM_M * (row + 2) + (col + 4)])); - asm volatile ("fsw f23, %0" :: "m"(results[DIM_M * (row + 2) + (col + 5)])); + float *const results_wid = results + (DIM_M * DIM_M * wid); + + asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); + asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); + asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); + asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); + asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); + asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); + asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); + asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); } void print_wmma_result() { @@ -160,23 +164,39 @@ void print_wmma_result() { for (int tid = 0; tid < num_threads; tid += 1) { for (int reg = 0; reg < 8; reg += 1) { - vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); + vx_printf("thread %d, f%d: %x\n", tid, 16 + reg, + *((int *)&results[tid * 8 + reg])); } } } -int main() -{ - vx_tmc(-1); - vx_wmma_load(); -// #pragma GCC unroll 100 -// for (int i = 0; i < 100; i++) { -// vx_wmma(); -// } - vx_wmma(); - store_wmma_result(); - vx_tmc(1); - // print_wmma_result(); - - return 0; +void wmma() { + vx_tmc(-1); + + // if (vx_warp_id() == 1) { + // for (int i = 0; i < 100; i++) { + // asm volatile ("nop"); + // } + // } + + vx_wmma_load(); + // #pragma GCC unroll 100 + // for (int i = 0; i < 100; i++) { + // vx_wmma(); + // } + vx_wmma(); + + store_wmma_result(); + // print_wmma_result(); + vx_tmc(1); +} + +int main() { + const int num_warps = vx_num_warps(); + + vx_wspawn(num_warps, wmma); + wmma(); + vx_wspawn_wait(); + + return 0; }