tensor: Test with multiple accumulators

This commit is contained in:
Hansung Kim
2024-06-07 18:19:20 -07:00
parent 080923e869
commit 800d9801b5

View File

@@ -11,6 +11,10 @@ inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
inline void vx_wmma_new() {
asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3));
}
#include "test_data.h"
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
@@ -122,6 +126,14 @@ void vx_wmma_load() {
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
asm volatile ("flw f24, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f25, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f26, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f27, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f28, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f29, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f30, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f31, %0" :: "m"(C[row+2][col+5]));
}
// float results[32*8];
@@ -149,14 +161,22 @@ void store_wmma_result() {
float *const results_wid = results + (DIM_M * DIM_M * wid);
asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
}
void print_wmma_result() {
@@ -184,7 +204,7 @@ void wmma() {
// for (int i = 0; i < 100; i++) {
// vx_wmma();
// }
vx_wmma();
vx_wmma_new();
store_wmma_result();
// print_wmma_result();