From bfb414c4eb8118fba12ec58de2662fb91abf1786 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 3 Sep 2024 16:21:28 -0700 Subject: [PATCH] flash: Add DMA config logic --- tests/regression/flash_attention/kernel.cpp | 26 ++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 430817eb..b1a4139b 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -74,9 +74,9 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, } inline void thread_block_copy_rowmax(const float *src, float *dest, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblock_id_in_cluster) { + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, + const uint32_t threadblock_id_in_cluster) { asm volatile("threadblock_copy_rowmax_start_%=:" ::); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; @@ -617,6 +617,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } + if constexpr (GEMMINI_DMA) { + if (tid_in_threadblock == 0) { + gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); + // gemmini_extended_config_ex(dataflow, act & 3, 0, 1, a_transpose, + // b_transpose); + + // configure DMA for Q tile + gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 0); + // configure DMA for K tile + gemmini_extended3_config_ld(B_COL * sizeof(elem_t), MVIN_SCALE_IDENTITY, + false, 1); + // configure DMA for Q*K store + gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, + MVIN_SCALE_IDENTITY); + + gemmini_fence(); + } + } + // read Q and K into SMEM before the loop starts // static_assert(B_ROW == B_COL, "currently only supports square tiles");