sgemm_tcore: Verify wo DMA; warn untested against K-major A + DMA
This commit is contained in:
@@ -120,7 +120,6 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
float per_thread_max = FLT_MIN;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
dest[thread_offset] = src[thread_offset];
|
||||
|
||||
@@ -232,13 +232,15 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
|
||||
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major),
|
||||
"GEMMINI_DMA only supported for K-major A tile");
|
||||
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
||||
"fp16 is not really tested for K-major A layout");
|
||||
|
||||
if constexpr (layout == MemLayout::K_major) {
|
||||
constexpr int smem_A_cols = leading_dim;
|
||||
|
||||
// f8-f15 stores a single row of A
|
||||
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
|
||||
const uint32_t smem_logical_col = local_k + 0; /* FIXME: adjust for fp16? */
|
||||
const uint32_t smem_logical_col = local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
|
||||
uint32_t smem_row;
|
||||
uint32_t smem_col;
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
|
||||
Reference in New Issue
Block a user