sgemm_tcore: Verify wo DMA; warn untested against K-major A + DMA

This commit is contained in:
Hansung Kim
2024-09-03 14:42:19 -07:00
parent 7aa0e6cbe4
commit f028a97f75
2 changed files with 3 additions and 2 deletions

View File

@@ -120,7 +120,6 @@ inline void thread_block_copy_tile(const float *src, float *dest,
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
float per_thread_max = FLT_MIN;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
dest[thread_offset] = src[thread_offset];

View File

@@ -232,13 +232,15 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major),
"GEMMINI_DMA only supported for K-major A tile");
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
"fp16 is not really tested for K-major A layout");
if constexpr (layout == MemLayout::K_major) {
constexpr int smem_A_cols = leading_dim;
// f8-f15 stores a single row of A
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
const uint32_t smem_logical_col = local_k + 0; /* FIXME: adjust for fp16? */
const uint32_t smem_logical_col = local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
uint32_t smem_row;
uint32_t smem_col;
if constexpr (GEMMINI_DMA) {