diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 933f835c..2014b507 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 64 +#define BM 128 #define BN 64 #if (FP_SIZE == 32) #define BK 64 @@ -72,7 +72,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 1 +#define GEMMINI_DMA 0 #define GEMMINI_DMA_MN_MAJOR 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) @@ -299,14 +299,23 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, (WM * warp_row + TCM * wm_iter) + row]); // f8-f15 stores a single row of A // threads read from different columns; no bank conflicts + // asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); - asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); + smem_addr += smem_AS_cols * 4 * sizeof(float); + asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr)); } else { static_assert(layout == MemLayout::K_major /* fake cond that is always false */, @@ -638,34 +647,67 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, // need to branch because address offset constant in the inline assembly // cannot be larger than a certain limit if constexpr (!transposed_write) { + // asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + // sizeof(float)), + // "r"(local)); + // asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + // sizeof(float)), + // "r"(local)); + // local += smem_dim_col * row_stride * 2; + // asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + // sizeof(float)), + // "r"(local)); + // asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + // sizeof(float)), + // "r"(local)); + // local += smem_dim_col * row_stride * 2; + // asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + // sizeof(float)), + // "r"(local)); + // asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + // sizeof(float)), + // "r"(local)); + // local += smem_dim_col * row_stride * 2; + // asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * + // sizeof(float)), + // "r"(local)); + // asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + // sizeof(float)), + // "r"(local)); + // local += smem_dim_col * row_stride * 2; + asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + local += smem_dim_col * row_stride; + asm volatile("fsw ft1, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - local += smem_dim_col * row_stride * 2; + local += smem_dim_col * row_stride; asm volatile("fsw ft2, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + local += smem_dim_col * row_stride; + asm volatile("fsw ft3, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - local += smem_dim_col * row_stride * 2; + local += smem_dim_col * row_stride; asm volatile("fsw ft4, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + local += smem_dim_col * row_stride; + asm volatile("fsw ft5, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - local += smem_dim_col * row_stride * 2; + local += smem_dim_col * row_stride; asm volatile("fsw ft6, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 1 * + local += smem_dim_col * row_stride; + asm volatile("fsw ft7, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 * sizeof(float)), "r"(local)); - local += smem_dim_col * row_stride * 2; + local += smem_dim_col * row_stride; } else { // currently, tensor core hardware only supports MN-major SMEM tile // layout for correct results @@ -996,6 +1038,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, warps_per_threadblock_per_core); #endif +#if 0 // consumer code: SMEM->RF and compute // ---------------------------------------------------------------------- // @perf: this loop spills to stack a lot because of all the flws in @@ -1044,6 +1087,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); +#endif } if constexpr (write_to_gmem) {