sgemm_tcore: Unpack arg params, remove threadblock_dim_y
thread_block_gemm is meant to be reusable, so it shouldn't assume what the kernel arg struct looks like. threadblock_dim_y was ambiguous and didn't match the literal name either (it was used as # of warps that participate in a barrier).
This commit is contained in:
@@ -31,7 +31,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const uint32_t threadblocks_per_cluster =
|
||||
hw_threads_per_cluster / threads_per_threadblock;
|
||||
|
||||
const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_cluster;
|
||||
const int threadblock_id = task_id / threads_per_threadblock;
|
||||
const int threadblock_id_in_cluster =
|
||||
threadblock_id % threadblocks_per_cluster;
|
||||
@@ -51,11 +50,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ *
|
||||
(2 * BM * BK) * threadblock_id_in_cluster);
|
||||
|
||||
thread_block_gemm<float_type>(
|
||||
arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_y,
|
||||
/*threadblock_id_x, threadblock_id_y,*/
|
||||
threadblocks_per_cluster,
|
||||
// threadblock_id,
|
||||
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
||||
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
||||
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
|
||||
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster, sharedmem_per_threadblock);
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
#define TCK 16
|
||||
#define WMITER (WM / TCM)
|
||||
#define WNITER (WN / TCN)
|
||||
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS)
|
||||
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
|
||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||
#error "threadblock size too big for cluster"
|
||||
@@ -433,8 +433,6 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
const uint32_t global_a_row = k_adjusted + local_as_row;
|
||||
const uint32_t global_a_col = BM * block_m + local_as_col;
|
||||
// number of rows a full TB can read at a time
|
||||
// this is equivalent to threadblock_dim_y (assuming threadblock_dim_x ==
|
||||
// BK)
|
||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||
dim_m * global_a_row + global_a_col;
|
||||
@@ -628,23 +626,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
||||
}
|
||||
|
||||
template <typename T, bool write_to_gmem = true>
|
||||
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
const uint32_t dim_m,
|
||||
const uint32_t dim_n,
|
||||
const uint32_t dim_k,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
const uint32_t threadblock_dim_y,
|
||||
/*const uint32_t threadblock_id_x,
|
||||
const uint32_t threadblock_id_y,*/
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster,
|
||||
uint8_t *sharedmem_per_threadblock) {
|
||||
const T *A = (const T *)arg->addr_a;
|
||||
const T *B = (const T *)arg->addr_b;
|
||||
float *C = (float *)arg->addr_c;
|
||||
|
||||
const uint32_t dim_m = arg->dim_m;
|
||||
const uint32_t dim_n = arg->dim_n;
|
||||
const uint32_t dim_k = arg->dim_k;
|
||||
|
||||
const uint32_t local_a_row = tid_in_threadblock / BK;
|
||||
const uint32_t local_a_col = tid_in_threadblock % BK;
|
||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
||||
@@ -658,6 +648,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||
const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock;
|
||||
|
||||
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
||||
constexpr size_t local_a_elems = (BM * BK);
|
||||
@@ -737,7 +728,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
#endif
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
#pragma GCC unroll 1
|
||||
@@ -814,7 +806,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
|
||||
local_b, tid_in_threadblock, block_n, block_m);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
#endif
|
||||
|
||||
// consumer code: SMEM->RF and compute
|
||||
@@ -865,14 +858,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||
// Hopefully by this time, dma would have finished so that this is a
|
||||
// no-op
|
||||
// Usually, by this time, dma has finished the copy so that this
|
||||
// becomes a no-op.
|
||||
if (tid_in_threadblock == 0) {
|
||||
gemmini_fence();
|
||||
}
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
if constexpr (write_to_gmem) {
|
||||
|
||||
Reference in New Issue
Block a user