diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 92955e2f..5ee64f75 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -54,8 +54,9 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, } inline void thread_block_online_softmax( - float *smem_S, float *smem_O, float *smem_P, + const float *smem_S, float *smem_O, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, + const uint32_t threadblocks_per_cluster, const uint32_t threadblock_id_in_cluster, float *smem_scratchpad, float *smem_rowmax, float *smem_rowsum) { asm volatile("thread_block_flashattn_start_%=:" ::); @@ -64,7 +65,7 @@ inline void thread_block_online_softmax( const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threads_per_threadblock; + NUM_WARPS / threadblocks_per_cluster; // float ft[8]; // asm volatile("fmv.s %0, f16" : "=f"(ft[0])); @@ -148,7 +149,6 @@ inline void thread_block_online_softmax( : "=f"(rowmax) : "f"(rowmax), "f"(prev_rowmax)); smem_rowmax_new[row] = rowmax; - gmem_tmp0[row] = rowmax; } #endif @@ -177,18 +177,16 @@ inline void thread_block_online_softmax( #pragma GCC unroll for (int i = 0; i < exp_per_row_iter; i++) { float f0 = smem_S[thread_offset]; - // float f1 = S[thread_offset + 1]; + + // check Q*K result + gmem_tmp0[thread_offset] = f0;; // FIXME: placeholder for proper exp f0 -= rowmax_new; - // f1 -= rowmax_new; - // float16_t h0 = NN_float_to_half(f0); - // float16_t h1 = NN_float_to_half(f1); // Store S transposed to the shared memory smem_P[thread_offset] = f0; - // S[thread_offset + 1] = f1; gmem_tmp1[thread_offset] = f0; thread_offset += NUM_THREADS; @@ -261,7 +259,6 @@ inline void thread_block_online_softmax( threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - } asm volatile("thread_block_flashattn_finish_%=:" ::); @@ -299,15 +296,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "static" shared memory allocation. This would determine maximum // threadblock occupancy in a cluster - const uint32_t smem_QK_size = B_ROW * B_COL; - const uint32_t smem_V_size = B_COL * HEADDIM; - const uint32_t smem_O_size = B_COL * HEADDIM; + constexpr uint32_t smem_Q_size = B_ROW * HEADDIM; + constexpr uint32_t smem_QK_size = B_ROW * B_COL; + constexpr uint32_t smem_V_size = B_COL * HEADDIM; + constexpr uint32_t smem_O_size = B_COL * HEADDIM; uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR + sizeof(float_type) * (smem_QK_size + smem_V_size + smem_O_size) * threadblock_id_in_cluster); + float *smem_Q = reinterpret_cast(smem_per_threadblock); + float *smem_K = smem_Q + smem_Q_size; + // in-place multiplication of QK into Q float *smem_S = reinterpret_cast(smem_per_threadblock); float *smem_P = smem_S; // in-place update from S to P float *smem_V = @@ -330,42 +331,73 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = smem_rowmax - smem_scratchpad_size; const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threads_per_threadblock; + NUM_WARPS / threadblocks_per_cluster; // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, smem_O, smem_rowmax, smem_rowsum); -#define SKIP_GEMM + const float *gmem_Q = reinterpret_cast(arg->addr_q); + const float *gmem_K = reinterpret_cast(arg->addr_k); + const float *gmem_V = reinterpret_cast(arg->addr_v); + float *gmem_O = reinterpret_cast(arg->addr_o); + float *gmem_tmp0 = reinterpret_cast(0xd0000000UL); + +// #define SKIP_GEMM #ifndef SKIP_GEMM +#if 0 thread_block_gemm( (const float_type *)arg->addr_q, (const float_type *)arg->addr_k, - (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n, - arg->dim_k, tid_in_threadblock, threads_per_threadblock, + (float *)smem_S /*write result to SMEM */, B_ROW, B_COL, + HEADDIM, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, smem_per_threadblock); - // protect writes of GEMM results before softmax +#else + + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + // load Q + static_assert(B_ROW == B_COL, "currently only supports square tiles"); + load_tile_to_smem(B_ROW, 0, 0, gmem_Q, smem_Q, tid_in_threadblock); + // load K + load_tile_to_smem(B_COL, 0, 0, gmem_K, smem_K, tid_in_threadblock); + + // GMEM->SMEM and compute barrier threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - float *tile_S = (float *)smem_S; + thread_block_gemm_single_tile( + smem_Q, smem_K, smem_S, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster); +#endif + + // protect GEMM result writes (smem_S) before softmax + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + const float *tile_S = (float *)smem_S; #else float *tile_S = (float *)arg->addr_q; #endif - // FIXME: V is stored in d0000000 for debugging purpose - const float *gmem_V = reinterpret_cast(arg->addr_k); - - thread_block_online_softmax( - tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, - threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum); + thread_block_online_softmax(tile_S, smem_O, smem_P, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster, smem_scratchpad, + smem_rowmax, smem_rowsum); // FIXME unnecessary? threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - float *gmem_tmp2 = reinterpret_cast(0xf0000000UL); + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); load_tile_to_smem( B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock); @@ -376,8 +408,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: support MN_major for A for ideal performance thread_block_gemm_single_tile( - smem_P, smem_V, gmem_tmp2 /*smem_O*/, tid_in_threadblock, - threads_per_threadblock); + smem_P, smem_V, gmem_O /*smem_O*/, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);