Files
kernels/tests/regression/sgemm_wg/kernel.cpp
2024-02-12 22:22:28 -08:00

39 lines
1.1 KiB
C++

#include <stdint.h>
#include <vx_intrinsics.h>
#include <vx_spawn.h>
#include "common.h"
void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
const float *global_a = (const float *)arg->addr_a;
float *global_c = (float *)arg->addr_c;
// assumes NT == NW == matrix_dim
const uint32_t dim = arg->matrix_dim;
const uint32_t row = vx_warp_id();
const uint32_t col = vx_thread_id();
float *local_c = (float *)DEV_SMEM_START_ADDR;
float *local_a = (float *)DEV_SMEM_START_ADDR + (dim * dim);
float *local_b = (float *)DEV_SMEM_START_ADDR + 2 * (dim * dim);
local_a[dim * row + col] = global_a[dim * row + col];
local_c[dim * row + col] = 0.0f;
vx_barrier(0, vx_num_warps());
for (uint32_t k = 0; k < dim; k++) {
local_c[dim * row + col] += local_a[dim * row + k] * local_a[dim * k + col];
}
vx_barrier(0, vx_num_warps());
global_c[dim * row + col] = local_c[dim * row + col];
}
int main() {
kernel_arg_t* arg = (kernel_arg_t*)KERNEL_ARG_DEV_MEM_ADDR;
int threads_per_core = vx_num_warps() * vx_num_threads();
vx_spawn_tasks(threads_per_core, (vx_spawn_tasks_cb)kernel_body, arg);
return 0;
}