tensor: Change B in-memory layout to column-major
This commit is contained in:
@@ -14,7 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
TOOLDIR=${TOOLDIR:=/opt}
|
||||
TOOLDIR=${TOOLDIR:=/scratch/hansung/build/vortex-toolchain-prebuilt}
|
||||
export TOOLDIR
|
||||
|
||||
export VERILATOR_ROOT=$TOOLDIR/verilator
|
||||
export PATH=$VERILATOR_ROOT/bin:$PATH
|
||||
@@ -25,6 +26,8 @@ export PATH=$SV2V_PATH/bin:$PATH
|
||||
export YOSYS_PATH=$TOOLDIR/yosys
|
||||
export PATH=$YOSYS_PATH/bin:$PATH
|
||||
|
||||
export LLVM_VORTEX=$TOOLDIR/llvm-vortex
|
||||
export POCL_CC_PATH=$TOOLDIR/pocl/compiler
|
||||
export POCL_RT_PATH=$TOOLDIR/pocl/runtime
|
||||
# LLVM_POCL seems to be only used in tests/opencl
|
||||
export LLVM_POCL=/scratch/hansung/build/llvm-vortex2
|
||||
export LLVM_VORTEX=/scratch/hansung/build/llvm-vortex2
|
||||
export POCL_CC_PATH=/scratch/hansung/build/pocl-vortex2/compiler
|
||||
export POCL_RT_PATH=/scratch/hansung/build/pocl-vortex2/runtime
|
||||
|
||||
@@ -122,14 +122,24 @@ void vx_wmma_load() {
|
||||
// load B
|
||||
// B is stored row-major in the memory,
|
||||
// loaded column-major into the RF.
|
||||
asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col]));
|
||||
asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col]));
|
||||
asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col]));
|
||||
asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col]));
|
||||
asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col]));
|
||||
asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col]));
|
||||
asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col]));
|
||||
asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col]));
|
||||
// asm volatile("flw f8 , %0" ::"m"(B[DIM_N * 0 + col]));
|
||||
// asm volatile("flw f9 , %0" ::"m"(B[DIM_N * 1 + col]));
|
||||
// asm volatile("flw f10, %0" ::"m"(B[DIM_N * 2 + col]));
|
||||
// asm volatile("flw f11, %0" ::"m"(B[DIM_N * 3 + col]));
|
||||
// asm volatile("flw f12, %0" ::"m"(B[DIM_N * 4 + col]));
|
||||
// asm volatile("flw f13, %0" ::"m"(B[DIM_N * 5 + col]));
|
||||
// asm volatile("flw f14, %0" ::"m"(B[DIM_N * 6 + col]));
|
||||
// asm volatile("flw f15, %0" ::"m"(B[DIM_N * 7 + col]));
|
||||
// B is stored column-major in the memory,
|
||||
// loaded column-major into the RF.
|
||||
asm volatile("flw f8 , %0" ::"m"(B[DIM_K * row + 0]));
|
||||
asm volatile("flw f9 , %0" ::"m"(B[DIM_K * row + 1]));
|
||||
asm volatile("flw f10, %0" ::"m"(B[DIM_K * row + 2]));
|
||||
asm volatile("flw f11, %0" ::"m"(B[DIM_K * row + 3]));
|
||||
asm volatile("flw f12, %0" ::"m"(B[DIM_K * row + 4]));
|
||||
asm volatile("flw f13, %0" ::"m"(B[DIM_K * row + 5]));
|
||||
asm volatile("flw f14, %0" ::"m"(B[DIM_K * row + 6]));
|
||||
asm volatile("flw f15, %0" ::"m"(B[DIM_K * row + 7]));
|
||||
|
||||
map_c_8lanes(tid, row, col);
|
||||
|
||||
|
||||
@@ -49,7 +49,10 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy
|
||||
#VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy
|
||||
|
||||
VX_CFLAGS += -v -O3 -std=c++17
|
||||
VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections -mllvm -inline-threshold=8192
|
||||
VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections
|
||||
# comment out below for regression/basic, which uses GCC that doesn't
|
||||
# understand these flags
|
||||
VX_CFLAGS += -mllvm -inline-threshold=8192
|
||||
VX_CFLAGS += -I$(VORTEX_KN_PATH)/include -I$(VORTEX_KN_PATH)/../hw -I$(GEMMINI_SW_PATH)
|
||||
VX_CFLAGS += -DNDEBUG -DLLVM_VORTEX
|
||||
|
||||
@@ -104,23 +107,24 @@ kernel.bin: kernel.elf kernel.radiance.elf
|
||||
|
||||
OBJCOPY ?= $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy
|
||||
OBJCOPY_FLAGS ?= "LOAD,ALLOC,DATA,CONTENTS"
|
||||
kernel.elf: $(VX_SRCS)
|
||||
BINFILES := args.bin input.a.bin input.b.bin
|
||||
kernel.elf: $(VX_SRCS) $(BINFILES)
|
||||
$(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o $@
|
||||
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --update-section .operand.a=input.a.bin $@
|
||||
$(OBJCOPY) --update-section .operand.b=input.b.bin $@
|
||||
$(OBJCOPY) --update-section .args=args.bin $@
|
||||
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
|
||||
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
|
||||
$(OBJCOPY) --update-section .args=args.bin $@ || true
|
||||
|
||||
kernel.radiance.elf: $(VX_SRCS)
|
||||
kernel.radiance.elf: $(VX_SRCS) $(BINFILES)
|
||||
$(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o $@
|
||||
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --update-section .operand.a=input.a.bin $@
|
||||
$(OBJCOPY) --update-section .operand.b=input.b.bin $@
|
||||
$(OBJCOPY) --update-section .args=args.bin $@
|
||||
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
|
||||
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
|
||||
$(OBJCOPY) --update-section .args=args.bin $@ || true
|
||||
|
||||
ifneq ($(CONFIG),)
|
||||
kernel$(CONFIGEXT).elf: kernel.elf
|
||||
|
||||
@@ -572,7 +572,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const uint32_t problem_size = (dim_m * dim_n) / (ELEM_PER_THREAD);
|
||||
const uint32_t num_threadblocks = problem_size / threads_per_threadblock;
|
||||
|
||||
using float_type = float;
|
||||
using float_type = float16_t;
|
||||
|
||||
// "static" shared memory allocation. This would determine threadblock
|
||||
// occupancy of a single cluster
|
||||
|
||||
@@ -173,7 +173,8 @@ int main(int argc, char *argv[]) {
|
||||
uint32_t dim_n = 64;
|
||||
uint32_t dim_k = 64;
|
||||
|
||||
using float_type = float;
|
||||
using float_type = half;
|
||||
|
||||
generate_source_matrix<float_type>(dim_m, dim_n, dim_k);
|
||||
generate_reference_matmul<float_type>(dim_m, dim_n, dim_k);
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 64
|
||||
#define BK 128
|
||||
#define WM 16
|
||||
#define WN 8
|
||||
#define TCM 8
|
||||
|
||||
Reference in New Issue
Block a user