sgemm_tcore: Unpack arg params, remove threadblock_dim_y
thread_block_gemm is meant to be reusable, so it shouldn't assume what the kernel arg struct looks like. threadblock_dim_y was ambiguous and didn't match the literal name either (it was used as # of warps that participate in a barrier).
This commit is contained in:
@@ -31,7 +31,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t threadblocks_per_cluster =
|
const uint32_t threadblocks_per_cluster =
|
||||||
hw_threads_per_cluster / threads_per_threadblock;
|
hw_threads_per_cluster / threads_per_threadblock;
|
||||||
|
|
||||||
const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_cluster;
|
|
||||||
const int threadblock_id = task_id / threads_per_threadblock;
|
const int threadblock_id = task_id / threads_per_threadblock;
|
||||||
const int threadblock_id_in_cluster =
|
const int threadblock_id_in_cluster =
|
||||||
threadblock_id % threadblocks_per_cluster;
|
threadblock_id % threadblocks_per_cluster;
|
||||||
@@ -51,11 +50,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ *
|
DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ *
|
||||||
(2 * BM * BK) * threadblock_id_in_cluster);
|
(2 * BM * BK) * threadblock_id_in_cluster);
|
||||||
|
|
||||||
thread_block_gemm<float_type>(
|
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
||||||
arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_y,
|
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
||||||
/*threadblock_id_x, threadblock_id_y,*/
|
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
|
||||||
threadblocks_per_cluster,
|
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
||||||
// threadblock_id,
|
|
||||||
threadblock_id_in_cluster, sharedmem_per_threadblock);
|
threadblock_id_in_cluster, sharedmem_per_threadblock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@
|
|||||||
#define TCK 16
|
#define TCK 16
|
||||||
#define WMITER (WM / TCM)
|
#define WMITER (WM / TCM)
|
||||||
#define WNITER (WN / TCN)
|
#define WNITER (WN / TCN)
|
||||||
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS)
|
#define ELEM_PER_THREAD (WM * WN / NUM_THREADS)
|
||||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||||
#error "threadblock size too big for cluster"
|
#error "threadblock size too big for cluster"
|
||||||
@@ -433,8 +433,6 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
const uint32_t global_a_row = k_adjusted + local_as_row;
|
const uint32_t global_a_row = k_adjusted + local_as_row;
|
||||||
const uint32_t global_a_col = BM * block_m + local_as_col;
|
const uint32_t global_a_col = BM * block_m + local_as_col;
|
||||||
// number of rows a full TB can read at a time
|
// number of rows a full TB can read at a time
|
||||||
// this is equivalent to threadblock_dim_y (assuming threadblock_dim_x ==
|
|
||||||
// BK)
|
|
||||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
||||||
const float *global_a = reinterpret_cast<const float *>(A) +
|
const float *global_a = reinterpret_cast<const float *>(A) +
|
||||||
dim_m * global_a_row + global_a_col;
|
dim_m * global_a_row + global_a_col;
|
||||||
@@ -628,23 +626,15 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool write_to_gmem = true>
|
template <typename T, bool write_to_gmem = true>
|
||||||
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||||
|
const uint32_t dim_m,
|
||||||
|
const uint32_t dim_n,
|
||||||
|
const uint32_t dim_k,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t threads_per_threadblock,
|
||||||
const uint32_t threadblock_dim_y,
|
|
||||||
/*const uint32_t threadblock_id_x,
|
|
||||||
const uint32_t threadblock_id_y,*/
|
|
||||||
const uint32_t threadblocks_per_cluster,
|
const uint32_t threadblocks_per_cluster,
|
||||||
const uint32_t threadblock_id_in_cluster,
|
const uint32_t threadblock_id_in_cluster,
|
||||||
uint8_t *sharedmem_per_threadblock) {
|
uint8_t *sharedmem_per_threadblock) {
|
||||||
const T *A = (const T *)arg->addr_a;
|
|
||||||
const T *B = (const T *)arg->addr_b;
|
|
||||||
float *C = (float *)arg->addr_c;
|
|
||||||
|
|
||||||
const uint32_t dim_m = arg->dim_m;
|
|
||||||
const uint32_t dim_n = arg->dim_n;
|
|
||||||
const uint32_t dim_k = arg->dim_k;
|
|
||||||
|
|
||||||
const uint32_t local_a_row = tid_in_threadblock / BK;
|
const uint32_t local_a_row = tid_in_threadblock / BK;
|
||||||
const uint32_t local_a_col = tid_in_threadblock % BK;
|
const uint32_t local_a_col = tid_in_threadblock % BK;
|
||||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
const uint32_t local_as_row = tid_in_threadblock / BM;
|
||||||
@@ -658,6 +648,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock;
|
||||||
|
|
||||||
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
volatile T *local_a = reinterpret_cast<T *>(sharedmem_per_threadblock);
|
||||||
constexpr size_t local_a_elems = (BM * BK);
|
constexpr size_t local_a_elems = (BM * BK);
|
||||||
@@ -737,7 +728,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
@@ -814,7 +806,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
|
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
|
||||||
local_b, tid_in_threadblock, block_n, block_m);
|
local_b, tid_in_threadblock, block_n, block_m);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// consumer code: SMEM->RF and compute
|
// consumer code: SMEM->RF and compute
|
||||||
@@ -865,14 +858,15 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||||
// Hopefully by this time, dma would have finished so that this is a
|
// Usually, by this time, dma has finished the copy so that this
|
||||||
// no-op
|
// becomes a no-op.
|
||||||
if (tid_in_threadblock == 0) {
|
if (tid_in_threadblock == 0) {
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_gmem) {
|
if constexpr (write_to_gmem) {
|
||||||
|
|||||||
Reference in New Issue
Block a user