Commit Graph

2659 Commits

Author SHA1 Message Date
Hansung Kim
a967c262b1 sgemm_impl: Add new block-row-major layout for DMA 2024-09-07 16:38:22 -07:00
Hansung Kim
ed9bf6f73e common.mk: Switch to -Os to prevent branch code duplication
Prevents erroneous stalls at vx_bar.  See comment in kernel.cpp
2024-09-07 15:49:19 -07:00
Hansung Kim
d2f086344d flash: Fix DMA addr stride, stop at S=Q*K 2024-09-07 15:48:37 -07:00
Hansung Kim
9f067acdb9 sgemm_impl: Remove #if 0, FP_SIZE 16 2024-09-05 19:55:36 -07:00
Hansung Kim
a832fa7b84 sgemm_impl: 128x64 tile; fix unrolled asm, comment out actual gemm 2024-09-05 16:23:32 -07:00
Hansung Kim
137df9bee2 WIP: flash: Use Gemmini DMA
Paused by the barrier bug in warp-divergent branches (tmask not being
considered)
2024-09-05 16:23:32 -07:00
Hansung Kim
87a1c2bbfc Cores per cluster 4 to 8 2024-09-05 16:23:32 -07:00
Hansung Kim
bde6f0ea2e py: Write P_expected, don't rewrite vars 2024-09-05 16:23:32 -07:00
Hansung Kim
dcd69ea304 Increase SMEM size to 256KB 2024-09-05 16:23:32 -07:00
Hansung Kim
81924b601a sgemm_impl: Rewrite tile param constraint 2024-09-05 16:23:32 -07:00
Hansung Kim
bfb414c4eb flash: Add DMA config logic 2024-09-05 16:23:32 -07:00
Richard Yan
741bb80fe8 Merge branch 'kernels' of https://github.com/hansungk/vortex-private into kernels 2024-09-05 16:22:43 -07:00
Richard Yan
dd3244fba0 large fp16 kernel 2024-09-05 16:22:38 -07:00
Hansung Kim
ced98a6ff4 sgemm_impl: Refactor DMA layout remap logic into constexpr func 2024-09-03 16:20:31 -07:00
Hansung Kim
58fa2a3e91 sgemm_impl: Switch for allowing MN-major with DMA 2024-09-03 15:12:58 -07:00
Hansung Kim
f028a97f75 sgemm_tcore: Verify wo DMA; warn untested against K-major A + DMA 2024-09-03 14:42:19 -07:00
Hansung Kim
7aa0e6cbe4 sgemm_tcore: Fix correctness for GEMMINI_DMA
Remap the logical SMEM row/col coordinates to the DMA's two-level
block-row-major layout.
2024-09-02 23:46:50 -07:00
Hansung Kim
dd1b408f56 sgemm_tcore: Add debug mode with tile copy-out 2024-09-02 21:55:55 -07:00
Hansung Kim
9d71fa44a7 sgemm_tcore: Fix invocation with compile time threadblock size 2024-09-02 17:03:46 -07:00
Hansung Kim
70273fd00d flash: Cleanup debug code 2024-09-02 00:40:05 -07:00
Hansung Kim
8125192846 flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM 2024-09-02 00:15:57 -07:00
Hansung Kim
bdd955836d sgemm_impl: Specify leading dimension to wmma load
This is necessary for when loading a subtile from a full tile in SMEM
into RF, but that subtile is split by non-major dimension.
2024-09-02 00:14:35 -07:00
Hansung Kim
602fe4a400 flash: Change timing for QKV move
Verified with warp_specialized false; true remains to be fixed.
2024-09-01 22:06:46 -07:00
Hansung Kim
aea257349a flash: Correct schedule with inter-warpgroup barriers 2024-09-01 20:40:51 -07:00
Hansung Kim
e5e65312d2 flash: Restructure to inter-warpgroup parallelism
This is similar to
https://tridao.me/blog/2024/flash3/#inter-warpgroup-overlapping-with-pingpong-scheduling
2024-09-01 19:58:33 -07:00
Hansung Kim
f7603b18d3 flash.py: Write V to file 2024-09-01 18:17:05 -07:00
Hansung Kim
6cc1b5ca37 flash: Reduce smem_scratchpad alloc size 2024-09-01 16:02:06 -07:00
Hansung Kim
817cc9a5a5 flash: Fix overlap in smem alloc for P tile 2024-08-31 16:24:28 -07:00
Hansung Kim
bdd6e6a9ce flash: Double-buffer between online softmax and GEMM II
TODO: O_after_PV at the last stage is incorrect.
2024-08-30 22:47:55 -07:00
Hansung Kim
042b47ff19 flash: Restructure for warp-specialization 2024-08-30 22:14:45 -07:00
Hansung Kim
1cfab40711 flash: Do Oi rescale with PV
Since Oi rescale has data dependency to previous Oi which gets produced
at the PV GEMM, both rescale+GEMM needs to be in a single pipeline stage
or otherwise it requires a stall.  So instead, compute only the
rescale factor in the online softmax stage and apply rescaling right
before PV.
2024-08-30 20:11:07 -07:00
Hansung Kim
adf717eb14 common.mk: Embed operand.c to ELF, track header dependency 2024-08-30 17:24:43 -07:00
Hansung Kim
986d507223 flash: Fix single-tile GEMM for warp-specialized
With 4 warps, we can only do 32x64 GEMM; serialize 64x64 into 2 32x64
GEMM calls by split by the row.
2024-08-30 17:12:46 -07:00
Hansung Kim
72b6004e24 flash: Fix online softmax for warp-specialized
Note: now that threads_per_threadblock is passed as compile-time
constant, the compiler likes to completely loop unroll which can cause a
lot of stack spills.

