flash: Optimize smem alloc for tcore for 8banks
Divide into first half & last half for warpgroup 0 & 1, and allocate Q/K and P/V in different banks for parallel acccess.
This commit is contained in:
@@ -11,8 +11,8 @@
|
||||
#define ROW_REMAINDER_LOGIC
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
constexpr bool TENSOR_CORE = false;
|
||||
constexpr bool WARP_SPECIALIZED = true;
|
||||
constexpr bool TENSOR_CORE = true;
|
||||
|
||||
// temporary safety stop for wrong configs
|
||||
static_assert(NUM_CORES == 4);
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
constexpr bool DEBUG = false;
|
||||
constexpr bool Q_IS_K_MAJOR = true;
|
||||
|
||||
// temporary safety stop
|
||||
static_assert(TENSOR_CORE && WARP_SPECIALIZED);
|
||||
|
||||
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// @perf: All threads are running these compute whose result is mostly same
|
||||
// across the threadblock
|
||||
@@ -90,80 +93,78 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
"flashattention kernel assumes 1 threadblock occupancy per cluster");
|
||||
uint8_t *smem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||
DEV_SMEM_START_ADDR);
|
||||
float *smem_cursor = reinterpret_cast<float *>(smem_per_threadblock);
|
||||
// constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||
// float *smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR);
|
||||
float *smem_Q0 = smem_cursor;
|
||||
smem_cursor += smem_Q_size;
|
||||
float *smem_Q1 = smem_cursor;
|
||||
smem_cursor += smem_Q_size;
|
||||
float *smem_K0 = smem_cursor;
|
||||
smem_cursor += smem_K_size;
|
||||
float *smem_K1 = smem_cursor;
|
||||
smem_cursor += smem_K_size;
|
||||
float *smem_V0 = smem_cursor;
|
||||
smem_cursor += smem_V_size;
|
||||
float *smem_V1 = smem_cursor;
|
||||
smem_cursor += smem_V_size;
|
||||
float *smem_S0 = smem_cursor;
|
||||
smem_cursor += smem_QK_size;
|
||||
float *smem_S1 = smem_cursor;
|
||||
smem_cursor += smem_QK_size;
|
||||
float *smem_P0 = smem_S0; // in-place update
|
||||
float *smem_P1 = smem_S1; // in-place update
|
||||
float *smem_O0 = smem_cursor;
|
||||
smem_cursor += smem_O_size;
|
||||
float *smem_O1 = smem_cursor;
|
||||
smem_cursor += smem_O_size;
|
||||
constexpr uint32_t smem_start = DEV_SMEM_START_ADDR;
|
||||
constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet3 = 3 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet4 = 4 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet5 = 5 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet6 = 6 * (SMEM_SIZE / 8);
|
||||
constexpr uint32_t smem_octet7 = 7 * (SMEM_SIZE / 8);
|
||||
|
||||
// NOTE: this has to match with smem_*
|
||||
static_assert(sizeof(elem_t) == sizeof(float));
|
||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||
constexpr uint32_t spad_addr_Q0 = 0;
|
||||
constexpr uint32_t spad_addr_Q1 =
|
||||
spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_K0 =
|
||||
spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_K1 =
|
||||
spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_V0 =
|
||||
spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_V1 =
|
||||
spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_S0 =
|
||||
spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor);
|
||||
constexpr uint32_t spad_addr_S1 =
|
||||
spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor);
|
||||
// allocation strategy: since the two warpgroups only access *0 and *1
|
||||
// buffers each, allocate *0 in the first half of SMEM, and *1 in the latter
|
||||
// half
|
||||
// at the same time, make sure Q and K are in different banks so that they
|
||||
// can be accessed in parallel for GEMM; same for P and V
|
||||
constexpr uint32_t smem_Q0_offset = smem_octet0;
|
||||
constexpr uint32_t smem_Q1_offset = smem_octet4;
|
||||
constexpr uint32_t smem_K0_offset = smem_octet1;
|
||||
constexpr uint32_t smem_K1_offset = smem_octet5;
|
||||
constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||
constexpr uint32_t smem_S0_offset = smem_octet2;
|
||||
constexpr uint32_t smem_S1_offset = smem_octet6;
|
||||
constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
||||
constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
||||
constexpr uint32_t smem_O0_offset = smem_octet3;
|
||||
constexpr uint32_t smem_O1_offset = smem_octet7;
|
||||
|
||||
float *smem_Q0 = reinterpret_cast<float *>(smem_start + smem_Q0_offset);
|
||||
float *smem_Q1 = reinterpret_cast<float *>(smem_start + smem_Q1_offset);
|
||||
float *smem_K0 = reinterpret_cast<float *>(smem_start + smem_K0_offset);
|
||||
float *smem_K1 = reinterpret_cast<float *>(smem_start + smem_K1_offset);
|
||||
float *smem_V0 = reinterpret_cast<float *>(smem_start + smem_V0_offset);
|
||||
float *smem_V1 = reinterpret_cast<float *>(smem_start + smem_V1_offset);
|
||||
float *smem_S0 = reinterpret_cast<float *>(smem_start + smem_S0_offset);
|
||||
float *smem_S1 = reinterpret_cast<float *>(smem_start + smem_S1_offset);
|
||||
float *smem_P0 = reinterpret_cast<float *>(smem_start + smem_P0_offset);
|
||||
float *smem_P1 = reinterpret_cast<float *>(smem_start + smem_P1_offset);
|
||||
float *smem_O0 = reinterpret_cast<float *>(smem_start + smem_O0_offset);
|
||||
float *smem_O1 = reinterpret_cast<float *>(smem_start + smem_O1_offset);
|
||||
|
||||
// allocate rowmax/rowsum storage at the end of the sharedmem address space
|
||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
||||
// FIXME: dangerous
|
||||
smem_cursor = reinterpret_cast<float *>(0xff038000);
|
||||
|
||||
float *smem_rowmax_0 = smem_cursor;
|
||||
smem_cursor += smem_rowmax_size;
|
||||
float *smem_rowmax_1 = smem_cursor;
|
||||
smem_cursor += smem_rowmax_size;
|
||||
float *smem_rowsum_0 = smem_cursor;
|
||||
smem_cursor += smem_rowsum_size;
|
||||
float *smem_rowsum_1 = smem_cursor;
|
||||
smem_cursor += smem_rowsum_size;
|
||||
float *smem_O_row_scale_0 = smem_cursor;
|
||||
smem_cursor += smem_O_row_scale_size;
|
||||
float *smem_O_row_scale_1 = smem_cursor;
|
||||
smem_cursor += smem_O_row_scale_size;
|
||||
float *smem_cursor_0 = smem_O0 + smem_O_size;
|
||||
float *smem_cursor_1 = smem_O1 + smem_O_size;
|
||||
// // FIXME: dangerous
|
||||
// smem_cursor = reinterpret_cast<float *>(0xff038000);
|
||||
float *smem_rowmax_0 = smem_cursor_0;
|
||||
smem_cursor_0 += smem_rowmax_size;
|
||||
float *smem_rowmax_1 = smem_cursor_1;
|
||||
smem_cursor_1 += smem_rowmax_size;
|
||||
float *smem_rowsum_0 = smem_cursor_0;
|
||||
smem_cursor_0 += smem_rowsum_size;
|
||||
float *smem_rowsum_1 = smem_cursor_1;
|
||||
smem_cursor_1 += smem_rowsum_size;
|
||||
float *smem_O_row_scale_0 = smem_cursor_0;
|
||||
smem_cursor_0 += smem_O_row_scale_size;
|
||||
float *smem_O_row_scale_1 = smem_cursor_1;
|
||||
smem_cursor_1 += smem_O_row_scale_size;
|
||||
|
||||
// sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction
|
||||
// in rowsum
|
||||
// NOTE: out-of bounds is not checked
|
||||
constexpr uint32_t smem_scratchpad_size =
|
||||
threads_per_warpgroup * 2 /*arbitrary slack*/;
|
||||
float *smem_scratchpad_0 = smem_cursor;
|
||||
smem_cursor += smem_scratchpad_size;
|
||||
float *smem_scratchpad_1 = smem_cursor;
|
||||
smem_cursor += smem_scratchpad_size;
|
||||
float *smem_scratchpad_0 = smem_cursor_0;
|
||||
smem_cursor_0 += smem_scratchpad_size;
|
||||
float *smem_scratchpad_1 = smem_cursor_1;
|
||||
smem_cursor_1 += smem_scratchpad_size;
|
||||
|
||||
// select the correct buffer by warpgroup
|
||||
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
||||
@@ -179,6 +180,21 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *smem_scratchpad =
|
||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||
|
||||
static_assert(sizeof(elem_t) == sizeof(float));
|
||||
constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t);
|
||||
constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
||||
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
||||
|
||||
const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0;
|
||||
const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0;
|
||||
const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0;
|
||||
|
||||
Reference in New Issue
Block a user