Commit Graph

2515 Commits

Author SHA1 Message Date
Hansung Kim
d8d5df64e6 flash: Fix load addr for V tile; test with seqlen=128 2024-08-20 14:34:09 -07:00
Hansung Kim
df3c41aa0d flash: data copy func for easy debugging 2024-08-19 21:41:37 -07:00
Hansung Kim
2f7fb372f1 Fix range for
i'm a python noob
2024-08-19 21:19:36 -07:00
Hansung Kim
09afd43904 More flash in generate_matrix 2024-08-19 21:16:37 -07:00
Hansung Kim
351e17c849 Separate golden script for flashattn 2024-08-19 21:16:13 -07:00
Hansung Kim
4080dec9d6 flash: Do exponential approx to rowsum and Oi as well 2024-08-19 20:52:57 -07:00
Hansung Kim
f6cc61241b flash: 2nd-order taylor approx of exponential for P 2024-08-19 20:12:37 -07:00
Hansung Kim
68eb271916 Add operand.c to the link script 2024-08-19 18:09:16 -07:00
Hansung Kim
64e48de8af flash: Do accumulation of PV into O using the single_tile API 2024-08-19 18:08:58 -07:00
Hansung Kim
03c61d72ff sgemm_impl: Add param to load accumulation tile in single_tile 2024-08-19 18:08:58 -07:00
Hansung Kim
134ba825de sgemm_impl: Fix typo bug for BK_adjusted 2024-08-19 18:02:00 -07:00
Hansung Kim
3f4abc542c tensor: Fix dimensions and makefile 2024-08-19 17:37:26 -07:00
Hansung Kim
a98da9e3ca flash: Add missing accum reg init and fix barrier count 2024-08-19 16:15:46 -07:00
Hansung Kim
7ac038fadf sgemm_impl: Rename initialize_C 2024-08-19 16:12:35 -07:00
Hansung Kim
4aba018733 sgemm_impl: Fix wrong barrier count; add barrier for write_to_smem 2024-08-19 15:33:23 -07:00
Hansung Kim
e93e54cdec sgemm_impl: Drop volatile quanitifier
doesn't seem to do much & creates excessive type errors.
2024-08-19 15:19:53 -07:00
Hansung Kim
1e042af571 flash: Write and verify O = O + PV step 2024-08-19 13:18:27 -07:00
Hansung Kim
42ddb9a48e sgemm_impl: Accept layout template param at gemm_single_tile and wmma_load 2024-08-19 13:16:51 -07:00
Hansung Kim
1b133e7b5c sgemm_impl: Rename dmem load function 2024-08-18 22:26:49 -07:00
Hansung Kim
46b5047775 sgemm_impl: Remove GMEM_COALESCED_A option
Uncoalesced GMEM accesses is verified to yield slow performance and the
relevant code is not used anymore; remove the cruft
2024-08-18 22:26:02 -07:00
Hansung Kim
04643fa64d sgemm_impl: Refactor dmem_load into one unified logic
Replace the confusing logic that had slightly different use of BM/BN/BK
for A and B, into one logic that accepts matrix memory layout as a
proper argument & does compile-time logic to determine the right
dimensions.