todo fix GEMM part.
2024-08-29 21:50:02 -07:00
Hansung Kim
ee0295cbef sgemm_impl: Accept threads_per_threadblock in load_tile_to_smem
Needed for warp-specialized kernels.
2024-08-29 21:43:57 -07:00
Hansung Kim
fd1ab358fa flash: Add DOUBLE_BUF compile-time param (wip) 2024-08-29 14:18:56 -07:00
Hansung Kim
5ba06dfd9d flash: Incomplete parallel stage-2 rowmax 2024-08-29 13:29:00 -07:00
Hansung Kim
4260bf7d6e Generate S matrix, pull out FA stuff from basic script 2024-08-28 16:13:38 -07:00
Hansung Kim
3f20dd59c0 flash: Supply correct tile dims to single_tile 2024-08-20 19:50:45 -07:00
Hansung Kim
091f40c365 sgemm_impl: Parameterize BM/BN/BK in single_tile 2024-08-20 19:41:34 -07:00
Hansung Kim
dde0372769 flash: Enable skipping Q*K for larger dimensions 2024-08-20 19:15:16 -07:00
Hansung Kim
526c2bd334 sgemm_impl: load_tile: accept k_index for consistency + fix gmem addr gen 2024-08-20 17:46:35 -07:00
Hansung Kim
60aec1de8d flash.py: Fix row-wise scaling of O, col_to_save 2024-08-20 14:49:25 -07:00
Hansung Kim
615d36a5c2 flash: Reduce smem use for rowmax; verify result 2024-08-20 14:45:34 -07:00
Hansung Kim
d8d5df64e6 flash: Fix load addr for V tile; test with seqlen=128 2024-08-20 14:34:09 -07:00
Hansung Kim
df3c41aa0d flash: data copy func for easy debugging 2024-08-19 21:41:37 -07:00
Hansung Kim
2f7fb372f1 Fix range for
i'm a python noob
2024-08-19 21:19:36 -07:00
Hansung Kim
09afd43904 More flash in generate_matrix 2024-08-19 21:16:37 -07:00
Hansung Kim
351e17c849 Separate golden script for flashattn 2024-08-19 21:16:13 -07:00
Hansung Kim
4080dec9d6 flash: Do exponential approx to rowsum and Oi as well 2024-08-19 20:52:57 -07:00