flash: Add Gemmini-accelerated kernel

This commit is contained in:
Hansung Kim
2024-09-07 22:40:50 -07:00
parent b3be271b88
commit 2e1485877d
4 changed files with 689 additions and 3 deletions

View File

@@ -320,12 +320,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// float *smem_O_row_scale_consume =
// (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
asm volatile("gemm_qk_start_%=:" ::);
constexpr bool skip_gemm_qk = false;
if constexpr (!skip_gemm_qk) {
// GEMM I: S = Q*K
//
// FIXME: deduplicate this between GEMM II
asm volatile("gemm_qk_start_%=:" ::);
if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators before GEMM
initialize_accum_regs<0>();
@@ -587,6 +588,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// Oi rescale
// TODO: move this back to after softmax for better load-balancing
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);