TODO: !GMEM_COALESCED_A is not updated yet
2024-08-18 22:05:22 -07:00
Hansung Kim
b44b202a21 sgemm_impl: Rename to wmma 2024-08-18 16:21:22 -07:00
Hansung Kim
b978bf8757 sgemm_impl: Split tile offset addr gen from wmma store
& add an option to write to smem in gemm_single_tile.
2024-08-18 16:10:29 -07:00
Hansung Kim
90f6effa97 flash: Pass smem_P arg to softmax func 2024-08-18 15:21:05 -07:00
Hansung Kim
d0809d292a sgemm: Specify A/B tile SMEM address via template args
& split single-time GEMM into a separate function.
2024-08-16 18:01:57 -07:00
Hansung Kim
64b9717064 sgemm_tcore: Remove duplicate float_type decl 2024-08-16 16:26:18 -07:00
Hansung Kim
d3de1b674a flash: Compute exponents using prev/next/this rowmax values
maybe there is a better way than storing all three in sharedmem?
2024-08-15 22:10:02 -07:00
Hansung Kim
be08204e65 flash: Do proper allocation and init of QK/V/O tile 2024-08-15 21:26:14 -07:00
Hansung Kim
0ea27dd15a flash: gitignore 2024-08-15 21:04:59 -07:00
Hansung Kim
e0daf226ef flash: Change kernel arg to contain qkv; strip stimulus gen from host code
test data is now generated by the python script instead of the host
binary.
2024-08-15 21:03:02 -07:00
Hansung Kim
a1858e0c80 sgemm_impl: Parameterize BK/TCK by FP_SIZE 2024-08-15 20:33:33 -07:00
Hansung Kim
fd2ff6208d Generate golden data for flash in generate_matrix.py 2024-08-15 17:41:57 -07:00
Hansung Kim
ac44633b39 flash: Compile time flag for skipping GEMM 2024-08-15 17:40:32 -07:00
Hansung Kim
f844d96eea flash: Initialize rowmax/rowsum cache in sharedmem 2024-08-15 17:28:36 -07:00
Hansung Kim
745aa098ed flash: Optimize spad use, fix rowsum 2024-08-15 16:54:56 -07:00
Hansung Kim
e809d25305 flash: Fix rowsum and write fake exp
GEMM part is disabled for faster debugging, the kernel reads the result
of A*B directly from input binary.
2024-08-15 16:32:21 -07:00
Hansung Kim
53dfc690b9 flash: Allocate smem properly for rowsum and scratch 2024-08-14 21:50:20 -07:00
Hansung Kim
9cabe3413b Fix overlapping smem in rowmax 2024-08-14 21:09:53 -07:00
Hansung Kim
692d028afd Add flash attention kernel skeleton 2024-08-14 20:46:09 -07:00
Hansung Kim
014f7cd06f sgemm_tcore: Unpack arg params, remove threadblock_dim_y
thread_block_gemm is meant to be reusable, so it shouldn't assume what
the kernel arg struct looks like.

threadblock_dim_y was ambiguous and didn't match the literal name either
(it was used as # of warps that participate in a barrier).
2024-08-14 20:34:49 -07:00
Hansung Kim
70919c39c9 Encode dependency to sgemm header in makefile 2024-08-14 20:03:07 -07:00
Hansung Kim
1b1264207b sgemm_tcore: Add compile-time write_to_gmem param to thread_block_gemm 2024-08-14 17:48:31 -07:00
Hansung Kim
ee6339a35f sgemm_tcore: Split all impl code into sgemm_impl.hpp
This is to make thread_block_gemm a re-usable library function for GEMM
operations for use in other kernels.
2024-08-14 16:24:48 -07:00
Hansung Kim
0534e5d1f6 sgemm_tcore: Fix addr gen for GMEM->SMEM for M-major A
This fixes correctness for TRANSPOSE_AT_PRODUCE/COLUMN=0/0, provided the
matrices are already stored in the correct layout in GMEM.
2024-08-14 15:35:35 -07:00
Hansung Kim
409424b032 sgemm_tcore: Fix fp16 addr gen in vx_wmma_load 2024-08-14 13:48:03 -07:00
Hansung Kim
e69fbea83a sgemm_tcore: Fix casting error 2024-08-12 17:57:50 -07:00
Hansung Kim
95e3e96c6c tensor: Change B in-memory layout to column-major 2024-08-12 15:22:07 -07:00
Hansung Kim
07dd9e35a0 tensor: Fix dimensions for fp16 in script 2024-08-12 15:22:07 -07:00
Hansung Kim
c1906ebb4f tensor: Embed binary instead of hardcoding literals
the C compiler doesn't support fp16
2024-08-12 15:22:07 -07:00
Hansung Kim
1b5daccac9 tensor: Generate fp16-packed matrix in script 2024-08-12 15:22:07 -07:00