Hansung Kim
|
5ba06dfd9d
|
flash: Incomplete parallel stage-2 rowmax
|
2024-08-29 13:29:00 -07:00 |
|
Hansung Kim
|
3f20dd59c0
|
flash: Supply correct tile dims to single_tile
|
2024-08-20 19:50:45 -07:00 |
|
Hansung Kim
|
dde0372769
|
flash: Enable skipping Q*K for larger dimensions
|
2024-08-20 19:15:16 -07:00 |
|
Hansung Kim
|
615d36a5c2
|
flash: Reduce smem use for rowmax; verify result
|
2024-08-20 14:45:34 -07:00 |
|
Hansung Kim
|
d8d5df64e6
|
flash: Fix load addr for V tile; test with seqlen=128
|
2024-08-20 14:34:09 -07:00 |
|
Hansung Kim
|
df3c41aa0d
|
flash: data copy func for easy debugging
|
2024-08-19 21:41:37 -07:00 |
|
Hansung Kim
|
4080dec9d6
|
flash: Do exponential approx to rowsum and Oi as well
|
2024-08-19 20:52:57 -07:00 |
|
Hansung Kim
|
f6cc61241b
|
flash: 2nd-order taylor approx of exponential for P
|
2024-08-19 20:12:37 -07:00 |
|
Hansung Kim
|
64e48de8af
|
flash: Do accumulation of PV into O using the single_tile API
|
2024-08-19 18:08:58 -07:00 |
|
Hansung Kim
|
a98da9e3ca
|
flash: Add missing accum reg init and fix barrier count
|
2024-08-19 16:15:46 -07:00 |
|
Hansung Kim
|
1e042af571
|
flash: Write and verify O = O + PV step
|
2024-08-19 13:18:27 -07:00 |
|
Hansung Kim
|
b44b202a21
|
sgemm_impl: Rename to wmma
|
2024-08-18 16:21:22 -07:00 |
|
Hansung Kim
|
90f6effa97
|
flash: Pass smem_P arg to softmax func
|
2024-08-18 15:21:05 -07:00 |
|
Hansung Kim
|
d3de1b674a
|
flash: Compute exponents using prev/next/this rowmax values
maybe there is a better way than storing all three in sharedmem?
|
2024-08-15 22:10:02 -07:00 |
|
Hansung Kim
|
be08204e65
|
flash: Do proper allocation and init of QK/V/O tile
|
2024-08-15 21:26:14 -07:00 |
|
Hansung Kim
|
e0daf226ef
|
flash: Change kernel arg to contain qkv; strip stimulus gen from host code
test data is now generated by the python script instead of the host
binary.
|
2024-08-15 21:03:02 -07:00 |
|
Hansung Kim
|
ac44633b39
|
flash: Compile time flag for skipping GEMM
|
2024-08-15 17:40:32 -07:00 |
|
Hansung Kim
|
f844d96eea
|
flash: Initialize rowmax/rowsum cache in sharedmem
|
2024-08-15 17:28:36 -07:00 |
|
Hansung Kim
|
745aa098ed
|
flash: Optimize spad use, fix rowsum
|
2024-08-15 16:54:56 -07:00 |
|
Hansung Kim
|
e809d25305
|
flash: Fix rowsum and write fake exp
GEMM part is disabled for faster debugging, the kernel reads the result
of A*B directly from input binary.
|
2024-08-15 16:32:21 -07:00 |
|
Hansung Kim
|
53dfc690b9
|
flash: Allocate smem properly for rowsum and scratch
|
2024-08-14 21:50:20 -07:00 |
|
Hansung Kim
|
9cabe3413b
|
Fix overlapping smem in rowmax
|
2024-08-14 21:09:53 -07:00 |
|
Hansung Kim
|
692d028afd
|
Add flash attention kernel skeleton
|
2024-08-14 20:46:09 -07:00 |
|