sgemm_wg: Prevent run-ahead using ternary flags; reduce mem accesses

This commit is contained in:
Hansung Kim
2024-03-13 21:32:57 -07:00
parent 510a834db5
commit 2036d37840

View File

@@ -13,29 +13,61 @@
#define TN 2
#define DEV_BARRIER_MMIO_BASE_ADDR 0xff003f00UL
#define CORES_PER_CLUSTER 4
#define CORES_PER_CLUSTER 2
#define BARRIER_STRIDE 4
void threadblock_barrier(unsigned int barrier_id, unsigned int count) {
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
vx_barrier(barrier_id, count);
vx_fence();
#if CORES_PER_CLUSTER != 1
if (vx_thread_id() == 0) {
// vx_printf("========== barrier! barrier_id=%u, count=%u\n", barrier_id, count);
#if CORES_PER_CLUSTER != 0
// this code doesn't work without the memory-mapped register implemented in
// hardware, hence the #ifdef.
if (tid_in_threadblock == 0) {
volatile uint32_t *mmio = (volatile uint32_t *)(DEV_BARRIER_MMIO_BASE_ADDR);
int core_id = vx_core_id();
const uint32_t barrier_stride = CORES_PER_CLUSTER;
// FIXME: hardcoded
const uint32_t barrier_stride = BARRIER_STRIDE;
const uint32_t barrier_offset = barrier_stride * barrier_id;
// 1 : 0x00 is reserved for mmio read reg
// wait for the barrier to be initialized
while (mmio[barrier_offset + 1 + core_id] != 0);
// signal internal-core synchronization done
mmio[barrier_offset + 1 + core_id] = 1;
vx_printf("========== barrier written! barrier_id=%u, count=%u\n", barrier_id, count);
// wait for other cores in the cluster to finish by waiting on the
// all-synced read-only mmio reg
while (mmio[barrier_offset] == 0);
// reset per-core flag back to zero for the next barrier
mmio[barrier_offset + 1 + core_id] = 0;
// need to signal that this core passed the barrier; otherwise, if we
// reset this to 0 right away, the other core still waiting for the
// barrier might never see the all-sync mmio reg as 1.
mmio[barrier_offset + 1 + core_id] = 2;
// // if this core is the last one passing the barrier, reset all per-core
// // flags to 0 to get ready for the next barrier
// bool all_passed = true;
// for (int i = 0; i < CORES_PER_CLUSTER; i++) {
// // if (i == core_id) continue;
// // NOTE: this requires coherent access of store-to-load to the same
// // address
// if (mmio[barrier_offset + 1 + i] != 2) {
// all_passed = false;
// break;
// }
// }
// if (all_passed) {
// for (int i = 0; i < CORES_PER_CLUSTER; i++) {
// mmio[barrier_offset + 1 + i] = 0;
// }
// }
}
vx_barrier(barrier_id, count);
#endif
}
@@ -101,7 +133,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
B[global_b_offset];
}
threadblock_barrier(threadblock_id_in_core, threadblock_dim_y);
threadblock_barrier(tid_in_threadblock, threadblock_id_in_core,
threadblock_dim_y);
for (uint32_t local_k = 0; local_k < BK; local_k++) {
#pragma GCC unroll TM
@@ -130,7 +163,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
}
}
threadblock_barrier(threadblock_id_in_core, threadblock_dim_y);
threadblock_barrier(tid_in_threadblock, threadblock_id_in_core,
threadblock_dim_y);
}
#pragma GCC unroll TM
@@ -164,11 +198,6 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
const int threadblock_id_x = threadblock_id % dim_n_in_blocks;
const int threadblock_id_y = threadblock_id / dim_n_in_blocks;
// initialize barrier MMIO
volatile uint32_t *barrier_mmio = (volatile uint32_t *)(DEV_BARRIER_MMIO_BASE_ADDR);
*barrier_mmio = 0;
vx_fence();
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR +
(2 * BM * BK) * threadblock_id_in_core;