flash: Fix load addr for V tile; test with seqlen=128

This commit is contained in:
Hansung Kim
2024-08-20 14:34:09 -07:00
parent df3c41aa0d
commit d8d5df64e6
2 changed files with 6 additions and 3 deletions

View File

@@ -475,8 +475,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<0>();
initialize_accum_regs<1>();
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
// V dimension is [seqlen, headdim], stored N(headdim)-major
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM>(
HEADDIM, 0 /* 0 because always reads the full N-dimension */,
tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);

View File

@@ -99,7 +99,7 @@ int main(int argc, char *argv[]) {
std::cout << "open device connection" << std::endl;
RT_CHECK(vx_dev_open(&device));
uint32_t dim_seqlen = 64;
uint32_t dim_seqlen = 128;
uint32_t dim_headdim = 64;
using float_type = half;