flash: Compile time flag for skipping GEMM

This commit is contained in:
Hansung Kim
2024-08-15 17:40:32 -07:00
parent f844d96eea
commit ac44633b39

View File

@@ -260,26 +260,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
uint8_t *sharedmem_scratchpad =
sharedmem_rowmax - sharedmem_scratchpad_size;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
// initialize rowmax/rowsum values in sharedmem
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
(float *)sharedmem_scratchpad,
(float *)sharedmem_rowmax,
(float *)sharedmem_rowsum);
// thread_block_gemm<float_type, /*write_to_gmem=*/true>(
// (const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
// (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
// arg->dim_k, tid_in_threadblock, threads_per_threadblock,
// threadblocks_per_cluster, threadblock_id_in_cluster,
// sharedmem_per_threadblock);
#define SKIP_GEMM
#ifndef SKIP_GEMM
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
(float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
arg->dim_k, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
// protect writes of GEMM results before softmax
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_flashattn((float *)arg->addr_a /* smem_S, */, tid_in_threadblock,
float *tile_S = (float *)smem_S;
#else
float *tile_S = (float *)arg->addr_a;
#endif
thread_block_flashattn(tile_S, tid_in_threadblock,
threads_per_threadblock, threadblock_id_in_cluster,
(float *)sharedmem_scratchpad,
(float *)sharedmem_rowmax, (float *)sharedmem_rowsum);