Hansung Kim
9d71fa44a7
sgemm_tcore: Fix invocation with compile time threadblock size
2024-09-02 17:03:46 -07:00
Hansung Kim
70273fd00d
flash: Cleanup debug code
2024-09-02 00:40:05 -07:00
Hansung Kim
8125192846
flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM
2024-09-02 00:15:57 -07:00
Hansung Kim
bdd955836d
sgemm_impl: Specify leading dimension to wmma load
...
This is necessary for when loading a subtile from a full tile in SMEM
into RF, but that subtile is split by non-major dimension.
2024-09-02 00:14:35 -07:00
Hansung Kim
602fe4a400
flash: Change timing for QKV move
...
Verified with warp_specialized false; true remains to be fixed.
2024-09-01 22:06:46 -07:00
Hansung Kim
aea257349a
flash: Correct schedule with inter-warpgroup barriers
2024-09-01 20:40:51 -07:00
Hansung Kim
e5e65312d2
flash: Restructure to inter-warpgroup parallelism
...
This is similar to
https://tridao.me/blog/2024/flash3/#inter-warpgroup-overlapping-with-pingpong-scheduling
2024-09-01 19:58:33 -07:00
Hansung Kim
f7603b18d3
flash.py: Write V to file
2024-09-01 18:17:05 -07:00
Hansung Kim
6cc1b5ca37
flash: Reduce smem_scratchpad alloc size
2024-09-01 16:02:06 -07:00
Hansung Kim
817cc9a5a5
flash: Fix overlap in smem alloc for P tile
2024-08-31 16:24:28 -07:00
Hansung Kim
bdd6e6a9ce
flash: Double-buffer between online softmax and GEMM II
...
TODO: O_after_PV at the last stage is incorrect.
2024-08-30 22:47:55 -07:00
Hansung Kim
042b47ff19
flash: Restructure for warp-specialization
2024-08-30 22:14:45 -07:00
Hansung Kim
1cfab40711
flash: Do Oi rescale with PV
...
Since Oi rescale has data dependency to previous Oi which gets produced
at the PV GEMM, both rescale+GEMM needs to be in a single pipeline stage
or otherwise it requires a stall. So instead, compute only the
rescale factor in the online softmax stage and apply rescaling right
before PV.
2024-08-30 20:11:07 -07:00
Hansung Kim
adf717eb14
common.mk: Embed operand.c to ELF, track header dependency
2024-08-30 17:24:43 -07:00
Hansung Kim
986d507223
flash: Fix single-tile GEMM for warp-specialized
...
With 4 warps, we can only do 32x64 GEMM; serialize 64x64 into 2 32x64
GEMM calls by split by the row.
2024-08-30 17:12:46 -07:00
Hansung Kim
72b6004e24
flash: Fix online softmax for warp-specialized
...
Note: now that threads_per_threadblock is passed as compile-time
constant, the compiler likes to completely loop unroll which can cause a
lot of stack spills.
todo fix GEMM part.
2024-08-29 21:50:02 -07:00
Hansung Kim
ee0295cbef
sgemm_impl: Accept threads_per_threadblock in load_tile_to_smem
...
Needed for warp-specialized kernels.
2024-08-29 21:43:57 -07:00
Hansung Kim
fd1ab358fa
flash: Add DOUBLE_BUF compile-time param (wip)
2024-08-29 14:18:56 -07:00
Hansung Kim
5ba06dfd9d
flash: Incomplete parallel stage-2 rowmax
2024-08-29 13:29:00 -07:00
Hansung Kim
4260bf7d6e
Generate S matrix, pull out FA stuff from basic script
2024-08-28 16:13:38 -07:00
Hansung Kim
3f20dd59c0
flash: Supply correct tile dims to single_tile
2024-08-20 19:50:45 -07:00
Hansung Kim
091f40c365
sgemm_impl: Parameterize BM/BN/BK in single_tile
2024-08-20 19:41:34 -07:00
Hansung Kim
dde0372769
flash: Enable skipping Q*K for larger dimensions
2024-08-20 19:15:16 -07:00
Hansung Kim
526c2bd334
sgemm_impl: load_tile: accept k_index for consistency + fix gmem addr gen
2024-08-20 17:46:35 -07:00
Hansung Kim
60aec1de8d
flash.py: Fix row-wise scaling of O, col_to_save
2024-08-20 14:49:25 -07:00
Hansung Kim
615d36a5c2
flash: Reduce smem use for rowmax; verify result
2024-08-20 14:45:34 -07:00
Hansung Kim
d8d5df64e6
flash: Fix load addr for V tile; test with seqlen=128
2024-08-20 14:34:09 -07:00
Hansung Kim
df3c41aa0d
flash: data copy func for easy debugging
2024-08-19 21:41:37 -07:00
Hansung Kim
2f7fb372f1
Fix range for
...
i'm a python noob
2024-08-19 21:19:36 -07:00
Hansung Kim
09afd43904
More flash in generate_matrix
2024-08-19 21:16:37 -07:00
Hansung Kim
351e17c849
Separate golden script for flashattn
2024-08-19 21:16:13 -07:00
Hansung Kim
4080dec9d6
flash: Do exponential approx to rowsum and Oi as well
2024-08-19 20:52:57 -07:00
Hansung Kim
f6cc61241b
flash: 2nd-order taylor approx of exponential for P
2024-08-19 20:12:37 -07:00
Hansung Kim
68eb271916
Add operand.c to the link script
2024-08-19 18:09:16 -07:00
Hansung Kim
64e48de8af
flash: Do accumulation of PV into O using the single_tile API
2024-08-19 18:08:58 -07:00
Hansung Kim
03c61d72ff
sgemm_impl: Add param to load accumulation tile in single_tile
2024-08-19 18:08:58 -07:00
Hansung Kim
134ba825de
sgemm_impl: Fix typo bug for BK_adjusted
2024-08-19 18:02:00 -07:00
Hansung Kim
3f4abc542c
tensor: Fix dimensions and makefile
2024-08-19 17:37:26 -07:00
Hansung Kim
a98da9e3ca
flash: Add missing accum reg init and fix barrier count
2024-08-19 16:15:46 -07:00
Hansung Kim
7ac038fadf
sgemm_impl: Rename initialize_C
2024-08-19 16:12:35 -07:00
Hansung Kim
4aba018733
sgemm_impl: Fix wrong barrier count; add barrier for write_to_smem
2024-08-19 15:33:23 -07:00
Hansung Kim
e93e54cdec
sgemm_impl: Drop volatile quanitifier
...
doesn't seem to do much & creates excessive type errors.
2024-08-19 15:19:53 -07:00
Hansung Kim
1e042af571
flash: Write and verify O = O + PV step
2024-08-19 13:18:27 -07:00
Hansung Kim
42ddb9a48e
sgemm_impl: Accept layout template param at gemm_single_tile and wmma_load
2024-08-19 13:16:51 -07:00
Hansung Kim
1b133e7b5c
sgemm_impl: Rename dmem load function
2024-08-18 22:26:49 -07:00
Hansung Kim
46b5047775
sgemm_impl: Remove GMEM_COALESCED_A option
...
Uncoalesced GMEM accesses is verified to yield slow performance and the
relevant code is not used anymore; remove the cruft
2024-08-18 22:26:02 -07:00
Hansung Kim
04643fa64d
sgemm_impl: Refactor dmem_load into one unified logic
...
Replace the confusing logic that had slightly different use of BM/BN/BK
for A and B, into one logic that accepts matrix memory layout as a
proper argument & does compile-time logic to determine the right
dimensions.
TODO: !GMEM_COALESCED_A is not updated yet
2024-08-18 22:05:22 -07:00
Hansung Kim
b44b202a21
sgemm_impl: Rename to wmma
2024-08-18 16:21:22 -07:00
Hansung Kim
b978bf8757
sgemm_impl: Split tile offset addr gen from wmma store
...
& add an option to write to smem in gemm_single_tile.
2024-08-18 16:10:29 -07:00
Hansung Kim
90f6effa97
flash: Pass smem_P arg to softmax func
2024-08-18 15:21:05 -07:00