flash: Fix load addr for V tile; test with seqlen=128
This commit is contained in:
@@ -475,8 +475,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM>(
|
||||
HEADDIM, 0 /* 0 because always reads the full N-dimension */,
|
||||
tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
@@ -99,7 +99,7 @@ int main(int argc, char *argv[]) {
|
||||
std::cout << "open device connection" << std::endl;
|
||||
RT_CHECK(vx_dev_open(&device));
|
||||
|
||||
uint32_t dim_seqlen = 64;
|
||||
uint32_t dim_seqlen = 128;
|
||||
uint32_t dim_headdim = 64;
|
||||
|
||||
using float_type = half;
|
||||
|
||||
Reference in New Issue
Block a user