flash: Add Gemmini-accelerated kernel
This commit is contained in:
@@ -320,12 +320,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// float *smem_O_row_scale_consume =
|
||||
// (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
|
||||
constexpr bool skip_gemm_qk = false;
|
||||
if constexpr (!skip_gemm_qk) {
|
||||
// GEMM I: S = Q*K
|
||||
//
|
||||
// FIXME: deduplicate this between GEMM II
|
||||
asm volatile("gemm_qk_start_%=:" ::);
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
@@ -587,6 +588,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
// Oi rescale
|
||||
// TODO: move this back to after softmax for better load-balancing
|
||||
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
|
||||
smem_O_row_scale, tid_in_warpgroup,
|
||||
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||
|
||||
Reference in New Issue
Block a user