From beb3dce46dcb48786acd3115730a204510f9f619 Mon Sep 17 00:00:00 2001 From: joshua Date: Wed, 6 Mar 2024 01:39:17 -0800 Subject: [PATCH 01/55] integer reduction unit --- hw/rtl/VX_define.vh | 14 +- hw/rtl/core/VX_alu_unit.sv | 45 ++++- hw/rtl/core/VX_decode.sv | 28 +++ hw/rtl/core/VX_reduce_unit.sv | 283 +++++++++++++++++++++++++++++++ tests/kernel/Makefile | 6 +- tests/kernel/reductions/Makefile | 5 + tests/kernel/reductions/main.cpp | 216 +++++++++++++++++++++++ 7 files changed, 587 insertions(+), 10 deletions(-) create mode 100644 hw/rtl/core/VX_reduce_unit.sv create mode 100644 tests/kernel/reductions/Makefile create mode 100644 tests/kernel/reductions/main.cpp diff --git a/hw/rtl/VX_define.vh b/hw/rtl/VX_define.vh index 996c769d..9ddeeeea 100644 --- a/hw/rtl/VX_define.vh +++ b/hw/rtl/VX_define.vh @@ -115,7 +115,7 @@ /////////////////////////////////////////////////////////////////////////////// `define INST_OP_BITS 4 -`define INST_MOD_BITS 3 +`define INST_MOD_BITS 4 `define INST_FMT_BITS 2 /////////////////////////////////////////////////////////////////////////////// @@ -140,6 +140,7 @@ `define INST_ALU_IS_BR(mod) mod[0] `define INST_ALU_IS_M(mod) mod[1] `define INST_ALU_IS_W(mod) mod[2] +`define INST_ALU_IS_RED(mod) mod[3] `define INST_BR_EQ 4'b0000 `define INST_BR_NE 4'b0010 @@ -176,6 +177,17 @@ `define INST_M_SIGNED_A(op) (op[1:0] != 1) `define INST_M_IS_REM(op) op[1] +`define INST_RED_ADD 4'b0000 +`define INST_RED_ADDU 4'b1000 +`define INST_RED_MIN 4'b0001 +`define INST_RED_MINU 4'b1001 +`define INST_RED_MAX 4'b0010 +`define INST_RED_MAXU 4'b1010 +`define INST_RED_AND 4'b0011 +`define INST_RED_OR 4'b0100 +`define INST_RED_XOR 4'b0101 +`define INST_RED_BITS 4 + `define INST_FMT_B 3'b000 `define INST_FMT_H 3'b001 `define INST_FMT_W 3'b010 diff --git a/hw/rtl/core/VX_alu_unit.sv b/hw/rtl/core/VX_alu_unit.sv index d2b38cf4..7546f4b3 100644 --- a/hw/rtl/core/VX_alu_unit.sv +++ b/hw/rtl/core/VX_alu_unit.sv @@ -33,7 +33,7 @@ module VX_alu_unit #( localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES); localparam PID_WIDTH = `UP(PID_BITS); localparam RSP_ARB_DATAW= `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + `NR_BITS + 1 + NUM_LANES * `XLEN + PID_WIDTH + 1 + 1; - localparam RSP_ARB_SIZE = 1 + `EXT_M_ENABLED; + localparam RSP_ARB_SIZE = 2 + `EXT_M_ENABLED; localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS); VX_execute_if #( @@ -60,12 +60,13 @@ module VX_alu_unit #( for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin wire is_muldiv_op; + wire is_reduce_op; VX_execute_if #( .NUM_LANES (NUM_LANES) ) int_execute_if(); - assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op; + assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op && ~is_reduce_op; assign int_execute_if.data = execute_if[block_idx].data; VX_commit_if #( @@ -86,6 +87,31 @@ module VX_alu_unit #( .commit_if (int_commit_if) ); + assign is_reduce_op = `INST_ALU_IS_RED(execute_if[block_idx].data.op_mod); + + VX_execute_if #( + .NUM_LANES (NUM_LANES) + ) red_execute_if(); + + assign red_execute_if.valid = execute_if[block_idx].valid && is_reduce_op; + assign red_execute_if.data = execute_if[block_idx].data; + + VX_commit_if #( + .NUM_LANES (NUM_LANES) + ) red_commit_if(); + + `RESET_RELAY(red_reset, reset); + + VX_reduce_unit #( + .CORE_ID(CORE_ID), + .NUM_LANES(NUM_LANES) + ) reduce_unit ( + .clk(clk), + .reset(red_reset), + .execute_if(red_execute_if), + .commit_if(red_commit_if) + ); + `ifdef EXT_M_ENABLE assign is_muldiv_op = `INST_ALU_IS_M(execute_if[block_idx].data.op_mod); @@ -96,7 +122,7 @@ module VX_alu_unit #( .NUM_LANES (NUM_LANES) ) mdv_execute_if(); - assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op; + assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op && ~is_reduce_op; assign mdv_execute_if.data = execute_if[block_idx].data; VX_commit_if #( @@ -113,12 +139,12 @@ module VX_alu_unit #( .commit_if (mdv_commit_if) ); - assign execute_if[block_idx].ready = is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready; + assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : (is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready); `else assign is_muldiv_op = 0; - assign execute_if[block_idx].ready = int_execute_if.ready; + assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : int_execute_if.ready; `endif @@ -135,19 +161,22 @@ module VX_alu_unit #( `ifdef EXT_M_ENABLE mdv_commit_if.valid, `endif - int_commit_if.valid + int_commit_if.valid, + red_commit_if.valid }), .ready_in ({ `ifdef EXT_M_ENABLE mdv_commit_if.ready, `endif - int_commit_if.ready + int_commit_if.ready, + red_commit_if.ready }), .data_in ({ `ifdef EXT_M_ENABLE mdv_commit_if.data, `endif - int_commit_if.data + int_commit_if.data, + red_commit_if.data }), .data_out (commit_block_if[block_idx].data), .valid_out (commit_block_if[block_idx].valid), diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index 0a6b00ec..42cd7ffc 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -505,6 +505,34 @@ module VX_decode #( default:; endcase end + `INST_EXT3: begin + ex_type = `EX_ALU; + op_mod[3] = 1; + `USED_IREG(rs1); + `USED_IREG(rd); + + case (func7[5:0]) + 6'h0: begin + op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD; + end + 6'h1: begin + op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN; + end + 6'h2: begin + op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX; + end + 6'h3: begin + op_type = `INST_RED_AND; + end + 6'h4: begin + op_type = `INST_RED_OR; + end + 6'h5: begin + op_type = `INST_RED_XOR; + end + default:; + endcase + end default:; endcase end diff --git a/hw/rtl/core/VX_reduce_unit.sv b/hw/rtl/core/VX_reduce_unit.sv new file mode 100644 index 00000000..37610bcf --- /dev/null +++ b/hw/rtl/core/VX_reduce_unit.sv @@ -0,0 +1,283 @@ +`include "VX_define.vh" +`include "VX_platform.vh" + + +// Copyright © 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`include "VX_platform.vh" + +module VX_reduce_ext #( + parameter DATAW_IN = 1, + parameter DATAW_OUT = DATAW_IN, + parameter N = 1 +) ( + input wire [N-1:0][DATAW_IN-1:0] data_in, + input wire [N-1:0] mask, + input wire [`INST_RED_BITS-1:0] op_type, + output wire [DATAW_OUT-1:0] data_out +); + if (N == 1) begin + `UNUSED_VAR(op_type) + `UNUSED_VAR(mask) + assign data_out = DATAW_OUT'(data_in[0]); + end else begin + localparam int N_A = N / 2; + localparam int N_B = N - N_A; + + wire [N_A-1:0][DATAW_IN-1:0] in_A; + wire [N_B-1:0][DATAW_IN-1:0] in_B; + wire [DATAW_OUT-1:0] out_A, out_B; + + wire [N_A-1:0] mask_A; + wire [N_B-1:0] mask_B; + wire any_A, any_B; + + for (genvar i = 0; i < N_A; i++) begin + assign in_A[i] = data_in[i]; + end + + for (genvar i = 0; i < N_B; i++) begin + assign in_B[i] = data_in[N_A + i]; + end + + assign mask_A = mask[N_A-1:0]; + assign mask_B = mask[N-1:N_A]; + assign any_A = |mask_A; + assign any_B = |mask_B; + + VX_reduce_ext #( + .DATAW_IN (DATAW_IN), + .DATAW_OUT (DATAW_OUT), + .N (N_A) + ) reduce_A ( + .data_in (in_A), + .mask(mask_A), + .op_type(op_type), + .data_out (out_A) + ); + + VX_reduce_ext #( + .DATAW_IN (DATAW_IN), + .DATAW_OUT (DATAW_OUT), + .N (N_B) + ) reduce_B ( + .data_in (in_B), + .mask(mask_B), + .op_type(op_type), + .data_out (out_B) + ); + + logic [DATAW_OUT-1:0] _data_out; + + always @(*) begin + case (op_type) + `INST_RED_ADD: _data_out = out_A + out_B; + `INST_RED_ADDU: _data_out = out_A + out_B; + `INST_RED_MIN: _data_out = ($signed(out_A) < $signed(out_B)) ? out_A : out_B; + `INST_RED_MINU: _data_out = (out_A < out_B) ? out_A : out_B; + `INST_RED_MAX: _data_out = ($signed(out_A) < $signed(out_B)) ? out_B : out_A; + `INST_RED_MAXU: _data_out = (out_A < out_B) ? out_B : out_A; + `INST_RED_AND: _data_out = out_A & out_B; + `INST_RED_OR: _data_out = out_A | out_B; + `INST_RED_XOR: _data_out = out_A ^ out_B; + default: _data_out = out_A; + endcase + end + + // if both sides are masked out, then it doesn't matter what we output + assign data_out = (any_A && any_B) ? _data_out : (any_A ? out_A : out_B); + + end + +endmodule + +module VX_reduce_unit #( + parameter CORE_ID = 0, + parameter NUM_LANES = 1 +) ( + input wire clk, + input wire reset, + + VX_execute_if.slave execute_if, + VX_commit_if.master commit_if +); + `UNUSED_PARAM(CORE_ID) + + localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES; + localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES); + localparam PID_WIDTH = `UP(PID_BITS); + + logic [`XLEN-1:0] accumulator, accumulator_n, reduced_accumulator; + wire [(NUM_LANES * `XLEN)-1:0] broadcasted_accumulator; + + assign broadcasted_accumulator = {NUM_LANES{accumulator}}; + + wire eop; + wire [NUM_LANES-1:0][`XLEN-1:0] data_in; + wire [`XLEN-1:0] data_out; + + assign eop = execute_if.data.eop; + assign data_in = execute_if.data.rs1_data; + + logic execute_if_valid; + logic execute_if_ready; + logic commit_if_valid; + logic commit_if_ready; + + wire execute_if_fire; + wire commit_if_fire; + + assign execute_if_valid = execute_if.valid; + assign execute_if.ready = execute_if_ready; + + assign execute_if_fire = execute_if.ready && execute_if.valid; + assign commit_if_fire = commit_if_ready && commit_if_valid; + + logic store_tmask_pid; + logic read_tmask_pid; + wire [PID_WIDTH-1:0] stored_pid; + wire [NUM_LANES-1:0] stored_tmask; + wire stored_sop; + wire stored_eop; + + logic [PID_BITS:0] size, size_n; + + // 1. idle state - wait for execute_if to be valid + // 2. accumulate - continue accumulating until eop, store packet id + thread mask for broadcast phase + // 3. broadcast - broadcast to rds + localparam IDLE = 2'b00; + localparam ACCUMULATE = 2'b01; + localparam BROADCAST = 2'b10; + localparam FINISH = 2'b11; + + logic [1:0] state, state_n; + + always @(*) begin + state_n = state; + accumulator_n = accumulator; + execute_if_ready = '0; + commit_if_valid = '0; + store_tmask_pid = '0; + read_tmask_pid = '0; + size_n = store_tmask_pid ? size + 1 : (read_tmask_pid ? size - 1 : size); + + case (state) + IDLE: begin + if (execute_if_valid) begin + accumulator_n = data_out; + store_tmask_pid = '1; + if (eop) begin + state_n = BROADCAST; + end + else begin + execute_if_ready = '1; + state_n = ACCUMULATE; + end + end + end + ACCUMULATE: begin + execute_if_ready = '1; + if (eop) begin + execute_if_ready = '0; + state_n = BROADCAST; + end + if (eop || execute_if_fire) begin + accumulator_n = reduced_accumulator; + store_tmask_pid = '1; + end + end + BROADCAST: begin + execute_if_ready = '0; + commit_if_valid = '1; + + if (commit_if_fire) begin + read_tmask_pid = '1; + end + if (size_n == '0) begin + state_n = FINISH; + end + end + FINISH: begin + execute_if_ready = '1; + if (execute_if_fire) begin + state_n = IDLE; + end + end + endcase + end + + always @(posedge clk) begin + if (reset) begin + accumulator <= '0; + state <= IDLE; + size <= '0; + end + else begin + accumulator <= accumulator_n; + state <= state_n; + size <= size_n; + end + end + + VX_reduce_ext #( + .DATAW_IN(`XLEN), + .N(NUM_LANES) + ) reducer ( + .data_in(data_in), + .mask(execute_if.data.tmask), + .op_type(execute_if.data.op_type), + .data_out(data_out) + ); + + VX_reduce_ext #( + .DATAW_IN(`XLEN), + .N(2) + ) accumulator_reducer ( + .data_in({accumulator, data_out}), + .mask(2'b11), + .op_type(execute_if.data.op_type), + .data_out(reduced_accumulator) + ); + + VX_elastic_buffer #( + .DATAW(NUM_LANES + PID_WIDTH + 1 + 1), + .SIZE(NUM_PACKETS), + ) tmask_pid_store ( + .clk(clk), + .reset(reset), + + .valid_in(store_tmask_pid), + `UNUSED_PIN(ready_in), + .data_in({execute_if.data.tmask, execute_if.data.pid, execute_if.data.sop, execute_if.data.eop}), + + .data_out({stored_tmask, stored_pid, stored_sop, stored_eop}), + .ready_out(read_tmask_pid), + `UNUSED_PIN(valid_out) + ); + + VX_elastic_buffer #( + .DATAW(`UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + 1 + `NR_BITS + (`XLEN * NUM_LANES) + PID_WIDTH + 1 + 1) + ) output_buffer ( + .clk(clk), + .reset(reset), + .valid_in(commit_if_valid), + .ready_in(commit_if_ready), + .data_in({execute_if.data.uuid, execute_if.data.wid, stored_tmask, execute_if.data.PC, execute_if.data.wb, execute_if.data.rd, broadcasted_accumulator, stored_pid, stored_sop, stored_eop}), + + .data_out({commit_if.data.uuid, commit_if.data.wid, commit_if.data.tmask, commit_if.data.PC, commit_if.data.wb, commit_if.data.rd, commit_if.data.data, commit_if.data.pid, commit_if.data.sop, commit_if.data.eop}), + .ready_out(commit_if.ready), + .valid_out(commit_if.valid) + ); + +endmodule diff --git a/tests/kernel/Makefile b/tests/kernel/Makefile index ab4fdd07..f7c46754 100644 --- a/tests/kernel/Makefile +++ b/tests/kernel/Makefile @@ -1,19 +1,23 @@ all: $(MAKE) -C conform $(MAKE) -C hello - $(MAKE) -C fibonacci + $(MAKE) -C fibonacci + $(MAKE) -C reductions run-simx: $(MAKE) -C conform run-simx $(MAKE) -C hello run-simx $(MAKE) -C fibonacci run-simx + $(MAKE) -C reductions run-simx run-rtlsim: $(MAKE) -C conform run-rtlsim $(MAKE) -C hello run-rtlsim $(MAKE) -C fibonacci run-rtlsim + $(MAKE) -C reductions run-rtlsim clean: $(MAKE) -C conform clean $(MAKE) -C hello clean $(MAKE) -C fibonacci clean + $(MAKE) -C reductions clean diff --git a/tests/kernel/reductions/Makefile b/tests/kernel/reductions/Makefile new file mode 100644 index 00000000..76e96c46 --- /dev/null +++ b/tests/kernel/reductions/Makefile @@ -0,0 +1,5 @@ +PROJECT = reductions + +SRCS = main.cpp + +include ../common.mk diff --git a/tests/kernel/reductions/main.cpp b/tests/kernel/reductions/main.cpp new file mode 100644 index 00000000..edde1da4 --- /dev/null +++ b/tests/kernel/reductions/main.cpp @@ -0,0 +1,216 @@ +#define RISCV_CUSTOM2 0x5B +#define ADD_FUNC7 0b0000000 +#define ADDU_FUNC7 0b1000000 +#define MIN_FUNC7 0b0000001 +#define MINU_FUNC7 0b1000001 +#define MAX_FUNC7 0b0000010 +#define MAXU_FUNC7 0b1000010 +#define AND_FUNC7 0b0000011 +#define OR_FUNC7 0b0000100 +#define XOR_FUNC7 0b0000101 + +/* + 6'h0: begin + op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD; + end + 6'h1: begin + op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN; + end + 6'h2: begin + op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX; + end + 6'h3: begin + op_type = `INST_RED_AND; + end + 6'h4: begin + op_type = `INST_RED_OR; + end + 6'h5: begin + op_type = `INST_RED_XOR; + end +*/ + +#include +#include +#include + +int x[4] = {3, 7, 2, 5}; +int y = -1; + +inline int vx_add_reduce(int v) { + int ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(ADD_FUNC7)); + return ret; +} + +inline int vx_min_reduce(int v) { + int ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MIN_FUNC7)); + return ret; +} + +inline unsigned vx_minu_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MINU_FUNC7)); + return ret; +} + +inline int vx_max_reduce(int v) { + int ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAX_FUNC7)); + return ret; +} + +inline unsigned vx_maxu_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAXU_FUNC7)); + return ret; +} + + +inline unsigned vx_and_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(AND_FUNC7)); + return ret; +} + +inline unsigned vx_or_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(OR_FUNC7)); + return ret; +} + +inline unsigned vx_xor_reduce(unsigned v) { + unsigned ret; + asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(XOR_FUNC7)); + return ret; +} + +void test_add_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + int v = x[tid]; + int reduced = vx_add_reduce(v); + vx_tmc(1); + + y = reduced; +} + +unsigned unsigned_vector[4] = {(unsigned)-1, 0, (unsigned)-2, 5}; + +void test_min_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + int v = unsigned_vector[tid]; + int reduced = vx_min_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_max_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + int v = unsigned_vector[tid]; + int reduced = vx_max_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_minu_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = unsigned_vector[tid]; + unsigned reduced = vx_minu_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_maxu_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = unsigned_vector[tid]; + unsigned reduced = vx_maxu_reduce(v); + vx_tmc(1); + + y = reduced; +} + +unsigned bit_vectors[4] = {0b11010110000111001100010100100110, 0b10010100011010001010000000001110, 0b10001001010111110001110000000010, 0b00010011010100101101110111001111}; + +void test_and_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = bit_vectors[tid]; + unsigned reduced = vx_and_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_or_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = bit_vectors[tid]; + unsigned reduced = vx_or_reduce(v); + vx_tmc(1); + + y = reduced; +} + +void test_xor_reduce() { + vx_tmc(-1); + int tid = vx_thread_id(); + unsigned v = bit_vectors[tid]; + unsigned reduced = vx_xor_reduce(v); + vx_tmc(1); + + y = reduced; +} + +int main() +{ + int expected; + + test_add_reduce(); + vx_printf("add reduce result: %d\n", y); + vx_printf("expected: %d\n", x[0] + x[1] + x[2] + x[3]); + + test_min_reduce(); + vx_printf("min reduce result: %d\n", y); + expected = MIN((int)unsigned_vector[0], MIN((int)unsigned_vector[1], MIN((int)unsigned_vector[2], (int)unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_max_reduce(); + vx_printf("max reduce result: %d\n", y); + expected = MAX((int)unsigned_vector[0], MAX((int)unsigned_vector[1], MAX((int)unsigned_vector[2], (int)unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_minu_reduce(); + vx_printf("minu reduce result: %d\n", y); + expected = MIN(unsigned_vector[0], MIN(unsigned_vector[1], MIN(unsigned_vector[2], unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_maxu_reduce(); + vx_printf("maxu reduce result: %d\n", y); + expected = MAX(unsigned_vector[0], MAX(unsigned_vector[1], MAX(unsigned_vector[2], unsigned_vector[3]))); + vx_printf("expected: %d\n", expected); + + test_and_reduce(); + vx_printf("and reduce result: %d\n", y); + vx_printf("expected: %d\n", bit_vectors[0] & bit_vectors[1] & bit_vectors[2] & bit_vectors[3]); + + + test_or_reduce(); + vx_printf("or reduce result: %d\n", y); + vx_printf("expected: %d\n", bit_vectors[0] | bit_vectors[1] | bit_vectors[2] | bit_vectors[3]); + + test_xor_reduce(); + vx_printf("xor reduce result: %d\n", y); + vx_printf("expected: %d\n", bit_vectors[0] ^ bit_vectors[1] ^ bit_vectors[2] ^ bit_vectors[3]); + + + return 0; +} \ No newline at end of file From 978dd3bdfecdf8c5c326025bb57ff1048e70ba78 Mon Sep 17 00:00:00 2001 From: joshua Date: Tue, 19 Mar 2024 17:56:59 -0700 Subject: [PATCH 02/55] seemingly working fp32 implementation --- hw/dpi/float_dpi.cpp | 76 + hw/dpi/float_dpi.vh | 2 + hw/dpi/half.hpp | 4018 ++++++++++++++++++++++++++++++++++ hw/rtl/VX_config.vh | 5 + hw/rtl/fpu/VX_tensor_core.sv | 0 hw/rtl/fpu/VX_tensor_dpu.sv | 35 + hw/rtl/fpu/VX_tensor_tb.sv | 28 + hw/unittest/tensor/Makefile | 89 + hw/unittest/tensor/main.cpp | 197 ++ 9 files changed, 4450 insertions(+) create mode 100644 hw/dpi/half.hpp create mode 100644 hw/rtl/fpu/VX_tensor_core.sv create mode 100644 hw/rtl/fpu/VX_tensor_dpu.sv create mode 100644 hw/rtl/fpu/VX_tensor_tb.sv create mode 100644 hw/unittest/tensor/Makefile create mode 100644 hw/unittest/tensor/main.cpp diff --git a/hw/dpi/float_dpi.cpp b/hw/dpi/float_dpi.cpp index 340b258d..d5209bed 100644 --- a/hw/dpi/float_dpi.cpp +++ b/hw/dpi/float_dpi.cpp @@ -23,6 +23,9 @@ #include "verilated_vpi.h" #include "VX_config.h" +#include +#include "half.hpp" + extern "C" { void dpi_fadd(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags); void dpi_fsub(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags); @@ -51,6 +54,8 @@ extern "C" { void dpi_feq(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags); void dpi_fmin(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags); void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags); + + void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile); } inline uint64_t nan_box(uint32_t value) { @@ -337,4 +342,75 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s } else { *result = nan_box(rv_fmax_s(check_boxing(a), check_boxing(b), fflags)); } +} + +// A is M * K, B is K * K * M, C is M * M, D is M * M +#define M 4 +#define K 2 + +// all row major +float c_A_tile[M][K]; +float c_B_tile[K][M]; +float c_C_tile[M][M]; +float c_D_tile[M][M]; + +// code assumes that svBitVecVal is basically a uint32_t +static_assert(sizeof(svBitVecVal) == 4); + +void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) { + + for (int i = 0; i < rows; i += 1) { + for (int j = 0; j < cols; j += 1) { + int index = i * cols + j; + svBitVecVal sv_val = sv_tile[index]; + + uint32_t c_val = sv_val; + float c_float; + + memcpy(&c_float, &c_val, sizeof(c_float)); + c_tile[index] = c_float; + + // std::cout << c_float << " "; + } + // std::cout << std::endl; + } +} + +void write_float_array(svBitVecVal* sv_tile, float* c_tile, int rows, int cols) { + for (int i = 0; i < rows; i += 1) { + for (int j = 0; j < cols; j += 1) { + int index = i * cols + j; + svBitVecVal* sv_val = &sv_tile[index]; + + float c_float = c_tile[index]; + memcpy(sv_val, &c_float, sizeof(c_float)); + + // std::cout << c_float << " "; + } + // std::cout << std::endl; + } +} + +void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile) { + if (!enable) { + return; + } + // std::cout << "A: " << std::endl; + fill_float_array(A_tile, &c_A_tile[0][0], M, K); + // std::cout << "B: " << std::endl; + fill_float_array(B_tile, &c_B_tile[0][0], K, M); + // std::cout << "C: " << std::endl; + fill_float_array(C_tile, &c_C_tile[0][0], M, M); + + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < M; j += 1) { + float accum = c_C_tile[i][j]; + for (int k = 0; k < K; k += 1) { + accum += c_A_tile[i][k] * c_B_tile[k][j]; + } + c_D_tile[i][j] = accum; + } + } + + write_float_array(D_tile, &c_D_tile[0][0], M, M); } \ No newline at end of file diff --git a/hw/dpi/float_dpi.vh b/hw/dpi/float_dpi.vh index 13580765..c8e7c9cb 100644 --- a/hw/dpi/float_dpi.vh +++ b/hw/dpi/float_dpi.vh @@ -44,4 +44,6 @@ import "DPI-C" function void dpi_feq(input logic enable, input int dst_fmt, inpu import "DPI-C" function void dpi_fmin(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags); import "DPI-C" function void dpi_fmax(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags); +import "DPI-C" function void dpi_hmma(input logic enable, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, output bit[3:0][3:0][31:0] D_tile); + `endif diff --git a/hw/dpi/half.hpp b/hw/dpi/half.hpp new file mode 100644 index 00000000..18596f97 --- /dev/null +++ b/hw/dpi/half.hpp @@ -0,0 +1,4018 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2019 Christian Rau +// Copyright (c) 2020 0xBYTESHIFT +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +/// \file +/// Main header file for half-precision functionality. + +#pragma once + +#define HALF_TWOS_COMPLEMENT_INT 1 + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING (HALF_ERRHANDLING_FLAGS||HALF_ERRHANDLING_ERRNO||HALF_ERRHANDLING_FENV||HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING + #define HALF_UNUSED_NOERR(name) name +#else + #define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ERRHANDLING + #define constexpr_NOERR +#else + #define constexpr_NOERR constexpr +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ERRHANDLING_ERRNO + #include +#endif +#include +#include + +#ifndef HALF_ENABLE_F16C_INTRINSICS + /// Enable F16C intruction set intrinsics. + /// Defining this to 1 enables the use of [F16C compiler intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between + /// half-precision and single-precision values which may result in improved performance. This will not perform additional checks + /// for support of the F16C instruction set, so an appropriate target platform is required when enabling this feature. + /// + /// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which some compilers do on supporting platforms. + #define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif + +#if HALF_ENABLE_F16C_INTRINSICS + #include +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest representable value. It can even +/// be set to [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE + #define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 + #define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN + #define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL + #define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO + #define FP_ZERO 1 +#endif +#ifndef FP_NAN + #define FP_NAN 2 +#endif +#ifndef FP_INFINITE + #define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL + #define FP_NORMAL 4 +#endif + +#if !defined(FE_ALL_EXCEPT) + #define FE_INVALID 0x10 + #define FE_DIVBYZERO 0x08 + #define FE_OVERFLOW 0x04 + #define FE_UNDERFLOW 0x02 + #define FE_INEXACT 0x01 + #define FE_ALL_EXCEPT (FE_INVALID|FE_DIVBYZERO|FE_OVERFLOW|FE_UNDERFLOW|FE_INEXACT) +#endif + + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float { + class half; + + /// Library-defined half-precision literals. + /// Import this namespace to enable half-precision floating-point literals: + /// ~~~~{.cpp} + /// using namespace half_float::literal; + /// half_float::half = 4.2_h; + /// ~~~~ + namespace literal { + half operator "" _h(long double); + } + + /// \internal + /// \brief Implementation details. + namespace detail { + /// Conditional type. + template struct conditional : std::conditional {}; + + /// Helper for tag dispatching. + template struct bool_type : std::integral_constant {}; + using std::true_type; + using std::false_type; + + /// Type traits for floating-point types. + template struct is_float : std::is_floating_point {}; + + /// Type traits for floating-point bits. + template struct bits { using type = unsigned char; }; + template struct bits : bits {}; + template struct bits : bits {}; + template struct bits : bits {}; + + /// Unsigned integer of (at least) 16 bits width. + using uint16 = std::uint_least16_t; + + /// Fastest unsigned integer of (at least) 32 bits width. + using uint32 = std::uint_fast32_t; + + /// Fastest signed integer of (at least) 32 bits width. + using int32 = std::int_fast32_t; + + /// Unsigned integer of (at least) 32 bits width. + template<> struct bits { using type = std::uint_least32_t; }; + + /// Unsigned integer of (at least) 64 bits width. + template<> struct bits { using type = std::uint_least64_t; }; + template using bits_t = typename bits::type; + + #ifdef HALF_ARITHMETIC_TYPE + /// Type to use for arithmetic computations and mathematic functions internally. + typedef HALF_ARITHMETIC_TYPE internal_t; + #endif + + /// Tag type for binary construction. + struct binary_t {}; + + /// Tag for binary construction. + constexpr binary_t binary = binary_t(); + + /// \name Implementation defined classification and arithmetic + /// \{ + + /// Check for infinity. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if infinity + /// \retval false else + template bool builtin_isinf(T arg) { return std::isinf(arg); } + + /// Check for NaN. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if not a number + /// \retval false else + template bool builtin_isnan(T arg) { return std::isnan(arg); } + + /// Check sign. + /// \tparam T argument type (builtin floating-point type) + /// \param arg value to query + /// \retval true if signbit set + /// \retval false else + template bool builtin_signbit(T arg) { return std::signbit(arg); } + + /// Platform-independent sign mask. + /// \param arg integer value in two's complement + /// \retval -1 if \a arg negative + /// \retval 0 if \a arg positive + inline uint32 sign_mask(uint32 arg) { + static const int N = std::numeric_limits::digits - 1; + #if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; + #else + return -((arg>>N)&1); + #endif + } + + /// Platform-independent arithmetic right shift. + /// \param arg integer value in two's complement + /// \param i shift amount (at most 31) + /// \return \a arg right shifted for \a i bits with possible sign extension + inline uint32 arithmetic_shift(uint32 arg, int i) { + #if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; + #else + return static_cast(arg)/(static_cast(1)<>(std::numeric_limits::digits-1))&1); + #endif + } + + /// \} + /// \name Error handling + /// \{ + + /// Internal exception flags. + /// \return reference to global exception flags + inline int& errflags() { thread_local int flags = 0; return flags; } + + /// Raise floating-point exception. + /// \param flags exceptions to raise + /// \param cond condition to raise exceptions for + inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) { + #if HALF_ERRHANDLING + if(!cond) + return; + #if HALF_ERRHANDLING_FLAGS + errflags() |= flags; + #endif + #if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO|FE_OVERFLOW|FE_UNDERFLOW)) + errno = ERANGE; + #endif + #if HALF_ERRHANDLING_FENV + std::feraiseexcept(flags); + #endif + #ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); + #endif + #ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); + #endif + #ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); + #endif + #ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); + #endif + #ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); + #endif + #if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); + #endif + #if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); + #endif + #endif + } + + /// Check and signal for any NaN. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \retval true if either \a x or \a y is NaN + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool compsignal(unsigned int x, unsigned int y) { + #if HALF_ERRHANDLING + raise(FE_INVALID, (x&0x7FFF)>0x7C00 || (y&0x7FFF)>0x7C00); + #endif + return (x&0x7FFF) > 0x7C00 || (y&0x7FFF) > 0x7C00; + } + + /// Signal and silence signaling NaN. + /// \param nan half-precision NaN value + /// \return quiet NaN + /// \exception FE_INVALID if \a nan is signaling NaN + inline constexpr_NOERR unsigned int signal(unsigned int nan) { + #if HALF_ERRHANDLING + raise(FE_INVALID, !(nan&0x200)); + #endif + return nan | 0x200; + } + + /// Signal and silence signaling NaNs. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \return quiet NaN + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline constexpr_NOERR unsigned int signal(unsigned int x, unsigned int y) { + #if HALF_ERRHANDLING + raise(FE_INVALID, ((x&0x7FFF)>0x7C00 && !(x&0x200)) || ((y&0x7FFF)>0x7C00 && !(y&0x200))); + #endif + return ((x&0x7FFF)>0x7C00) ? (x|0x200) : (y|0x200); + } + + /// Signal and silence signaling NaNs. + /// \param x first half-precision value to check + /// \param y second half-precision value to check + /// \param z third half-precision value to check + /// \return quiet NaN + /// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN + inline constexpr_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) { + #if HALF_ERRHANDLING + raise(FE_INVALID, ((x&0x7FFF)>0x7C00 && !(x&0x200)) || ((y&0x7FFF)>0x7C00 && !(y&0x200)) || ((z&0x7FFF)>0x7C00 && !(z&0x200))); + #endif + return ((x&0x7FFF)>0x7C00) ? (x|0x200) : ((y&0x7FFF)>0x7C00) ? (y|0x200) : (z|0x200); + } + + /// Select value or signaling NaN. + /// \param x preferred half-precision value + /// \param y ignored half-precision value except for signaling NaN + /// \return \a y if signaling NaN, \a x otherwise + /// \exception FE_INVALID if \a y is signaling NaN + inline constexpr_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) { + #if HALF_ERRHANDLING + return (((y&0x7FFF)>0x7C00) && !(y&0x200)) ? signal(y) : x; + #else + return x; + #endif + } + + /// Raise domain error and return NaN. + /// return quiet NaN + /// \exception FE_INVALID + inline constexpr_NOERR unsigned int invalid() { + #if HALF_ERRHANDLING + raise(FE_INVALID); + #endif + return 0x7FFF; + } + + /// Raise pole error and return infinity. + /// \param sign half-precision value with sign bit only + /// \return half-precision infinity with sign of \a sign + /// \exception FE_DIVBYZERO + inline constexpr_NOERR unsigned int pole(unsigned int sign = 0) { + #if HALF_ERRHANDLING + raise(FE_DIVBYZERO); + #endif + return sign | 0x7C00; + } + + /// Check value for underflow. + /// \param arg non-zero half-precision value to check + /// \return \a arg + /// \exception FE_UNDERFLOW if arg is subnormal + inline constexpr_NOERR unsigned int check_underflow(unsigned int arg) { + #if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg&0x7C00)); + #endif + return arg; + } + + /// \} + /// \name Conversion and rounding + /// \{ + + /// Half-precision overflow. + /// \tparam R rounding mode to use + /// \param sign half-precision value with sign bit only + /// \return rounded overflowing half-precision value + /// \exception FE_OVERFLOW + template constexpr_NOERR unsigned int overflow(unsigned int sign = 0) { + #if HALF_ERRHANDLING + raise(FE_OVERFLOW); + #endif + return (R==std::round_toward_infinity) ? (sign+0x7C00-(sign>>15)) : + (R==std::round_toward_neg_infinity) ? (sign+0x7BFF+(sign>>15)) : + (R==std::round_toward_zero) ? (sign|0x7BFF) : + (sign|0x7C00); + } + + /// Half-precision underflow. + /// \tparam R rounding mode to use + /// \param sign half-precision value with sign bit only + /// \return rounded underflowing half-precision value + /// \exception FE_UNDERFLOW + template constexpr_NOERR unsigned int underflow(unsigned int sign = 0) { + #if HALF_ERRHANDLING + raise(FE_UNDERFLOW); + #endif + return (R==std::round_toward_infinity) ? (sign+1-(sign>>15)) : + (R==std::round_toward_neg_infinity) ? (sign+(sign>>15)) : + sign; + } + + /// Round half-precision number. + /// \tparam R rounding mode to use + /// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results + /// \param value finite half-precision number to round + /// \param g guard bit (most significant discarded bit) + /// \param s sticky bit (or of all but the most significant discarded bits) + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded or \a I is `true` + template constexpr_NOERR unsigned int rounded(unsigned int value, int g, int s) { + #if HALF_ERRHANDLING + value += (R==std::round_to_nearest) ? (g&(s|value)) : + (R==std::round_toward_infinity) ? (~(value>>15)&(g|s)) : + (R==std::round_toward_neg_infinity) ? ((value>>15)&(g|s)) : 0; + if((value&0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g|s)!=0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g|s)!=0); + return value; + #else + return (R==std::round_to_nearest) ? (value+(g&(s|value))) : + (R==std::round_toward_infinity) ? (value+(~(value>>15)&(g|s))) : + (R==std::round_toward_neg_infinity) ? (value+((value>>15)&(g|s))) : + value; + #endif + } + + /// Round half-precision number to nearest integer value. + /// \tparam R rounding mode to use + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it + /// \param value half-precision value to round + /// \return half-precision bits for nearest integral value + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded and \a I is `true` + template unsigned int integral(unsigned int value) { + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) { + raise(FE_INEXACT, I); + return ((R==std::round_to_nearest) ? (0x3C00&-static_cast(abs>=(0x3800+E))) : + (R==std::round_toward_infinity) ? (0x3C00&-(~(value>>15)&(abs!=0))) : + (R==std::round_toward_neg_infinity) ? (0x3C00&-static_cast(value>0x8000)) : + 0) | (value&0x8000); + } + if(abs >= 0x6400) + return (abs>0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs>>10), mask = (1<>exp)&E)) : + (R==std::round_toward_infinity) ? (mask&((value>>15)-1)) : + (R==std::round_toward_neg_infinity) ? (mask&-(value>>15)) : + 0) + value) & ~mask; + } + + /// Convert fixed point to half-precision floating-point. + /// \tparam R rounding mode to use + /// \tparam F number of fractional bits (at least 11) + /// \tparam S `true` for signed, `false` for unsigned + /// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F + /// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results + /// \param m mantissa in Q1.F fixed point format + /// \param exp exponent + /// \param sign half-precision value with sign bit only + /// \param s sticky bit (or of all but the most significant already discarded bits) + /// \return value converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded or \a I is `true` + template unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) { + if(S) { + uint32 msign = sign_mask(m); + m = (m^msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m<(static_cast(1)<(sign+(m>>(F-10-exp)), (m>>(F-11-exp))&1, s|((m&((static_cast(1)<<(F-11-exp))-1))!=0)); + return rounded(sign+(exp<<10)+(m>>(F-10)), (m>>(F-11))&1, s|((m&((static_cast(1)<<(F-11))-1))!=0)); + } + + /// Convert IEEE single-precision to half-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \tparam R rounding mode to use + /// \param value single-precision value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(float value, true_type) { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R==std::round_to_nearest) ? _MM_FROUND_TO_NEAREST_INT : + (R==std::round_toward_zero) ? _MM_FROUND_TO_ZERO : + (R==std::round_toward_infinity) ? _MM_FROUND_TO_POS_INF : + (R==std::round_toward_neg_infinity) ? _MM_FROUND_TO_NEG_INF : + _MM_FROUND_CUR_DIRECTION)); + #else + bits_t fbits; + std::memcpy(&fbits, &value, sizeof(float)); + #if 1 + unsigned int sign = (fbits>>16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits>0x7F800000) ? (0x200|((fbits>>13)&0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign|(((fbits>>23)-112)<<10)|((fbits>>13)&0x3FF), (fbits>>12)&1, (fbits&0xFFF)!=0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits>>23); + fbits = (fbits&0x7FFFFF) | 0x800000; + return rounded(sign|(fbits>>(i+1)), (fbits>>i)&1, (fbits&((static_cast(1)<(sign); + return sign; + #else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, + 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, + 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7C00, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, + 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, + 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00 }; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13 }; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits|((exp!=0)<<23)) & -static_cast(exp!=0xFF); + return rounded(base_table[sexp]+(fbits>>i), (m>>(i-1))&1, (((static_cast(1)<<(i-1))-1)&m)!=0); + #endif + #endif + } + + /// Convert IEEE double-precision to half-precision. + /// \tparam R rounding mode to use + /// \param value double-precision value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(double value, true_type) { + #if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); + #endif + bits_t dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi>>16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits&0xFFFFFFFFFFFFF) ? (0x200|((hi>>10)&0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign|(((hi>>20)-1008)<<10)|((hi>>10)&0x3FF), (hi>>9)&1, ((hi&0x1FF)|lo)!=0); + if(hi >= 0x3E600000) { + int i = 1018 - (hi>>20); + hi = (hi&0xFFFFF) | 0x100000; + return rounded(sign|(hi>>(i+1)), (hi>>i)&1, ((hi&((static_cast(1)<(sign); + return sign; + } + + /// Convert non-IEEE floating-point to half-precision. + /// \tparam R rounding mode to use + /// \tparam T source type (builtin floating-point type) + /// \param value floating-point value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half_impl(T value, ...) { + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else { + value = std::ldexp(value, 12-exp); + hbits |= ((exp+13)<<10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits+(m>>1), m&1, frac!=T()); + } + + /// Convert floating-point to half-precision. + /// \tparam R rounding mode to use + /// \tparam T source type (builtin floating-point type) + /// \param value floating-point value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int float2half(T value) { + return float2half_impl(value, bool_type::is_iec559&&sizeof(bits_t)==sizeof(T)>()); + } + template unsigned int float2half(T value) { + return float2half_impl<(std::float_round_style)(HALF_ROUND_STYLE)>(value, bool_type::is_iec559&&sizeof(bits_t)==sizeof(T)>()); + } + + /// Convert integer to half-precision floating-point. + /// \tparam R rounding mode to use + /// \tparam T type to convert (builtin integer type) + /// \param value integral value to convert + /// \return rounded half-precision value + /// \exception FE_OVERFLOW on overflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int int2half(T value) { + unsigned int bits = static_cast(value<0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m<0x400; m<<=1,--exp) ; + for(; m>0x7FF; m>>=1,++exp) ; + bits |= (exp<<10) + m; + return (exp>24) ? rounded(bits, (value>>(exp-25))&1, (((1<<(exp-25))-1)&value)!=0) : bits; + } + + /// Convert half-precision to IEEE single-precision. + /// Credit for this goes to [Jeroen van der Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). + /// \param value half-precision value to convert + /// \return single-precision value + inline float half2float_impl(unsigned int value, float, true_type) { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); + #else + #if 0 + bits_t fbits = static_cast>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast>(abs) << 13; + } + #else + static const bits_t mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, + 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, + 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, + 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, + 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, + 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, + 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, + 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, + 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, + 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, + 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, + 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, + 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, + 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, + 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, + 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, + 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, + 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, + 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, + 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, + 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, + 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, + 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, + 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, + 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, + 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, + 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, + 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, + 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, + 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, + 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, + 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, + 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, + 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, + 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, + 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, + 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, + 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, + 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, + 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, + 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, + 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, + 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, + 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, + 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, + 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, + 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, + 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, + 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, + 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, + 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, + 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, + 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, + 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, + 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, + 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, + 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, + 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, + 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, + 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, + 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, + 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, + 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, + 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, + 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, + 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, + 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, + 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, + 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, + 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, + 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, + 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, + 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, + 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, + 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, + 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, + 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, + 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, + 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, + 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, + 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, + 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, + 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, + 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, + 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, + 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, + 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, + 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, + 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, + 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, + 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, + 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, + 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, + 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, + 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, + 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, + 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, + 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, + 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, + 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, + 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, + 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, + 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, + 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, + 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, + 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, + 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, + 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, + 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, + 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000 }; + static const bits_t exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, + 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, + 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, + 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000 }; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024 }; + bits_t fbits = mantissa_table[offset_table[value>>10]+(value&0x3FF)] + exponent_table[value>>10]; + #endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; + #endif + } + + /// Convert half-precision to IEEE double-precision. + /// \param value half-precision value to convert + /// \return double-precision value + inline double half2float_impl(unsigned int value, double, true_type) { + #if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); + #else + uint32 hi = static_cast(value&0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) { + hi |= 0x3F000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,hi-=0x100000) ; + hi += static_cast(abs) << 10; + } + bits_t dbits = static_cast>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; + #endif + } + + /// Convert half-precision to non-IEEE floating-point. + /// \tparam T type to convert to (builtin integer type) + /// \param value half-precision value to convert + /// \return floating-point value + template T half2float_impl(unsigned int value, T, ...) { + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = (std::numeric_limits::has_signaling_NaN && !(abs&0x200)) ? std::numeric_limits::signaling_NaN() : + std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs&0x3FF)|0x400), (abs>>10)-25); + else + out = std::ldexp(static_cast(abs), -24); + return (value&0x8000) ? -out : out; + } + + /// Convert half-precision to floating-point. + /// \tparam T type to convert to (builtin integer type) + /// \param value half-precision value to convert + /// \return floating-point value + template T half2float(unsigned int value) { + return half2float_impl(value, T(), bool_type::is_iec559&&sizeof(bits_t)==sizeof(T)>()); + } + + /// Convert half-precision floating-point to integer. + /// \tparam R rounding mode to use + /// \tparam E `true` for round to even, `false` for round away from zero + /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it + /// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding any implicit sign bits) + /// \param value half-precision value to convert + /// \return rounded integer value + /// \exception FE_INVALID if value is not representable in type \a T + /// \exception FE_INEXACT if value had to be rounded and \a I is `true` + template T half2int(unsigned int value) { + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) { + raise(FE_INVALID); + return (value&0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) { + raise(FE_INEXACT, I); + return (R==std::round_toward_infinity) ? T(~(value>>15)&(abs!=0)) : + (R==std::round_toward_neg_infinity) ? -T(value>0x8000) : + T(); + } + int exp = 25 - (abs>>10); + unsigned int m = (value&0x3FF) | 0x400; + int32 i = static_cast((exp<=0) ? (m<<-exp) : ((m+( + (R==std::round_to_nearest) ? ((1<<(exp-1))-(~(m>>exp)&E)) : + (R==std::round_toward_infinity) ? (((1<>15)-1)) : + (R==std::round_toward_neg_infinity) ? (((1<>15)) : 0))>>exp)); + if((!std::numeric_limits::is_signed && (value&0x8000)) || (std::numeric_limits::digits<16 && + ((value&0x8000) ? (-i::min()) : (i>std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m&((1<((value&0x8000) ? -i : i); + } + + /// \} + /// \name Mathematics + /// \{ + + /// upper part of 64-bit multiplication. + /// \tparam R rounding mode to use + /// \param x first factor + /// \param y second factor + /// \return upper 32 bit of \a x * \a y + template uint32 mulhi(uint32 x, uint32 y) { + uint32 xy = (x>>16) * (y&0xFFFF), yx = (x&0xFFFF) * (y>>16), c = (xy&0xFFFF) + (yx&0xFFFF) + (((x&0xFFFF)*(y&0xFFFF))>>16); + return (x>>16)*(y>>16) + (xy>>16) + (yx>>16) + (c>>16) + + ((R==std::round_to_nearest) ? ((c>>15)&1) : (R==std::round_toward_infinity) ? ((c&0xFFFF)!=0) : 0); + } + + /// 64-bit multiplication. + /// \param x first factor + /// \param y second factor + /// \return upper 32 bit of \a x * \a y rounded to nearest + inline uint32 multiply64(uint32 x, uint32 y) { + return static_cast((static_cast(x)*static_cast(y)+0x80000000)>>32); + } + + /// 64-bit division. + /// \param x upper 32 bit of dividend + /// \param y divisor + /// \param s variable to store sticky bit for rounding + /// \return (\a x << 32) / \a y + inline uint32 divide64(uint32 x, uint32 y, int &s) { + unsigned long long xx = static_cast(x) << 32; + return s = (xx%y!=0), static_cast(xx/y); + } + + /// Half precision positive modulus. + /// \tparam Q `true` to compute full quotient, `false` else + /// \tparam R `true` to compute signed remainder, `false` for positive remainder + /// \param x first operand as positive finite half-precision value + /// \param y second operand as positive finite half-precision value + /// \param quo adress to store quotient at, `nullptr` if \a Q `false` + /// \return modulus of \a x / \a y + template unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) { + unsigned int q = 0; + if(x > y) { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + for(int d=expx-expy; d; --d) { + if(!Q && mx == my) + return 0; + if(mx >= my) { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) { + mx -= my; + ++q; + } + if(Q) { + q &= (1<<(std::numeric_limits::digits-1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx<0x400; mx<<=1,--expy) ; + x = (expy>0) ? ((expy<<10)|(mx&0x3FF)) : (mx>>(1-expy)); + } + if(R) { + unsigned int a, b; + if(y < 0x800) { + a = (x<0x400) ? (x<<1) : (x+0x400); + b = y; + } else { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q&1))) { + int exp = (y>>10) + (y<=0x3FF), d = exp - (x>>10) - (x<=0x3FF); + int m = (((y&0x3FF)|((y>0x3FF)<<10))<<1) - (((x&0x3FF)|((x>0x3FF)<<10))<<(1-d)); + for(; m<0x800 && exp>1; m<<=1,--exp) ; + x = 0x8000 + ((exp-1)<<10) + (m>>1); + q += Q; + } + } + if(Q) + *quo = q; + return x; + } + + /// Fixed point square root. + /// \tparam F number of fractional bits + /// \param r radicand in Q1.F fixed point format + /// \param exp exponent + /// \return square root as Q1.F/2 + template uint32 sqrt(uint32 &r, int &exp) { + int i = exp & 1; + r <<= i; + exp = (exp-i) / 2; + uint32 m = 0; + for(uint32 bit=static_cast(1)<>=2) { + if(r < m+bit) + m >>= 1; + else { + r -= m + bit; + m = (m>>1) + bit; + } + } + return m; + } + + /// Fixed point binary exponential. + /// This uses the BKM algorithm in E-mode. + /// \param m exponent in [0,1) as Q0.31 + /// \param n number of iterations (at most 32) + /// \return 2 ^ \a m as Q1.31 + inline uint32 exp2(uint32 m, unsigned int n = 32) { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, 0x016FE50B, + 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, + 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, 0x00000003, 0x00000001 }; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i=1; i> i; + } + } + return mx; + } + + /// Fixed point binary logarithm. + /// This uses the BKM algorithm in L-mode. + /// \param m mantissa in [1,2) as Q1.30 + /// \param n number of iterations (at most 32) + /// \return log2(\a m) as Q0.31 + inline uint32 log2(uint32 m, unsigned int n = 32) { + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, 0x016FE50B, + 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, 0x0002E2A3, 0x00017153, + 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, 0x000005C5, 0x000002E3, 0x00000171, + 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, 0x0000000C, 0x00000006, 0x00000003, 0x00000001 }; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i=1; i>i); + if(mz <= m) { + mx = mz; + my += logs[i]; + } + } + return my; + } + + /// Fixed point sine and cosine. + /// This uses the CORDIC algorithm in rotation mode. + /// \param mz angle in [-pi/2,pi/2] as Q1.30 + /// \param n number of iterations (at most 31) + /// \return sine and cosine of \a mz as Q1.30 + inline std::pair sincos(uint32 mz, unsigned int n = 31) { + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, 0x007FFF55, + 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, 0x00010000, 0x00008000, + 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, 0x00000200, 0x00000100, 0x00000080, + 0x00000040, 0x00000020, 0x00000010, 0x00000008, 0x00000004, 0x00000002, 0x00000001 }; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i=0; i0x3FF)<<10); + int exp = (abs>>10) + (abs<=0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp+20); + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL<<(62-exp)) - 1, yi = (y+(mask>>1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f>>63); + k = static_cast(yi>>(62-exp)); + return (multiply64(static_cast((sign ? -f : f)>>(31-exp)), 0xC90FDAA2)^sign) - sign; + } + + /// Get arguments for atan2 function. + /// \param abs half-precision floating-point value + /// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 + inline std::pair atan2_args(unsigned int abs) { + int exp = -15; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + uint32 my = ((abs&0x3FF)|0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - ((rexp>-31) ? ((r>>-rexp)|((r&((static_cast(1)<<-rexp)-1))!=0)) : 1); + for(rexp=0; r<0x40000000; r<<=1,--rexp) ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d<-14) ? ((my>>(-d-14))+((my>>(-d-15))&1)) : (my<<(14+d)), (mx<<14)+(r<<13)/mx); + if(d > 0) + return std::make_pair(my<<14, (d>14) ? ((mx>>(d-14))+((mx>>(d-15))&1)) : ((d==14) ? mx : ((mx<<(14-d))+(r<<(13-d))/mx))); + return std::make_pair(my<<13, (mx<<13)+(r<<12)/mx); + } + + /// Get exponentials for hyperbolic computation + /// \param abs half-precision floating-point value + /// \param exp variable to take unbiased exponent of larger result + /// \param n number of BKM iterations (at most 32) + /// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent + inline std::pair hyperbolic_args(unsigned int abs, int &exp, unsigned int n = 32) { + uint32 mx = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29), my; + int e = (abs>>10) + (abs<=0x3FF); + if(e < 14) { + exp = 0; + mx >>= 14 - e; + } else { + exp = mx >> (45-e); + mx = (mx<<(e-14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } else + my = mx; + return std::make_pair(mx, (d<31) ? ((my>>d)|((my&((static_cast(1)< unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) { + int s = 0; + if(esign) { + if(m > 0x80000000) { + m = divide64(0x80000000, m, s); + ++exp; + } + if(exp > 25) + return underflow(sign); + else if(exp == 25) + return rounded(sign, 1, (m&0x7FFFFFFF)!=0); + exp = -exp; + } else if(exp > 15) + return overflow(sign); + return fixed2half(m, exp+14, sign, s); + } + + /// Postprocessing for binary logarithm. + /// \tparam R rounding mode to use + /// \tparam L logarithm for base transformation as Q1.31 + /// \param m fractional part of logarithm as Q0.31 + /// \param ilog signed integer part of logarithm + /// \param exp biased exponent of result + /// \param sign sign bit of result + /// \return value base-transformed and converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) { + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog)<<27)+(m>>4))^msign) - msign; + if(!m) + return 0; + for(; m<0x80000000; m<<=1,--exp) ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); + } + + /// Hypotenuse square root and postprocessing. + /// \tparam R rounding mode to use + /// \param r mantissa as Q2.30 + /// \param exp unbiased exponent + /// \return square root converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if value had to be rounded + template unsigned int hypot_post(uint32 r, int exp) { + int i = r >> 31; + if((exp+=i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r>>i) | (r&i); + uint32 m = sqrt<30>(r, exp+=15); + return fixed2half(m, exp-1, 0, r!=0); + } + + /// Division and postprocessing for tangents. + /// \tparam R rounding mode to use + /// \param my dividend as Q1.31 + /// \param mx divisor as Q1.31 + /// \param exp biased exponent of result + /// \param sign sign bit of result + /// \return quotient converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) { + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my>>(i+1), mx, s); + return fixed2half(m, exp, sign, s); + } + + /// Area function and postprocessing. + /// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = log(x+sqrt(x^2+|-1))`. + /// \tparam R rounding mode to use + /// \tparam S `true` for asinh, `false` for acosh + /// \param arg half-precision argument + /// \return asinh|acosh(\a arg) converted to half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int area(unsigned int arg) { + int abs = arg & 0x7FFF, expx = (abs>>10) + (abs<=0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs&0x3FF)|((abs>0x3FF)<<10)) << 20, my, r; + for(; abs<0x400; abs<<=1,--expy) ; + expy += abs >> 10; + r = ((abs&0x3FF)|0x400) << 5; + r *= r; + i = r >> 31; + expy = 2*expy + i; + r >>= i; + if(S) { + if(expy < 0) { + r = 0x40000000 + ((expy>-30) ? ((r>>-expy)|((r&((static_cast(1)<<-expy)-1))!=0)) : 1); + expy = 0; + } else { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r>>i) | (r&i); + expy += i; + } + } else { + r -= 0x40000000 >> expy; + for(; r<0x40000000; r<<=1,--expy) ; + } + my = sqrt<30>(r, expy); + my = (my<<15) + (r<<14)/my; + if(S) { + mx >>= expy - expx; + ilog = expy; + } else { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R==std::round_to_nearest); + return log2_post(log2(my>>i, 26+S+G)+(G<<3), ilog+i, 17, arg&(static_cast(S)<<15)); + } + + /// Class for 1.31 unsigned floating-point computation + struct f31 { + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + constexpr f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) { + for(; abs<0x400; abs<<=1,--exp) ; + m = static_cast((abs&0x3FF)|0x400) << 21; + exp += (abs>>10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d<32) ? (b.m>>d) : 0); + int i = (m&0xFFFFFFFF) < a.m; + return f31(((m+i)>>i)|0x80000000, a.exp+i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d<32) ? (b.m>>d) : 0); + if(!m) + return f31(0, -32); + for(; m<0x80000000; m<<=1,--exp) ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m<<(1-i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m+i)>>i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. + }; + + /// Error function and postprocessing. + /// This computes the value directly in Q1.31 using the approximations given + /// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). + /// \tparam R rounding mode to use + /// \tparam C `true` for comlementary error function, `false` else + /// \param arg half-precision function argument + /// \return approximated value of error function in half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if no other exception occurred + template unsigned int erf(unsigned int arg) { + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), t = f31(0x80000000, 0) / (f31(0x80000000, 0)+f31(0xA7BA054A, -2)*x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0)*t2+f31(0xB5F0E2AE, 0))*t2+f31(0x82790637, -2)-(f31(0xBA00E2B8, 0)*t2+f31(0x91A98E62, -2))*t) * t / + ((x2.exp<0) ? f31(exp2((x2.exp>-32) ? (x2.m>>-x2.exp) : 0, 30), 0) : f31(exp2((x2.m<>(31-x2.exp))); + return (!C || sign) ? fixed2half(0x80000000-(e.m>>(C-e.exp)), 14+C, sign&(C-1U)) : + (e.exp<-25) ? underflow() : fixed2half(e.m>>1, e.exp+14, 0, e.m&1); + } + + /// Gamma function and postprocessing. + /// This approximates the value of either the gamma function or its logarithm directly in Q1.31. + /// \tparam R rounding mode to use + /// \tparam L `true` for lograithm of gamma function, `false` for gamma function + /// \param arg half-precision floating-point value + /// \return lgamma/tgamma(\a arg) in half-precision + /// \exception FE_OVERFLOW on overflows + /// \exception FE_UNDERFLOW on underflows + /// \exception FE_INEXACT if \a arg is not a positive integer + template unsigned int gamma(unsigned int arg) { +/* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, 0.0114684895434781459556 }; + double t = arg + 4.65, s = p[0]; + for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z+f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), s = + f31(0xA06C9901, 1) + f31(0xBBE654E2, -7)/(x+f31(0x80000000, 2)) + f31(0xA1CE6098, 6)/(x+f31(0x80000000, 1)) + + f31(0xE1868CB7, 7)/x - f31(0x8625E279, 8)/(x+f31(0x80000000, 0)) - f31(0xA03E158F, 2)/(x+f31(0xC0000000, 1)); + int i = (s.exp>=2) + (s.exp>=4) + (s.exp>=8) + (s.exp>=16); + s = f31((static_cast(s.exp)<<(31-i))+(log2(s.m>>1, 28)>>i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) { + i = (t.exp>=2) + (t.exp>=4) + (t.exp>=8); + f31 l = f31((static_cast(t.exp)<<(31-i))+(log2(t.m>>1, 30)>>i), i) / lbe; + s = (x.exp<-1) ? (s-(f31(0x80000000, -1)-x)*l) : (s+(x-f31(0x80000000, -1))*l); + } + s = x.exp ? (s-t) : (t-s); + if(bsign) { + if(z.exp >= 0) { + sign &= (L|((z.m>>(31-z.exp))&1)) - 1; + for(z=f31((z.m<<(1+z.exp))&0xFFFFFFFF, -1); z.m<0x80000000; z.m<<=1,--z.exp) ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) { + z = z * pi; + z.m = sincos(z.m>>(1-z.exp), 30).first; + for(z.exp=1; z.m<0x80000000; z.m<<=1,--z.exp) ; + } + else + z = f31(0x80000000, 0); + } if(L) { + if(bsign) { + f31 l(0x92868247, 0); + if(z.exp < 0) { + uint32 m = log2((z.m+1)>>1, 27); + z = f31(-((static_cast(z.exp)<<26)+(m>>5)), 5); + for(; z.m<0x80000000; z.m<<=1,--z.exp) ; + l = l + z / lbe; + } + sign = static_cast(x.exp&&(l.exp(x.exp==0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } else { + s = s * lbe; + uint32 m; + if(s.exp < 0) { + m = s.m >> -s.exp; + s.exp = 0; + } else { + m = (s.m<>(31-s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } else if(z.exp > 0 && !(z.m&((1<<(31-z.exp))-1))) + return ((s.exp+14)<<10) + (s.m>>21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp+14, sign); + } + /// \} + + template struct half_caster; + } + + /// Half-precision floating-point type. + /// This class implements an IEEE-conformant half-precision floating-point type with the usual arithmetic + /// operators and conversions. It is implicitly convertible to single-precision floating-point, which makes artihmetic + /// expressions and functions with mixed-type operands to be of the most precise operand type. + /// + /// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's less strict and + /// extended definitions it is both a standard layout type and a trivially copyable type (even if not a POD type), which + /// means it can be standard-conformantly copied using raw binary copies. But in this context some more words about the + /// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not neccessarily have to be of + /// exactly 16-bits size. But on any reasonable implementation the actual binary representation of this type will most + /// probably not ivolve any additional "magic" or padding beyond the simple binary representation of the underlying 16-bit + /// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an actual size of 16 bits if + /// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this should be the case on + /// nearly any reasonable platform. + /// + /// So if your C++ implementation is not totally exotic or imposes special alignment requirements, it is a reasonable + /// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE representation. + class half { + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper value-initialization semantics. + constexpr half() noexcept : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + //explicit half(float rhs) : data_(static_cast(detail::float2half(rhs))) {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template + half(T rhs) : data_(static_cast(detail::float2half(static_cast(rhs)))) {} + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(const float &rhs) { data_ = static_cast(detail::float2half(rhs)); return *this; } + + template + half& operator=(const T &rhs) { return *this = static_cast(rhs); } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /* + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + */ + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) { half out(*this); ++*this; return out; } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) { half out(*this); --*this; return out; } + /// \} + detail::uint16 get_data()const{ return data_; } + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + constexpr half(detail::binary_t, unsigned int bits) noexcept : data_(static_cast(bits)) {} + + /// Internal binary representation + detail::uint16 data_; + + friend constexpr_NOERR bool operator==(half, half); + template friend constexpr_NOERR bool operator==(half, T); + template friend constexpr_NOERR bool operator==(T, half); + friend constexpr_NOERR bool operator!=(half, half); + template friend constexpr_NOERR bool operator!=(half, T); + template friend constexpr_NOERR bool operator!=(T, half); + friend constexpr_NOERR bool operator<(half, half); + template friend constexpr_NOERR bool operator<(half, T); + template friend constexpr_NOERR bool operator<(T, half); + friend constexpr_NOERR bool operator>(half, half); + template friend constexpr_NOERR bool operator>(half, T); + template friend constexpr_NOERR bool operator>(T, half); + friend constexpr_NOERR bool operator<=(half, half); + template friend constexpr_NOERR bool operator<=(half, T); + template friend constexpr_NOERR bool operator<=(T, half); + friend constexpr_NOERR bool operator>=(half, half); + template friend constexpr_NOERR bool operator>=(half, T); + template friend constexpr_NOERR bool operator>=(T, half); + friend constexpr half operator+(half); + friend constexpr half operator-(half); + friend half operator+(half, half); + template friend half operator+(half, T); + template friend half operator+(T, half); + friend half operator-(half, half); + template friend half operator-(half, T); + template friend half operator-(T, half); + friend half operator*(half, half); + template friend half operator*(half, T); + template friend half operator*(T, half); + friend half operator/(half, half); + template friend half operator/(half, T); + template friend half operator/(T, half); + template friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend constexpr half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend constexpr_NOERR half fmax(half, half); + friend constexpr_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); + friend long long llround(half); + friend long long llrint(half); + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend constexpr half copysign(half, half); + friend constexpr int fpclassify(half); + friend constexpr bool isfinite(half); + friend constexpr bool isinf(half); + friend constexpr bool isnan(half); + friend constexpr bool isnormal(half); + friend constexpr bool signbit(half); + friend constexpr bool isgreater(half, half); + friend constexpr bool isgreaterequal(half, half); + friend constexpr bool isless(half, half); + friend constexpr bool islessequal(half, half); + friend constexpr bool islessgreater(half, half); + template friend struct detail::half_caster; + friend class std::numeric_limits; + friend struct std::hash; + friend half literal::operator "" _h(long double); + }; + + namespace literal { + /// Half literal. + /// While this returns a properly rounded half-precision value, half literals can unfortunately not be constant + /// expressions due to rather involved conversions. So don't expect this to be a literal literal without involving + /// conversion operations at runtime. It is a convenience feature, not a performance optimization. + /// \param value literal value + /// \return half with of given value (possibly rounded) + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator "" _h(long double value) { return half(detail::binary, detail::float2half(value)); } + } + + namespace detail { + /// Helper class for half casts. + /// This class template has to be specialized for all valid cast arguments to define an appropriate static + /// `cast` member function and a corresponding `type` member denoting its return type. + /// \tparam T destination type + /// \tparam U source type + /// \tparam R rounding mode to use + template struct half_caster {}; + template struct half_caster { + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); + static half cast(U arg) { return cast_impl(arg, is_float()); }; + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } + }; + template struct half_caster { + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); + static T cast(half arg) { return cast_impl(arg, is_float()); } + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } + }; + template struct half_caster { + static half cast(half arg) { return arg; } + }; + } +} + +/// Extensions to the C++ standard library. +namespace std { + /// Numeric limits for half-precision floats. + /// **See also:** Documentation for [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) + template<> class numeric_limits { + public: + /// Is template specialization. + static constexpr bool is_specialized = true; + + /// Supports signed values. + static constexpr bool is_signed = true; + + /// Is not an integer type. + static constexpr bool is_integer = false; + + /// Is not exact. + static constexpr bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static constexpr bool is_modulo = false; + + /// Has a finite set of values. + static constexpr bool is_bounded = true; + + /// IEEE conformant. + static constexpr bool is_iec559 = true; + + /// Supports infinity. + static constexpr bool has_infinity = true; + + /// Supports quiet NaNs. + static constexpr bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static constexpr bool has_signaling_NaN = true; + + /// Supports subnormal values. + static constexpr float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static constexpr bool has_denorm_loss = false; + + #if HALF_ERRHANDLING_THROWS + static constexpr bool traps = true; + #else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is acitvated. + static constexpr bool traps = false; + #endif + + /// Does not support no pre-rounding underflow detection. + static constexpr bool tinyness_before = false; + + /// Rounding mode. + static constexpr float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static constexpr int digits = 11; + + /// Significant decimal digits. + static constexpr int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static constexpr int max_digits10 = 5; + + /// Number base. + static constexpr int radix = 2; + + /// One more than smallest exponent. + static constexpr int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static constexpr int min_exponent10 = -4; + + /// One more than largest exponent + static constexpr int max_exponent = 16; + + /// Largest finitely representable power of 10. + static constexpr int max_exponent10 = 4; + + /// Smallest positive normal value. + static constexpr half_float::half min() noexcept { return half_float::half(half_float::detail::binary, 0x0400); } + + /// Smallest finite value. + static constexpr half_float::half lowest() noexcept { return half_float::half(half_float::detail::binary, 0xFBFF); } + + /// Largest finite value. + static constexpr half_float::half max() noexcept { return half_float::half(half_float::detail::binary, 0x7BFF); } + + /// Difference between 1 and next representable value. + static constexpr half_float::half epsilon() noexcept { return half_float::half(half_float::detail::binary, 0x1400); } + + /// Maximum rounding error in ULP (units in the last place). + static constexpr half_float::half round_error() noexcept + { return half_float::half(half_float::detail::binary, (round_style==std::round_to_nearest) ? 0x3800 : 0x3C00); } + + /// Positive infinity. + static constexpr half_float::half infinity() noexcept { return half_float::half(half_float::detail::binary, 0x7C00); } + + /// Quiet NaN. + static constexpr half_float::half quiet_NaN() noexcept { return half_float::half(half_float::detail::binary, 0x7FFF); } + + /// Signaling NaN. + static constexpr half_float::half signaling_NaN() noexcept { return half_float::half(half_float::detail::binary, 0x7DFF); } + + /// Smallest positive subnormal value. + static constexpr half_float::half denorm_min() noexcept { return half_float::half(half_float::detail::binary, 0x0001); } + }; + + /// Hash function for half-precision floats. + /// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) + template<> struct hash { + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const { return hash()(arg.data_&-static_cast(arg.data_!=0x8000)); } + }; +} + +namespace half_float { + /// \anchor compop + /// \name Comparison operators + /// \{ + + /// Comparison for equality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands equal + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool operator==(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && (x.data_==y.data_ || !((x.data_|y.data_)&0x7FFF)); + } + template + inline constexpr_NOERR bool operator==(half x, T y) { return x == static_cast(y); } + template + inline constexpr_NOERR bool operator==(T x, half y) { return static_cast(x) == y; } + + /// Comparison for inequality. + /// \param x first operand + /// \param y second operand + /// \retval true if operands not equal + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool operator!=(half x, half y) { + return detail::compsignal(x.data_, y.data_) || (x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF)); + } + template + inline constexpr_NOERR bool operator!=(half x, T y) { return x != static_cast(y); } + template + inline constexpr_NOERR bool operator!=(T x, half y) { return static_cast(x) != y; } + + /// Comparison for less than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool operator<(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) < ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + template + inline constexpr_NOERR bool operator<(half x, T y) { return x < static_cast(y); } + template + inline constexpr_NOERR bool operator<(T x, half y) { return static_cast(x) < y; } + + /// Comparison for greater than. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool operator>(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) > ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + template + inline constexpr_NOERR bool operator>(half x, T y) { return x > static_cast(y); } + template + inline constexpr_NOERR bool operator>(T x, half y) { return static_cast(x) > y; } + + /// Comparison for less equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool operator<=(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) <= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + template + inline constexpr_NOERR bool operator<=(half x, T y) { return x <= static_cast(y); } + template + inline constexpr_NOERR bool operator<=(T x, half y) { return static_cast(x) <= y; } + + /// Comparison for greater equal. + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + /// \exception FE_INVALID if \a x or \a y is NaN + inline constexpr_NOERR bool operator>=(half x, half y) { + return !detail::compsignal(x.data_, y.data_) && + ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) >= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)); + } + template + inline constexpr_NOERR bool operator>=(half x, T y) { return x >= static_cast(y); } + template + inline constexpr_NOERR bool operator>=(T x, half y) { return static_cast(x) >= y; } + + /// \} + /// \anchor arithmetics + /// \name Arithmetic operators + /// \{ + + /// Identity. + /// \param arg operand + /// \return unchanged operand + inline constexpr half operator+(half arg) { return arg; } + + /// Negation. + /// \param arg operand + /// \return negated operand + inline constexpr half operator-(half arg) { return half(detail::binary, arg.data_^0x8000); } + + /// Addition. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return sum of half expressions + /// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator+(half x, half y) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)+detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_^y.data_)&0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : (absy!=0x7C00) ? x.data_ : + (sub && absx==0x7C00) ? detail::invalid() : y.data_); + if(!absx) + return absy ? y : half(detail::binary, (half::round_style==std::round_toward_neg_infinity) ? (x.data_|y.data_) : (x.data_&y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy>absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx>>10) + (absx<=0x3FF), d = exp - (absy>>10) - (absy<=0x3FF), mx = ((absx&0x3FF)|((absx>0x3FF)<<10)) << 3, my; + if(d < 13) { + my = ((absy&0x3FF)|((absy>0x3FF)<<10)) << 3; + my = (my>>d) | ((my&((1<(half::round_style==std::round_toward_neg_infinity)<<15); + for(; mx<0x2000 && exp>1; mx<<=1,--exp) ; + } else { + mx += my; + int i = mx >> 14; + if((exp+=i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx>>i) | (mx&i); + } + return half(detail::binary, detail::rounded(sign+((exp-1)<<10)+(mx>>3), (mx>>2)&1, (mx&0x3)!=0)); + #endif + } + template + inline half operator+(half x, T y) { return x + static_cast(y); } + template + inline half operator+(T x, half y) { return static_cast(x) + y; } + + /// Subtraction. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return difference of half expressions + /// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator-(half x, half y) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)-detail::half2float(y.data_))); + #else + return x + (-y); + #endif + } + template + inline half operator-(half x, T y) { return x - static_cast(y); } + template + inline half operator-(T x, half y) { return static_cast(x) - y; } + + /// Multiplication. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return product of half expressions + /// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator*(half x, half y) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)*detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_^y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + ((absx==0x7C00 && !absy)||(absy==0x7C00 && !absx)) ? detail::invalid() : (sign|0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + detail::uint32 m = static_cast((absx&0x3FF)|0x400) * static_cast((absy&0x3FF)|0x400); + int i = m >> 21, s = m & i; + exp += (absx>>10) + (absy>>10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, detail::fixed2half(m>>i, exp, sign, s)); + #endif + } + template + inline half operator*(half x, T y) { return x * static_cast(y); } + template + inline half operator*(T x, half y) { return static_cast(x) * y; } + + /// Division. + /// This operation is exact to rounding for all rounding modes. + /// \param x left operand + /// \param y right operand + /// \return quotient of half expressions + /// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is signaling NaN + /// \exception FE_DIVBYZERO if dividing finite value by 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half operator/(half x, half y) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(detail::half2float(x.data_)/detail::half2float(y.data_))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_^y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==absy) ? detail::invalid() : (sign|((absx==0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,++exp) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + int i = mx < my; + exp += (absx>>10) - (absy>>10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, detail::fixed2half(mx/my, exp, sign, mx%my!=0)); + #endif + } + template + inline half operator/(half x, T y) { return x / static_cast(y); } + template + inline half operator/(T x, half y) { return static_cast(x) / y; } + + /// \} + /// \anchor streaming + /// \name Input and output + /// \{ + + /// Output operator. + /// This uses the built-in functionality for streaming out floating-point numbers. + /// \param out output stream to write into + /// \param arg half expression to write + /// \return reference to output stream + template std::basic_ostream& operator<<(std::basic_ostream &out, half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); + #else + return out << detail::half2float(arg.data_); + #endif + } + + /// Input operator. + /// This uses the built-in functionality for streaming in floating-point numbers, specifically double precision floating + /// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the input string is first + /// rounded to double precision using the underlying platform's current floating-point rounding mode before being rounded + /// to half-precision using the library's half-precision rounding mode. + /// \param in input stream to read from + /// \param arg half to read into + /// \return reference to input stream + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template std::basic_istream& operator>>(std::basic_istream &in, half &arg) { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; + #else + double f; + #endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; + } + + /// \} + /// \anchor basic + /// \name Basic mathematical operations + /// \{ + + /// Absolute value. + /// **See also:** Documentation for [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). + /// \param arg operand + /// \return absolute value of \a arg + inline constexpr half fabs(half arg) { return half(detail::binary, arg.data_&0x7FFF); } + + /// Absolute value. + /// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). + /// \param arg operand + /// \return absolute value of \a arg + inline constexpr half abs(half arg) { return fabs(arg); } + + /// Remainder of division. + /// **See also:** Documentation for [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). + /// \param x first operand + /// \param y second operand + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half fmod(half x, half y) { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign|detail::mod(absx, absy)); + } + + /// Remainder of division. + /// **See also:** Documentation for [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). + /// \param x first operand + /// \param y second operand + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half remainder(half x, half y) { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign^detail::mod(absx, absy)); + } + + /// Remainder of division. + /// **See also:** Documentation for [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). + /// \param x first operand + /// \param y second operand + /// \param quo address to store some bits of quotient at + /// \return remainder of floating-point division. + /// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN + inline half remquo(half x, half y, int *quo) { + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absx==0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value^y.data_)&0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); + } + + /// Fused multiply add. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). + /// \param x first operand + /// \param y second operand + /// \param z third operand + /// \return ( \a x * \a y ) + \a z rounded as one operation. + /// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet NaN and no argument is a signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition + inline half fma(half x, half y, half z) { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_), fz = detail::half2float(z.data_); + #if FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); + #else + return half(detail::binary, detail::float2half(fx*fy+fz)); + #endif + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_^y.data_) & 0x8000; + bool sub = ((sign^z.data_)&0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx>0x7C00 || absy>0x7C00 || absz>0x7C00) ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) : + (absx==0x7C00) ? half(detail::binary, (!absy || (sub && absz==0x7C00)) ? detail::invalid() : (sign|0x7C00)) : + (absy==0x7C00) ? half(detail::binary, (!absx || (sub && absz==0x7C00)) ? detail::invalid() : (sign|0x7C00)) : z; + if(!absx || !absy) + return absz ? z : half(detail::binary, (half::round_style==std::round_toward_neg_infinity) ? (z.data_|sign) : (z.data_&sign)); + for(; absx<0x400; absx<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + detail::uint32 m = static_cast((absx&0x3FF)|0x400) * static_cast((absy&0x3FF)|0x400); + int i = m >> 21; + exp += (absx>>10) + (absy>>10) + i; + m <<= 3 - i; + if(absz) { + int expz = 0; + for(; absz<0x400; absz<<=1,--expz) ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz&0x3FF)|0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d<23) ? ((mz>>d)|((mz&((static_cast(1)<(half::round_style==std::round_toward_neg_infinity)<<15); + for(; m<0x800000; m<<=1,--exp) ; + } else { + m += mz; + i = m >> 24; + m = (m>>i) | (m&i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, detail::fixed2half(m, exp-1, sign)); + #endif + } + + /// Maximum of half expressions. + /// **See also:** Documentation for [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). + /// \param x first operand + /// \param y second operand + /// \return maximum of operands, ignoring quiet NaNs + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline constexpr_NOERR half fmax(half x, half y) { + return half(detail::binary, (!isnan(y) && (isnan(x) || (x.data_^(0x8000|(0x8000-(x.data_>>15)))) < + (y.data_^(0x8000|(0x8000-(y.data_>>15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); + } + + /// Minimum of half expressions. + /// **See also:** Documentation for [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). + /// \param x first operand + /// \param y second operand + /// \return minimum of operands, ignoring quiet NaNs + /// \exception FE_INVALID if \a x or \a y is signaling NaN + inline constexpr_NOERR half fmin(half x, half y) { + return half(detail::binary, (!isnan(y) && (isnan(x) || (x.data_^(0x8000|(0x8000-(x.data_>>15)))) > + (y.data_^(0x8000|(0x8000-(y.data_>>15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); + } + + /// Positive difference. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). + /// \param x first operand + /// \param y second operand + /// \return \a x - \a y or 0 if difference negative + /// \exception FE_... according to operator-(half,half) + inline half fdim(half x, half y) { + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_^(0x8000|(0x8000-(x.data_>>15)))) <= (y.data_^(0x8000|(0x8000-(y.data_>>15)))) ? half(detail::binary, 0) : (x-y); + } + + /// Get NaN value. + /// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). + /// \param arg string code + /// \return quiet NaN + inline half nanh(const char *arg) { + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); + } + + /// \} + /// \anchor exponential + /// \name Exponential functions + /// \{ + + /// Exponential function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). + /// \param arg function argument + /// \return e raised to \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half exp(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::exp(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00&((arg.data_>>15)-1U)) : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, (arg.data_&0x8000) ? detail::underflow() : detail::overflow()); + detail::uint32 m = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29); + int e = (abs>>10) + (abs<=0x3FF), exp; + if(e < 14) { + exp = 0; + m >>= 14 - e; + } else { + exp = m >> (45-e); + m = (m<<(e-14)) & 0x7FFFFFFF; + } + return half(detail::binary, detail::exp2_post(detail::exp2(m, 26), exp, (arg.data_&0x8000)!=0)); + #endif + } + + /// Binary exponential. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). + /// \param arg function argument + /// \return 2 raised to \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half exp2(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::exp2(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00&((arg.data_>>15)-1U)) : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, (arg.data_&0x8000) ? detail::underflow() : detail::overflow()); + int e = (abs>>10) + (abs<=0x3FF), exp = (abs&0x3FF) + ((abs>0x3FF)<<10); + detail::uint32 m = detail::exp2((static_cast(exp)<<(6+e))&0x7FFFFFFF, 28); + exp >>= 25 - e; + if(m == 0x80000000) { + if(arg.data_&0x8000) + exp = -exp; + else if(exp > 15) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::fixed2half(m, exp+14)); + } + return half(detail::binary, detail::exp2_post(m, exp, (arg.data_&0x8000)!=0)); + #endif + } + + /// Exponential minus one. + /// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for `std::round_to_nearest` + /// and in <1% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). + /// \param arg function argument + /// \return e raised to \a arg and subtracted by 1 + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half expm1(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::expm1(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? (0x7C00+(sign>>1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, (arg.data_&0x8000) ? detail::rounded(0xBBFF, 1, 1) : detail::overflow()); + detail::uint32 m = detail::multiply64(static_cast((abs&0x3FF)+((abs>0x3FF)<<10))<<21, 0xB8AA3B29); + int e = (abs>>10) + (abs<=0x3FF), exp; + if(e < 14) { + exp = 0; + m >>= 14 - e; + } else { + exp = m >> (45-e); + m = (m<<(e-14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) { + int s = 0; + if(m > 0x80000000) { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - ((m>>exp)|((m&((static_cast(1)<>exp) : 1; + for(exp+=14; m<0x80000000 && exp; m<<=1,--exp) ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::rounded(sign+(exp<<10)+(m>>21), (m>>20)&1, (m&0xFFFFF)!=0)); + #endif + } + + /// Natural logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). + /// \param arg function argument + /// \return logarithm of \a arg to base e + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::log(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + return half(detail::binary, detail::log2_post( + detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 27)+8, exp, 17)); + #endif + } + + /// Common logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). + /// \param arg function argument + /// \return logarithm of \a arg to base 10 + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log10(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::log10(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + return half(detail::binary, detail::log2_post( + detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 27)+8, exp, 16)); + #endif + } + + /// Binary logarithm. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). + /// \param arg function argument + /// \return logarithm of \a arg to base 2 + /// \exception FE_INVALID for signaling NaN or negative argument + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log2(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::log2(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs<0x400; abs<<=1,--exp) ; + exp += (abs>>10); + if(!(abs&0x3FF)) { + unsigned int value = static_cast(exp<0) << 15, m = std::abs(exp) << 6; + for(exp=18; m<0x400; m<<=1,--exp) ; + return half(detail::binary, value+(exp<<10)+m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), m = + (((ilog<<27)+(detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 28)>>4))^sign) - sign; + if(!m) + return half(detail::binary, 0); + for(exp=14; m<0x8000000 && exp; m<<=1,--exp) ; + for(; m>0xFFFFFFF; m>>=1,++exp) + s |= m & 1; + return half(detail::binary, detail::fixed2half(m, exp, sign&0x8000, s)); + #endif + } + + /// Natural logarithm plus one. + /// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for `std::round_to_nearest` + /// and in ~1% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). + /// \param arg function argument + /// \return logarithm of \a arg plus 1 to base e + /// \exception FE_INVALID for signaling NaN or argument <-1 + /// \exception FE_DIVBYZERO for -1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half log1p(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::log1p(detail::half2float(arg.data_)))); + #else + if(arg.data_ >= 0xBC00) + return half(detail::binary, (arg.data_==0xBC00) ? detail::pole(0x8000) : (arg.data_<=0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs&0x3FF)|0x400) << 20; + if(arg.data_ & 0x8000) { + m = 0x40000000 - (m>>-exp); + for(exp=0; m<0x40000000; m<<=1,--exp) ; + } else { + if(exp < 0) { + m = 0x40000000 + (m>>-exp); + exp = 0; + } else { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, detail::log2_post(detail::log2(m), exp, 17)); + #endif + } + + /// \} + /// \anchor power + /// \name Power functions + /// \{ + + /// Square root. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). + /// \param arg function argument + /// \return square root of \a arg + /// \exception FE_INVALID for signaling NaN and negative arguments + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sqrt(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sqrt(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_>0x8000) ? detail::invalid() : arg.data_); + for(; abs<0x400; abs<<=1,--exp) ; + detail::uint32 r = static_cast((abs&0x3FF)|0x400) << 10, m = detail::sqrt<20>(r, exp+=abs>>10); + return half(detail::binary, detail::rounded((exp<<10)+(m&0x3FF), r>m, r!=0)); + #endif + } + + /// Cubic root. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). + /// \param arg function argument + /// \return cubic root of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cbrt(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::cbrt(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1, --exp); + detail::uint32 ilog = exp + (abs>>10), sign = detail::sign_mask(ilog), f, m = + (((ilog<<27)+(detail::log2(static_cast((abs&0x3FF)|0x400)<<20, 24)>>4))^sign) - sign; + for(exp=2; m<0x80000000; m<<=1,--exp) ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) { + f = m >> -exp; + exp = 0; + } else { + f = (m<> (31-exp); + } + m = detail::exp2(f, (half::round_style==std::round_to_nearest) ? 29 : 26); + if(sign) { + if(m > 0x80000000) { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, (half::round_style==std::round_to_nearest) ? + detail::fixed2half(m, exp+14, arg.data_&0x8000) : + detail::fixed2half((m+0x80)>>8, exp+14, arg.data_&0x8000)); + #endif + } + + /// Hypotenuse function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). + /// \param x first argument + /// \param y second argument + /// \return square root of sum of squares without internal over- or underflows + /// \exception FE_INVALID if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root + inline half hypot(half x, half y) { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_); + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx==0x7C00) ? detail::select(0x7C00, y.data_) : + (absy==0x7C00) ? detail::select(0x7C00, x.data_) : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2*(expx+(absx>>10)) - 15 + ix; + expy = 2*(expy+(absy>>10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d<30) ? ((my>>d)|((my&((static_cast(1)<(mx+my, expx)); + #endif + } + + /// Hypotenuse function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). + /// \param x first argument + /// \param y second argument + /// \param z third argument + /// \return square root of sum of squares without internal over- or underflows + /// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root + inline half hypot(half x, half y, half z) { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), fy = detail::half2float(y.data_), fz = detail::half2float(z.data_); + return half(detail::binary, detail::float2half(std::sqrt(fx*fx+fy*fy+fz*fz))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, (absx==0x7C00) ? detail::select(0x7C00, detail::select(y.data_, z.data_)) : + (absy==0x7C00) ? detail::select(0x7C00, detail::select(x.data_, z.data_)) : + (absz==0x7C00) ? detail::select(0x7C00, detail::select(x.data_, y.data_)) : + detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx<0x400; absx<<=1,--expx) ; + for(; absy<0x400; absy<<=1,--expy) ; + for(; absz<0x400; absz<<=1,--expz) ; + detail::uint32 mx = (absx&0x3FF) | 0x400, my = (absy&0x3FF) | 0x400, mz = (absz&0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2*(expx+(absx>>10)) - 15 + ix; + expy = 2*(expy+(absy>>10)) - 15 + iy; + expz = 2*(expz+(absz>>10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d<30) ? ((mz>>d)|((mz&((static_cast(1)<>1) | (my&1); + if(++expy > expx) { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d<30) ? ((my>>d)|((my&((static_cast(1)<(mx+my, expx)); + #endif + } + + /// Power function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in ~0.00025% of inputs. + /// + /// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). + /// \param x base + /// \param y exponent + /// \return \a x raised to \a y + /// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y is finite and not integral + /// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half pow(half x, half y) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::pow(detail::half2float(x.data_), detail::half2float(y.data_)))); + #else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, detail::select(0x3C00, (x.data_==0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy>=0x3C00 && !(absy&((1<<(25-(absy>>10)))-1))); + unsigned int sign = x.data_ & (static_cast((absy<0x6800)&&is_int&&((absy>>(25-(absy>>10)))&1))<<15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, (absx>0x7C00 || absy>0x7C00) ? detail::signal(x.data_, y.data_) : + (absy==0x7C00) ? ((absx==0x3C00) ? 0x3C00 : (!absx && y.data_==0xFC00) ? detail::pole() : + (0x7C00&-((y.data_>>15)^(absx>0x3C00)))) : (sign|(0x7C00&((y.data_>>15)-1U)))); + if(!absx) + return half(detail::binary, (y.data_&0x8000) ? detail::pole(sign) : sign); + if((x.data_&0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign|0x3C00); + if(y.data_ == 0x3800) + return sqrt(x); + if(y.data_ == 0x3C00) + return half(detail::binary, detail::check_underflow(x.data_)); + if(y.data_ == 0x4000) + return x * x; + for(; absx<0x400; absx<<=1,--exp) ; + detail::uint32 ilog = exp + (absx>>10), msign = detail::sign_mask(ilog), f, m = + (((ilog<<27)+((detail::log2(static_cast((absx&0x3FF)|0x400)<<20)+8)>>4))^msign) - msign; + for(exp=-11; m<0x80000000; m<<=1,--exp) ; + for(; absy<0x400; absy<<=1,--exp) ; + m = detail::multiply64(m, static_cast((absy&0x3FF)|0x400)<<21); + int i = m >> 31; + exp += (absy>>10) + i; + m <<= 1 - i; + if(exp < 0) { + f = m >> -exp; + exp = 0; + } else { + f = (m<> (31-exp); + } + return half(detail::binary, detail::exp2_post(detail::exp2(f), exp, ((msign&1)^(y.data_>>15))!=0, sign)); + #endif + } + + /// \} + /// \anchor trigonometric + /// \name Trigonometric functions + /// \{ + + /// Compute sine and cosine simultaneously. + /// This returns the same results as sin() and cos() but is faster than calling each function individually. + /// + /// This function is exact to rounding for all rounding modes. + /// \param arg function argument + /// \param sin variable to take sine of \a arg + /// \param cos variable to take cosine of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline void sincos(half arg, half *sin, half *cos) { + #ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); + #else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } else if(abs < 0x2500) { + *sin = half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } else { + if(half::round_style != std::round_to_nearest) { + switch(abs) { + case 0x48B7: + *sin = half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half(detail::binary, detail::rounded((arg.data_&0x8000)|0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, detail::fixed2half((sc.first^-static_cast(sign))+sign)); + *cos = half(detail::binary, detail::fixed2half(sc.second)); + } + #endif + } + + /// Sine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). + /// \param arg function argument + /// \return sine value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sin(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sin(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) { + case 0x48B7: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x1D07, 1, 1)); + case 0x6A64: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x3BFE, 1, 1)); + case 0x6D8C: return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k>>1)&1)^(arg.data_>>15)); + return half(detail::binary, detail::fixed2half((((k&1) ? sc.second : sc.first)^sign) - sign)); + #endif + } + + /// Cosine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). + /// \param arg function argument + /// \return cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cos(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::cos(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k>>1)^k)&1); + return half(detail::binary, detail::fixed2half((((k&1) ? sc.first : sc.second)^sign) - sign)); + #endif + } + + /// Tangent function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). + /// \param arg function argument + /// \return tangent value of \a arg + /// \exception FE_INVALID for signaling NaN or infinity + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tan(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::tan(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) { + case 0x658C: return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x07E6, 1, 1)); + case 0x7330: return half(detail::binary, detail::rounded((~arg.data_&0x8000)|0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first^signy) - signy, mx = (sc.second^signx) - signx; + for(; my<0x80000000; my<<=1,--exp) ; + for(; mx<0x80000000; mx<<=1,++exp) ; + return half(detail::binary, detail::tangent_post(my, mx, exp, (signy^signx^arg.data_)&0x8000)); + #endif + } + + /// Arc sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). + /// \param arg function argument + /// \return arc sine value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half asin(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::asin(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (abs>0x3C00) ? detail::invalid() : + detail::rounded(sign|0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_+1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(sc.first, sc.second, (half::round_style==std::round_to_nearest) ? 27 : 26); + return half(detail::binary, detail::fixed2half(m, 14, sign)); + #endif + } + + /// Arc cosine function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). + /// \param arg function argument + /// \return arc cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half acos(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::acos(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (abs>0x3C00) ? detail::invalid() : + sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, detail::fixed2half(sign ? (0xC90FDAA2-m) : m, 15, 0, sign)); + #endif + } + + /// Arc tangent function. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). + /// \param arg function argument + /// \return arc tangent value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atan(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::atan(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? detail::rounded(sign|0x3E48, 0, 1) : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + int exp = (abs>>10) + (abs<=0x3FF); + detail::uint32 my = (abs&0x3FF) | ((abs>0x3FF)<<10); + detail::uint32 m = (exp>15) ? detail::atan2(my<<19, 0x20000000>>(exp-15), (half::round_style==std::round_to_nearest) ? 26 : 24) : + detail::atan2(my<<(exp+4), 0x20000000, (half::round_style==std::round_to_nearest) ? 30 : 28); + return half(detail::binary, detail::fixed2half(m, 14, sign)); + #endif + } + + /// Arc tangent function. + /// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for `std::round_to_nearest`, + /// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding mode. + /// + /// **See also:** Documentation for [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). + /// \param y numerator + /// \param x denominator + /// \return arc tangent value + /// \exception FE_INVALID if \a x or \a y is signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atan2(half y, half x) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::atan2(detail::half2float(y.data_), detail::half2float(x.data_)))); + #else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, (absx<0x7C00) ? detail::rounded(signy|0x3E48, 0, 1) : + signx ? detail::rounded(signy|0x40B6, 0, 1) : + detail::rounded(signy|0x3A48, 0, 1)); + return (x.data_==0x7C00) ? half(detail::binary, signy) : half(detail::binary, detail::rounded(signy|0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, detail::rounded(signy|0x4248, 0, 1)) : y; + if(!absx) + return half(detail::binary, detail::rounded(signy|0x3E48, 0, 1)); + int d = (absy>>10) + (absy<=0x3FF) - (absx>>10) - (absx<=0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy|0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy|0x4248, 0, 1)); + if(!signx && d < ((half::round_style==std::round_toward_zero) ? -15 : -9)) { + for(; absy<0x400; absy<<=1,--d) ; + detail::uint32 mx = ((absx<<1)&0x7FF) | 0x800, my = ((absy<<1)&0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, detail::fixed2half(my/mx, d+14, signy, my%mx!=0)); + } + detail::uint32 m = detail::atan2( ((absy&0x3FF)|((absy>0x3FF)<<10))<<(19+((d<0) ? d : (d>0) ? 0 : -1)), + ((absx&0x3FF)|((absx>0x3FF)<<10))<<(19-((d>0) ? d : (d<0) ? 0 : 1))); + return half(detail::binary, detail::fixed2half(signx ? (0xC90FDAA2-m) : m, 15, signy, signx)); + #endif + } + + /// \} + /// \anchor hyperbolic + /// \name Hyperbolic functions + /// \{ + + /// Hyperbolic sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). + /// \param arg function argument + /// \return hyperbolic sine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half sinh(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::sinh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, (half::round_style==std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp+=13; m<0x80000000 && exp; m<<=1,--exp) ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, detail::fixed2half(m, exp, sign)); + #endif + } + + /// Hyperbolic cosine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). + /// \param arg function argument + /// \return hyperbolic cosine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half cosh(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::cosh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = detail::hyperbolic_args(abs, exp, (half::round_style==std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m&0xFFFFFFFF) >> 31; + m = (m>>i) | (m&i) | 0x80000000; + if((exp+=13+i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, detail::fixed2half(m, exp)); + #endif + } + + /// Hyperbolic tangent. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). + /// \param arg function argument + /// \return hyperbolic tangent value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tanh(half arg) { + #ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, detail::float2half(std::tanh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, (abs>0x7C00) ? detail::signal(arg.data_) : (arg.data_-0x4000)); + if(abs >= 0x4500) + return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_-3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style!=std::round_to_nearest), mx = mm.first + mm.second, i = (~mx&0xFFFFFFFF) >> 31; + for(exp=13; my<0x80000000; my<<=1,--exp) ; + mx = (mx>>i) | 0x80000000; + return half(detail::binary, detail::tangent_post(my, mx, exp-i, arg.data_&0x8000)); + #endif + } + + /// Hyperbolic area sine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). + /// \param arg function argument + /// \return area sine value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half asinh(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::asinh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_-1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: return half(detail::binary, detail::rounded(arg.data_-13, 1, 1)); + case 0x3B5B: return half(detail::binary, detail::rounded(arg.data_-197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); + #endif + } + + /// Hyperbolic area cosine. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). + /// \param arg function argument + /// \return area cosine value of \a arg + /// \exception FE_INVALID for signaling NaN or arguments <1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half acosh(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::acosh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if((arg.data_&0x8000) || abs < 0x3C00) + return half(detail::binary, (abs<=0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); + #endif + } + + /// Hyperbolic area tangent. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). + /// \param arg function argument + /// \return area tangent value of \a arg + /// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 + /// \exception FE_DIVBYZERO for +/-1 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half atanh(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::atanh(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, (abs==0x3C00) ? detail::pole(arg.data_&0x8000) : (abs<=0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs&0x3FF)|((abs>0x3FF)<<10)) << ((abs>>10)+(abs<=0x3FF)+6), my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx<0x80000000; mx<<=1,++exp) ; + int i = my >= mx, s; + return half(detail::binary, detail::log2_post(detail::log2( + (detail::divide64(my>>i, mx, s)+1)>>1, 27)+0x10, exp+i-1, 16, arg.data_&0x8000)); + #endif + } + + /// \} + /// \anchor special + /// \name Error and gamma functions + /// \{ + + /// Error function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% of inputs. + /// + /// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). + /// \param arg function argument + /// \return error function value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half erf(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::erf(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs>=0x7C00) ? half(detail::binary, (abs==0x7C00) ? (arg.data_-0x4000) : detail::signal(arg.data_)) : arg; + if(abs >= 0x4200) + return half(detail::binary, detail::rounded((arg.data_&0x8000)|0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); + #endif + } + + /// Complementary error function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% of inputs. + /// + /// **See also:** Documentation for [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). + /// \param arg function argument + /// \return 1 minus error function value of \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half erfc(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::erfc(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs>=0x7C00) ? half(detail::binary, (abs==0x7C00) ? (sign>>1) : detail::signal(arg.data_)) : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half(detail::binary, detail::rounded((sign>>1)-(sign>>15), sign>>15, 1)); + return half(detail::binary, detail::erf(arg.data_)); + #endif + } + + /// Natural logarithm of gamma function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in ~0.025% of inputs. + /// + /// **See also:** Documentation for [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). + /// \param arg function argument + /// \return natural logarith of gamma function for \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_DIVBYZERO for 0 or negative integer arguments + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half lgamma(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::lgamma(detail::half2float(arg.data_)))); + #else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs&((1<<(25-(abs>>10)))-1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); + #endif + } + + /// Gamma function. + /// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.25% of inputs. + /// + /// **See also:** Documentation for [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). + /// \param arg function argument + /// \return gamma function value of \a arg + /// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments + /// \exception FE_DIVBYZERO for 0 + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half tgamma(half arg) { + #if defined(HALF_ARITHMETIC_TYPE) + return half(detail::binary, detail::float2half(std::tgamma(detail::half2float(arg.data_)))); + #else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_==0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs&((1<<(25-(abs>>10)))-1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half(detail::binary, detail::underflow((1-((abs>>(25-(abs>>10)))&1))<<15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); + #endif + } + + /// \} + /// \anchor rounding + /// \name Rounding + /// \{ + + /// Nearest integer not less than half value. + /// **See also:** Documentation for [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). + /// \param arg half to round + /// \return nearest integer not less than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half ceil(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer not greater than half value. + /// **See also:** Documentation for [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). + /// \param arg half to round + /// \return nearest integer not greater than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half floor(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer not greater in magnitude than half value. + /// **See also:** Documentation for [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). + /// \param arg half to round + /// \return nearest integer not greater in magnitude than \a arg + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half trunc(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer. + /// **See also:** Documentation for [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half round(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer. + /// **See also:** Documentation for [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID if value is not representable as `long` + inline long lround(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID for signaling NaN + /// \exception FE_INEXACT if value had to be rounded + inline half rint(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID if value is not representable as `long` + /// \exception FE_INEXACT if value had to be rounded + inline long lrint(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID for signaling NaN + inline half nearbyint(half arg) { return half(detail::binary, detail::integral(arg.data_)); } + /// Nearest integer. + /// **See also:** Documentation for [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). + /// \param arg half to round + /// \return nearest integer, rounded away from zero in half-way cases + /// \exception FE_INVALID if value is not representable as `long long` + inline long long llround(half arg) { return detail::half2int(arg.data_); } + + /// Nearest integer using half's internal rounding mode. + /// **See also:** Documentation for [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). + /// \param arg half expression to round + /// \return nearest integer using default rounding mode + /// \exception FE_INVALID if value is not representable as `long long` + /// \exception FE_INEXACT if value had to be rounded + inline long long llrint(half arg) { return detail::half2int(arg.data_); } + + /// \} + /// \anchor float + /// \name Floating point manipulation + /// \{ + + /// Decompress floating-point number. + /// **See also:** Documentation for [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). + /// \param arg number to decompress + /// \param exp address to store exponent at + /// \return significant in range [0.5, 1) + /// \exception FE_INVALID for signaling NaN + inline half frexp(half arg, int *exp) { + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--*exp) ; + *exp += (abs>>10) - 14; + return half(detail::binary, (arg.data_&0x8000)|0x3800|(abs&0x3FF)); + } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half scalbln(half arg, long exp) { + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs>0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs<0x400; abs<<=1,--exp) ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign|(exp<<10)|(abs&0x3FF)); + unsigned int m = (abs&0x3FF) | 0x400; + return half(detail::binary, detail::rounded(sign|(m>>(1-exp)), (m>>-exp)&1, (m&((1<<-exp)-1))!=0)); + } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + + /// Multiply by power of two. + /// This function is exact to rounding for all rounding modes. + /// + /// **See also:** Documentation for [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). + /// \param arg number to modify + /// \param exp power of two to multiply with + /// \return \a arg multplied by 2 raised to \a exp + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + + /// Extract integer and fractional parts. + /// **See also:** Documentation for [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). + /// \param arg number to decompress + /// \param iptr address to store integer part at + /// \return fractional part + /// \exception FE_INVALID for signaling NaN + inline half modf(half arg, half *iptr) { + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_&0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1<<(25-exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_&0x8000); + for(; m<0x400; m<<=1,--exp) ; + return half(detail::binary, (arg.data_&0x8000)|(exp<<10)|(m&0x3FF)); + } + + /// Extract exponent. + /// **See also:** Documentation for [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). + /// \param arg number to query + /// \return floating-point exponent + /// \retval FP_ILOGB0 for zero + /// \retval FP_ILOGBNAN for NaN + /// \retval INT_MAX for infinity + /// \exception FE_INVALID for 0 or infinite values + inline int ilogb(half arg) { + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs==0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp=(abs>>10)-15; abs<0x200; abs<<=1,--exp) ; + return exp; + } + + /// Extract exponent. + /// **See also:** Documentation for [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). + /// \param arg number to query + /// \return floating-point exponent + /// \exception FE_INVALID for signaling NaN + /// \exception FE_DIVBYZERO for 0 + inline half logb(half arg) { + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs==0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp=(abs>>10)-15; abs<0x200; abs<<=1,--exp) ; + unsigned int value = static_cast(exp<0) << 15; + if(exp) { + unsigned int m = std::abs(exp) << 6; + for(exp=18; m<0x400; m<<=1,--exp) ; + value |= (exp<<10) + m; + } + return half(detail::binary, value); + } + + /// Next representable value. + /// **See also:** Documentation for [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW for infinite result from finite argument + /// \exception FE_UNDERFLOW for subnormal result + inline half nextafter(half from, half to) { + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs|tabs)) + return to; + if(!fabs) { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_&0x8000)+1); + } + unsigned int out = from.data_ + (((from.data_>>15)^static_cast( + (from.data_^(0x8000|(0x8000-(from.data_>>15))))<(to.data_^(0x8000|(0x8000-(to.data_>>15))))))<<1) - 1; + detail::raise(FE_OVERFLOW, fabs<0x7C00 && (out&0x7C00)==0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out&0x7C00)<0x400); + return half(detail::binary, out); + } + + /// Next representable value. + /// **See also:** Documentation for [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). + /// \param from value to compute next representable value for + /// \param to direction towards which to compute next value + /// \return next representable value after \a from in direction towards \a to + /// \exception FE_INVALID for signaling NaN + /// \exception FE_OVERFLOW for infinite result from finite argument + /// \exception FE_UNDERFLOW for subnormal result + inline half nexttoward(half from, long double to) { + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to))<<15)+1); + } + unsigned int out = from.data_ + (((from.data_>>15)^static_cast(lfrom 0x7C00; } + + /// Check if normal number. + /// **See also:** Documentation for [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). + /// \param arg number to check + /// \retval true if normal number + /// \retval false if either subnormal, zero, infinity or NaN + inline constexpr bool isnormal(half arg) { return ((arg.data_&0x7C00)!=0) & ((arg.data_&0x7C00)!=0x7C00); } + + /// Check sign. + /// **See also:** Documentation for [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). + /// \param arg number to check + /// \retval true for negative number + /// \retval false for positive number + inline constexpr bool signbit(half arg) { return (arg.data_&0x8000) != 0; } + + /// \} + /// \anchor compfunc + /// \name Comparison + /// \{ + + /// Quiet comparison for greater than. + /// **See also:** Documentation for [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater than \a y + /// \retval false else + inline constexpr bool isgreater(half x, half y) { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) > ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for greater equal. + /// **See also:** Documentation for [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x greater equal \a y + /// \retval false else + inline constexpr bool isgreaterequal(half x, half y) { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) >= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for less than. + /// **See also:** Documentation for [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less than \a y + /// \retval false else + inline constexpr bool isless(half x, half y) { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) < ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comparison for less equal. + /// **See also:** Documentation for [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). + /// \param x first operand + /// \param y second operand + /// \retval true if \a x less equal \a y + /// \retval false else + inline constexpr bool islessequal(half x, half y) { + return ((x.data_^(0x8000|(0x8000-(x.data_>>15))))+(x.data_>>15)) <= ((y.data_^(0x8000|(0x8000-(y.data_>>15))))+(y.data_>>15)) && !isnan(x) && !isnan(y); + } + + /// Quiet comarison for less or greater. + /// **See also:** Documentation for [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). + /// \param x first operand + /// \param y second operand + /// \retval true if either less or greater + /// \retval false else + inline constexpr bool islessgreater(half x, half y) { + return x.data_!=y.data_ && ((x.data_|y.data_)&0x7FFF) && !isnan(x) && !isnan(y); + } + + /// Quiet check if unordered. + /// **See also:** Documentation for [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). + /// \param x first operand + /// \param y second operand + /// \retval true if unordered (one or two NaN operands) + /// \retval false else + inline constexpr bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + + /// \} + /// \anchor casting + /// \name Casting + /// \{ + + /// Cast to or from half-precision floating-point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the default rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s returns the argument unmodified. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + /// \exception FE_INVALID if \a T is integer type and result is not representable as \a T + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template T half_cast(U arg) { return detail::half_caster::cast(arg); } + + /// Cast to or from half-precision floating-point number. + /// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values are converted + /// directly using the specified rounding mode, without any roundtrip over `float` that a `static_cast` would otherwise do. + /// + /// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any of the two types + /// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) results in a compiler + /// error and casting between [half](\ref half_float::half)s returns the argument unmodified. + /// \tparam T destination type (half or built-in arithmetic type) + /// \tparam R rounding mode to use. + /// \tparam U source type (half or built-in arithmetic type) + /// \param arg value to cast + /// \return \a arg converted to destination type + /// \exception FE_INVALID if \a T is integer type and result is not representable as \a T + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + template T half_cast(U arg) { return detail::half_caster::cast(arg); } + /// \} + + /// \} + /// \anchor errors + /// \name Error handling + /// \{ + + /// Clear exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). + /// \param excepts OR of exceptions to clear + /// \retval 0 all selected flags cleared successfully + inline int feclearexcept(int excepts) { detail::errflags() &= ~excepts; return 0; } + + /// Test exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). + /// \param excepts OR of exceptions to test + /// \return OR of selected exceptions if raised + inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + + /// Raise exception flags. + /// This raises the specified floating point exceptions and also invokes any additional automatic exception handling as + /// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). + /// \param excepts OR of exceptions to raise + /// \retval 0 all selected exceptions raised successfully + inline int feraiseexcept(int excepts) { detail::errflags() |= excepts; detail::raise(excepts); return 0; } + + /// Save exception flags. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). + /// \param flagp adress to store flag state at + /// \param excepts OR of flags to save + /// \retval 0 for success + inline int fegetexceptflag(int *flagp, int excepts) { *flagp = detail::errflags() & excepts; return 0; } + + /// Restore exception flags. + /// This only copies the specified exception state (including unset flags) without incurring any additional exception handling. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// + /// **See also:** Documentation for [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). + /// \param flagp adress to take flag state from + /// \param excepts OR of flags to restore + /// \retval 0 for success + inline int fesetexceptflag(const int *flagp, int excepts) { detail::errflags() = (detail::errflags()|(*flagp&excepts)) & (*flagp|~excepts); return 0; } + + /// Throw C++ exceptions based on set exception flags. + /// This function manually throws a corresponding C++ exception if one of the specified flags is set, + /// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. + /// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is disabled, + /// but in that case manual flag management is the only way to raise flags. + /// \param excepts OR of exceptions to test + /// \param msg error message to use for exception description + /// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set + /// \throw std::overflow_error if `FE_OVERFLOW` is selected and set + /// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set + /// \throw std::range_error if `FE_INEXACT` is selected and set + inline void fethrowexcept(int excepts, const char *msg = "") { + excepts &= detail::errflags(); + if(excepts & (FE_INVALID|FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); + } + /// \} +} + + +#undef HALF_UNUSED_NOERR +#undef constexpr_NOERR +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS + #pragma warning(pop) + #undef HALF_POP_WARNINGS +#endif \ No newline at end of file diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index eed3cf54..8529105d 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -385,6 +385,11 @@ `define LATENCY_FCVT 5 `endif +// Tensor Core Latency +`ifndef LATENCY_HMMA +`define LATENCY_HMMA 4 +`endif + // Icache Configurable Knobs ////////////////////////////////////////////////// // Cache Enable diff --git a/hw/rtl/fpu/VX_tensor_core.sv b/hw/rtl/fpu/VX_tensor_core.sv new file mode 100644 index 00000000..e69de29b diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv new file mode 100644 index 00000000..f9147c9d --- /dev/null +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -0,0 +1,35 @@ +`include "VX_fpu_define.vh" + +module VX_tensor_dpu #( + +) ( + input clk, + input reset, + + input valid_in, + input [3:0][1:0][31:0] A_tile, + input [1:0][3:0][31:0] B_tile, + input [3:0][3:0][31:0] C_tile, + + output valid_out, + output [3:0][3:0][31:0] D_tile +); + logic [3:0][3:0][31:0] result_hmma; + + always @(*) begin + dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma); + end + + + VX_shift_register #( + .DATAW (1 + $bits(D_tile)), + .DEPTH (`LATENCY_HMMA), + .RESETW (1) + ) shift_reg ( + .clk (clk), + .reset (reset), + .enable (1'b1), + .data_in ({valid_in, result_hmma}), + .data_out ({valid_out, D_tile}) + ); +endmodule diff --git a/hw/rtl/fpu/VX_tensor_tb.sv b/hw/rtl/fpu/VX_tensor_tb.sv new file mode 100644 index 00000000..9fa9fa41 --- /dev/null +++ b/hw/rtl/fpu/VX_tensor_tb.sv @@ -0,0 +1,28 @@ +`include "VX_fpu_define.vh" + +module VX_tensor_tb( + input clk, + input reset, + + input valid_in, + input [3:0][1:0][31:0] A_tile, + input [1:0][3:0][31:0] B_tile, + input [3:0][3:0][31:0] C_tile, + + output valid_out, + output [3:0][3:0][31:0] D_tile +); + + VX_tensor_dpu #() tensor_core ( + .clk(clk), + .reset(reset), + + .valid_in(valid_in), + .A_tile(A_tile), + .B_tile(B_tile), + .C_tile(C_tile), + + .valid_out(valid_out), + .D_tile(D_tile) + ); +endmodule diff --git a/hw/unittest/tensor/Makefile b/hw/unittest/tensor/Makefile new file mode 100644 index 00000000..021b7dcb --- /dev/null +++ b/hw/unittest/tensor/Makefile @@ -0,0 +1,89 @@ +DESTDIR ?= . +RTL_DIR = ../../rtl +DPI_DIR = $(abspath ../../dpi) +SIM_DIR = ../../../sim +THIRD_PARTY_DIR = $(abspath ../../../third_party) + +CONFIGS += +PARAMS += + +CXXFLAGS += -std=c++17 -Wall -Wextra -Wfatal-errors -Wno-array-bounds +CXXFLAGS += -fPIC -Wno-maybe-uninitialized +CXXFLAGS += -fcoroutines +CXXFLAGS += -I../../.. -I../../common -I../../../../sim/common +CXXFLAGS += -I/$(THIRD_PARTY_DIR)/softfloat/source/include +CXXFLAGS += -I/$(DPI_DIR) +CXXFLAGS += $(CONFIGS) + +LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a + +# control RTL debug tracing states +DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_BANK +DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_MSHR +DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_TAG +DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_DATA + +DBG_FLAGS += -DDEBUG_LEVEL=$(DEBUG) -DVCD_OUTPUT $(DBG_TRACE_FLAGS) + +RTL_PKGS = $(RTL_DIR)/VX_gpu_pkg.sv + +RTL_INCLUDE = -I$(RTL_DIR) -I$(DPI_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/fpu + +# SRCS = cachesim.cpp testbench.cpp +SRCS += $(DPI_DIR)/util_dpi.cpp +SRCS += $(DPI_DIR)/float_dpi.cpp +SRCS += $(SIM_DIR)/common/rvfloats.cpp +SRCS += ./main.cpp + +RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_core.sv +RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv + +TOP = VX_tensor_tb + +VL_FLAGS = --exe +VL_FLAGS += --language 1800-2009 # -Wall -Wpedantic # --assert +VL_FLAGS += -Wno-DECLFILENAME -Wno-REDEFMACRO +VL_FLAGS += --x-initial unique --x-assign unique +VL_FLAGS += -DSIMULATION -DSV_DPI +VL_FLAGS += $(CONFIGS) +VL_FLAGS += $(PARAMS) +VL_FLAGS += $(RTL_INCLUDE) +VL_FLAGS += $(RTL_PKGS) +VL_FLAGS += --cc $(TOP) --top-module $(TOP) +VL_FLAGS += --timing + +# Enable Verilator multithreaded simulation +THREADS ?= $(shell python -c 'import multiprocessing as mp; print(mp.cpu_count())') +VL_FLAGS += -j $(THREADS) +#VL_FLAGS += --threads $(THREADS) + +# Debugigng +ifdef DEBUG + VL_FLAGS += --trace --trace-structs $(DBG_FLAGS) + CXXFLAGS += -g -O0 $(DBG_FLAGS) +else + VL_FLAGS += -DNDEBUG + CXXFLAGS += -O2 -DNDEBUG +endif + +# Enable perf counters +ifdef PERF + VL_FLAGS += -DPERF_ENABLE + CXXFLAGS += -DPERF_ENABLE +endif + +PROJECT = tensor + +all: $(DESTDIR)/$(PROJECT) + +$(DESTDIR)/$(PROJECT): $(SRCS) $(RTL_SRCS) + verilator --build $(VL_FLAGS) $(SRCS) -CFLAGS '$(CXXFLAGS)' -LDFLAGS '$(LDFLAGS)' -o ../$@ + +run: $(DESTDIR)/$(PROJECT) + $(DESTDIR)/$(PROJECT) + +waves: trace.vcd + gtkwave -o trace.vcd + +clean: + rm -rf obj_dir $(DESTDIR)/$(PROJECT) diff --git a/hw/unittest/tensor/main.cpp b/hw/unittest/tensor/main.cpp new file mode 100644 index 00000000..ffc7c064 --- /dev/null +++ b/hw/unittest/tensor/main.cpp @@ -0,0 +1,197 @@ +// Copyright © 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vl_simulator.h" +#include "VVX_tensor_tb.h" +#include + +#include + +#define MAX_TICKS 20 + +#ifndef TRACE_START_TIME +#define TRACE_START_TIME 0ull +#endif + +#ifndef TRACE_STOP_TIME +#define TRACE_STOP_TIME -1ull +#endif + +#define CHECK(x) \ + do { \ + if (x) \ + break; \ + std::cout << "FAILED: " << #x << std::endl; \ + std::abort(); \ + } while (false) + +static uint64_t timestamp = 0; +static bool trace_enabled = false; +static uint64_t trace_start_time = TRACE_START_TIME; +static uint64_t trace_stop_time = TRACE_STOP_TIME; + +double sc_time_stamp() { + return timestamp; +} + +bool sim_trace_enabled() { + if (timestamp >= trace_start_time + && timestamp < trace_stop_time) + return true; + return trace_enabled; +} + +void sim_trace_enable(bool enable) { + trace_enabled = enable; +} + +using Device = VVX_tensor_tb; +using half_float::half; + +static_assert(sizeof(half) == 2); +uint32_t half2bits(half h) { + uint16_t half_bits; + memcpy(&half_bits, &h, sizeof(half)); + return half_bits; +} + +uint32_t float2bits(float f) { + uint32_t float_bits; + memcpy(&float_bits, &f, sizeof(f)); + return float_bits; +} + +float bits2float(uint32_t b) { + float f; + memcpy(&f, &b, sizeof(b)); + return f; +} + +// A is M * K, B is K * K * M, C is M * M, D is M * M +#define M 4 +#define K 2 + +// row, column +float A_tile[M][K]; +float B_tile[K][M]; +float C_tile[M][M]; +float D_tile[M][M]; + +void initialize_test_data() { + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < K; j += 1) { + A_tile[i][j] = (float) (i * K + j); + } + } + + for (int i = 0; i < K; i += 1) { + for (int j = 0; j < M; j += 1) { + B_tile[i][j] = (float) (j * K + i); + } + } + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < M; j += 1) { + C_tile[i][j] = (float) (i * j); + } + } +} + +void write_test_data(vl_simulator& sim) { + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < K; j += 1) { + int index = (i * K + j); + uint32_t A_bits = float2bits(A_tile[i][j]); + sim->A_tile[index] = A_bits; + } + } + + for (int i = 0; i < K; i += 1) { + for (int j = 0; j < M; j += 1) { + int index = (i * M + j); + uint32_t B_bits = float2bits(B_tile[i][j]); + sim->B_tile[index] = B_bits; + } + } + + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < M; j += 1) { + int index = (i * M + j); + uint32_t C_bits = float2bits(C_tile[i][j]); + sim->C_tile[index] = C_bits; + } + } +} + +void read_result(vl_simulator& sim) { + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < M; j += 1) { + int index = (i * M + j); + + uint32_t D_bits = sim->D_tile[index]; + float f = bits2float(D_bits); + D_tile[i][j] = f; + std::cout << f << " "; + } + std::cout << std::endl; + } +} + +void expected() { + for (int i = 0; i < M; i += 1) { + for (int j = 0; j < M; j += 1) { + float accum = C_tile[i][j]; + for (int k = 0; k < K; k += 1) { + accum += A_tile[i][k] * B_tile[k][j]; + } + + std::cout << accum << " "; + } + std::cout << std::endl; + } +} + +int main(int argc, char **argv) { + // Initialize Verilators variables + Verilated::commandArgs(argc, argv); + + vl_simulator sim; + + initialize_test_data(); + // run test + timestamp = sim.reset(0); + + + // advance clock + timestamp = sim.step(timestamp, 10); + sim->valid_in = 1; + write_test_data(sim); + timestamp = sim.step(timestamp, 2); + CHECK(sim->valid_out == 0); + sim->valid_in = 0; + timestamp = sim.step(timestamp, 2); + CHECK(sim->valid_out == 0); + timestamp = sim.step(timestamp, 2); + CHECK(sim->valid_out == 0); + timestamp = sim.step(timestamp, 2); + CHECK(sim->valid_out == 1); + read_result(sim); + timestamp = sim.step(timestamp, 2); + CHECK(sim->valid_out == 0); + + expected(); + + std::cout << "PASSED!" << std::endl; + std::cout << "Simulation time: " << std::dec << timestamp/2 << " cycles" << std::endl; + + return 0; +} \ No newline at end of file From f9b4509936e7828f69b6e989ee7e97af93fd53bc Mon Sep 17 00:00:00 2001 From: joshua Date: Wed, 20 Mar 2024 02:46:00 -0700 Subject: [PATCH 03/55] initial tensor core --- hw/rtl/VX_config.vh | 10 ++ hw/rtl/VX_define.vh | 5 +- hw/rtl/core/VX_commit.sv | 26 ++++- hw/rtl/core/VX_core.sv | 14 +++ hw/rtl/core/VX_decode.sv | 6 + hw/rtl/core/VX_dispatch.sv | 35 ++++++ hw/rtl/core/VX_execute.sv | 19 +++- hw/rtl/core/VX_ibuffer.sv | 23 +++- hw/rtl/core/VX_issue.sv | 6 + hw/rtl/core/VX_tensor_core.sv | 15 +++ hw/rtl/core/VX_uop_sequencer.sv | 187 ++++++++++++++++++++++++++++++++ hw/rtl/fpu/VX_tensor_core.sv | 0 12 files changed, 338 insertions(+), 8 deletions(-) create mode 100644 hw/rtl/core/VX_tensor_core.sv create mode 100644 hw/rtl/core/VX_uop_sequencer.sv delete mode 100644 hw/rtl/fpu/VX_tensor_core.sv diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index 8529105d..e8bb56fc 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -40,6 +40,10 @@ `define EXT_F_ENABLE `endif +`ifndef EXT_T_DISABLE +`define EXT_T_ENABLE +`endif + `ifndef XLEN_32 `ifndef XLEN_64 `define XLEN_32 @@ -618,6 +622,12 @@ `define EXT_F_ENABLED 0 `endif +`ifdef EXT_T_ENABLE + `define EXT_T_ENABLED 1 +`else + `define EXT_T_ENABLED 0 +`endif + `ifdef EXT_M_ENABLE `define EXT_M_ENABLED 1 `else diff --git a/hw/rtl/VX_define.vh b/hw/rtl/VX_define.vh index 9ddeeeea..bb96a149 100644 --- a/hw/rtl/VX_define.vh +++ b/hw/rtl/VX_define.vh @@ -58,8 +58,9 @@ `define EX_LSU 1 `define EX_SFU 2 `define EX_FPU (`EX_SFU + `EXT_F_ENABLED) +`define EX_TENSOR (`EX_FPU + `EXT_T_ENABLED) -`define NUM_EX_UNITS (3 + `EXT_F_ENABLED) +`define NUM_EX_UNITS (3 + `EXT_F_ENABLED + `EXT_T_ENABLED) `define EX_BITS `CLOG2(`NUM_EX_UNITS) `define EX_WIDTH `UP(`EX_BITS) @@ -253,6 +254,8 @@ `define INST_SFU_IS_WCTL(op) (op <= 5) `define INST_SFU_IS_CSR(op) (op >= 6 && op <= 8) +`define INST_TENSOR_HMMA 4'b0000 + /////////////////////////////////////////////////////////////////////////////// // non-cacheable tag bits diff --git a/hw/rtl/core/VX_commit.sv b/hw/rtl/core/VX_commit.sv index 09667d11..227104df 100644 --- a/hw/rtl/core/VX_commit.sv +++ b/hw/rtl/core/VX_commit.sv @@ -27,6 +27,10 @@ module VX_commit import VX_gpu_pkg::*; #( `endif VX_commit_if.slave sfu_commit_if [`ISSUE_WIDTH], +`ifdef EXT_T_ENABLE + VX_commit_if.slave tensor_commit_if [`ISSUE_WIDTH], +`endif + // outputs VX_writeback_if.master writeback_if [`ISSUE_WIDTH], VX_commit_csr_if.master commit_csr_if, @@ -65,6 +69,9 @@ module VX_commit import VX_gpu_pkg::*; #( sfu_commit_if[i].valid, `ifdef EXT_F_ENABLE fpu_commit_if[i].valid, + `endif + `ifdef EXT_T_ENABLE + tensor_commit_if[i].valid, `endif alu_commit_if[i].valid, lsu_commit_if[i].valid @@ -73,6 +80,9 @@ module VX_commit import VX_gpu_pkg::*; #( sfu_commit_if[i].ready, `ifdef EXT_F_ENABLE fpu_commit_if[i].ready, + `endif + `ifdef EXT_T_ENABLE + tensor_commit_if[i].ready, `endif alu_commit_if[i].ready, lsu_commit_if[i].ready @@ -81,6 +91,9 @@ module VX_commit import VX_gpu_pkg::*; #( sfu_commit_if[i].data, `ifdef EXT_F_ENABLE fpu_commit_if[i].data, + `endif + `ifdef EXT_T_ENABLE + tensor_commit_if[i].data, `endif alu_commit_if[i].data, lsu_commit_if[i].data @@ -157,7 +170,18 @@ module VX_commit import VX_gpu_pkg::*; #( // Committed instructions - wire [`ISSUE_WIDTH-1:0] committed = commit_fire & commit_eop; + // temporary hack to not underflow the pending instructions buffer + wire [`ISSUE_WIDTH-1:0] final_hmma; +`ifdef EXT_T_ENABLE + for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + assign final_hmma[i] = ~(tensor_commit_if[i].ready && tensor_commit_if[i].valid) || (tensor_commit_if[i].data.rd == `NR_BITS'(32 + 23)); + end +`else + assign final_hmma = '1; +`endif + + + wire [`ISSUE_WIDTH-1:0] committed = (commit_fire & commit_eop) & final_hmma; VX_pipe_register #( .DATAW (`ISSUE_WIDTH * (1 + `NW_WIDTH)), diff --git a/hw/rtl/core/VX_core.sv b/hw/rtl/core/VX_core.sv index dde085a8..41f54b95 100644 --- a/hw/rtl/core/VX_core.sv +++ b/hw/rtl/core/VX_core.sv @@ -65,6 +65,10 @@ module VX_core import VX_gpu_pkg::*; #( `ifdef EXT_F_ENABLE VX_dispatch_if fpu_dispatch_if[`ISSUE_WIDTH](); VX_commit_if fpu_commit_if[`ISSUE_WIDTH](); +`endif +`ifdef EXT_T_ENABLE + VX_dispatch_if tensor_dispatch_if[`ISSUE_WIDTH](); + VX_commit_if tensor_commit_if[`ISSUE_WIDTH](); `endif VX_dispatch_if sfu_dispatch_if[`ISSUE_WIDTH](); VX_commit_if sfu_commit_if[`ISSUE_WIDTH](); @@ -172,6 +176,9 @@ module VX_core import VX_gpu_pkg::*; #( .lsu_dispatch_if(lsu_dispatch_if), `ifdef EXT_F_ENABLE .fpu_dispatch_if(fpu_dispatch_if), + `endif + `ifdef EXT_T_ENABLE + .tensor_dispatch_if(tensor_dispatch_if), `endif .sfu_dispatch_if(sfu_dispatch_if) ); @@ -197,6 +204,10 @@ module VX_core import VX_gpu_pkg::*; #( .fpu_dispatch_if(fpu_dispatch_if), .fpu_commit_if (fpu_commit_if), `endif + `ifdef EXT_T_ENABLE + .tensor_dispatch_if (tensor_dispatch_if), + .tensor_commit_if (tensor_commit_if), + `endif .commit_csr_if (commit_csr_if), .sched_csr_if (sched_csr_if), @@ -227,6 +238,9 @@ module VX_core import VX_gpu_pkg::*; #( .fpu_commit_if (fpu_commit_if), `endif .sfu_commit_if (sfu_commit_if), + `ifdef EXT_T_ENABLE + .tensor_commit_if (tensor_commit_if), + `endif .writeback_if (writeback_if), diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index 42cd7ffc..1d38c0b2 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -533,6 +533,12 @@ module VX_decode #( default:; endcase end + `ifdef EXT_T_ENABLE + `INST_EXT4: begin + ex_type = `EX_TENSOR; + op_type = `INST_TENSOR_HMMA; + end + `endif default:; endcase end diff --git a/hw/rtl/core/VX_dispatch.sv b/hw/rtl/core/VX_dispatch.sv index 61d857c5..b8288529 100644 --- a/hw/rtl/core/VX_dispatch.sv +++ b/hw/rtl/core/VX_dispatch.sv @@ -31,6 +31,9 @@ module VX_dispatch import VX_gpu_pkg::*; #( VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH], `ifdef EXT_F_ENABLE VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH], +`endif +`ifdef EXT_T_ENABLE + VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH], `endif VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH] ); @@ -139,6 +142,35 @@ module VX_dispatch import VX_gpu_pkg::*; #( end `endif + // Tensor Core dispatch + +`ifdef EXT_T_ENABLE + + VX_operands_if tensor_operands_if[`ISSUE_WIDTH](); + + for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + assign tensor_operands_if[i].valid = operands_if[i].valid && (operands_if[i].data.ex_type == `EX_TENSOR); + assign tensor_operands_if[i].data = operands_if[i].data; + + `RESET_RELAY (tensor_reset, reset); + + VX_elastic_buffer #( + .DATAW (DATAW), + .SIZE (2), + .OUT_REG (2) + ) tensor_buffer ( + .clk (clk), + .reset (tensor_reset), + .valid_in (tensor_operands_if[i].valid), + .ready_in (tensor_operands_if[i].ready), + .data_in (`TO_DISPATCH_DATA(tensor_operands_if[i].data, last_active_tid[i])), + .data_out (tensor_dispatch_if[i].data), + .valid_out (tensor_dispatch_if[i].valid), + .ready_out (tensor_dispatch_if[i].ready) + ); + end +`endif + // SFU dispatch VX_operands_if sfu_operands_if[`ISSUE_WIDTH](); @@ -171,6 +203,9 @@ module VX_dispatch import VX_gpu_pkg::*; #( || (lsu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_LSU)) `ifdef EXT_F_ENABLE || (fpu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_FPU)) + `endif + `ifdef EXT_T_ENABLE + || (tensor_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_TENSOR)) `endif || (sfu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_SFU)); end diff --git a/hw/rtl/core/VX_execute.sv b/hw/rtl/core/VX_execute.sv index f1ea2675..cdf17a31 100644 --- a/hw/rtl/core/VX_execute.sv +++ b/hw/rtl/core/VX_execute.sv @@ -41,7 +41,7 @@ module VX_execute import VX_gpu_pkg::*; #( VX_dispatch_if.slave fpu_dispatch_if [`ISSUE_WIDTH], VX_commit_if.master fpu_commit_if [`ISSUE_WIDTH], `endif - + VX_dispatch_if.slave alu_dispatch_if [`ISSUE_WIDTH], VX_commit_if.master alu_commit_if [`ISSUE_WIDTH], VX_branch_ctl_if.master branch_ctl_if [`NUM_ALU_BLOCKS], @@ -53,6 +53,11 @@ module VX_execute import VX_gpu_pkg::*; #( VX_commit_if.master sfu_commit_if [`ISSUE_WIDTH], VX_warp_ctl_if.master warp_ctl_if, +`ifdef EXT_T_ENABLE + VX_dispatch_if.slave tensor_dispatch_if [`ISSUE_WIDTH], + VX_commit_if.master tensor_commit_if [`ISSUE_WIDTH], +`endif + // simulation helper signals output wire sim_ebreak ); @@ -127,6 +132,18 @@ module VX_execute import VX_gpu_pkg::*; #( .commit_if (sfu_commit_if) ); +`ifdef EXT_T_ENABLE + VX_tensor_core #( + + ) tensor_core ( + .clk(clk), + .reset(reset), + + .dispatch_if(tensor_dispatch_if), + .commit_if(tensor_commit_if) + ); +`endif + // simulation helper signal to get RISC-V tests Pass/Fail status assign sim_ebreak = alu_dispatch_if[0].valid && alu_dispatch_if[0].ready && alu_dispatch_if[0].data.wis == 0 diff --git a/hw/rtl/core/VX_ibuffer.sv b/hw/rtl/core/VX_ibuffer.sv index b465c195..c81d48c4 100644 --- a/hw/rtl/core/VX_ibuffer.sv +++ b/hw/rtl/core/VX_ibuffer.sv @@ -36,6 +36,8 @@ module VX_ibuffer import VX_gpu_pkg::*; #( assign decode_if.ready = ibuf_ready_in[decode_isw]; + VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH]; + for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin VX_elastic_buffer #( .DATAW (DATAW), @@ -62,13 +64,24 @@ module VX_ibuffer import VX_gpu_pkg::*; #( decode_if.data.rs1, decode_if.data.rs2, decode_if.data.rs3}), - .data_out(ibuffer_if[i].data), - .valid_out (ibuffer_if[i].valid), - .ready_out(ibuffer_if[i].ready) - ); + + .data_out (uop_sequencer_if[i].data), + .valid_out (uop_sequencer_if[i].valid), + .ready_out (uop_sequencer_if[i].ready) + ); + `ifndef L1_ENABLE - assign decode_if.ibuf_pop[i] = ibuffer_if[i].valid && ibuffer_if[i].ready; + assign decode_if.ibuf_pop[i] = uop_sequencer_if[i].valid && uop_sequencer_if[i].ready; `endif + + VX_uop_sequencer uop_sequencer ( + .clk(clk), + .reset(reset), + + .uop_sequencer_if(uop_sequencer_if[i]), + .ibuffer_if(ibuffer_if[i]) + ); + end endmodule diff --git a/hw/rtl/core/VX_issue.sv b/hw/rtl/core/VX_issue.sv index 1ba4ca28..614451c2 100644 --- a/hw/rtl/core/VX_issue.sv +++ b/hw/rtl/core/VX_issue.sv @@ -33,6 +33,9 @@ module VX_issue #( VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH], `ifdef EXT_F_ENABLE VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH], +`endif +`ifdef EXT_T_ENABLE + VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH], `endif VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH] ); @@ -92,6 +95,9 @@ module VX_issue #( .lsu_dispatch_if(lsu_dispatch_if), `ifdef EXT_F_ENABLE .fpu_dispatch_if(fpu_dispatch_if), + `endif + `ifdef EXT_T_ENABLE + .tensor_dispatch_if(tensor_dispatch_if), `endif .sfu_dispatch_if(sfu_dispatch_if) ); diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv new file mode 100644 index 00000000..c31f3f9f --- /dev/null +++ b/hw/rtl/core/VX_tensor_core.sv @@ -0,0 +1,15 @@ +`include "VX_fpu_define.vh" + +module VX_tensor_core #( + +) ( + input clk, + input reset, + + VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], + VX_commit_if.master commit_if [`ISSUE_WIDTH] +); + `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32")); + `UNUSED_VAR(clk); + `UNUSED_VAR(reset); +endmodule diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv new file mode 100644 index 00000000..f18e473e --- /dev/null +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -0,0 +1,187 @@ +`include "VX_define.vh" + +`define FREG(x) {1'b1, `NRI_BITS'(`CLOG2(x))} + +module VX_uop_sequencer import VX_gpu_pkg::*; ( + input clk, + input reset, + + VX_ibuffer_if.slave uop_sequencer_if, + VX_ibuffer_if.master ibuffer_if +); + +`ifdef EXT_T_ENABLE + localparam UOP_TABLE_SIZE = 64; + localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE); + + localparam NEXT = 2'b00; + localparam FINISH = 2'b01; + + localparam UBR_BITS = 2; + + // uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3 + localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4); + localparam IBUFFER_IF_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `XLEN + 1 + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + `XLEN + (`NR_BITS * 4); + + logic [UOP_TABLE_WIDTH-1:0] uop; + + // reserve space at start of table for more uop sequences + localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0); + localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8); + /* + localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9); + localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10); + localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11); + localparam HMMA_SET0_STEP2_1 = UPC_BITS'(12); + localparam HMMA_SET0_STEP3_0 = UPC_BITS'(13); + localparam HMMA_SET0_STEP3_1 = UPC_BITS'(14); + + localparam HMMA_SET1_STEP0_0 = UPC_BITS'(15); + localparam HMMA_SET1_STEP0_1 = UPC_BITS'(16); + localparam HMMA_SET1_STEP1_0 = UPC_BITS'(17); + localparam HMMA_SET1_STEP1_1 = UPC_BITS'(18); + localparam HMMA_SET1_STEP2_0 = UPC_BITS'(19); + localparam HMMA_SET1_STEP2_1 = UPC_BITS'(20); + localparam HMMA_SET1_STEP3_0 = UPC_BITS'(21); + localparam HMMA_SET1_STEP3_1 = UPC_BITS'(22); + + localparam HMMA_SET2_STEP0_0 = UPC_BITS'(23); + localparam HMMA_SET2_STEP0_1 = UPC_BITS'(24); + localparam HMMA_SET2_STEP1_0 = UPC_BITS'(25); + localparam HMMA_SET2_STEP1_1 = UPC_BITS'(26); + localparam HMMA_SET2_STEP2_0 = UPC_BITS'(27); + localparam HMMA_SET2_STEP2_1 = UPC_BITS'(28); + localparam HMMA_SET2_STEP3_0 = UPC_BITS'(29); + localparam HMMA_SET2_STEP3_1 = UPC_BITS'(30); + + localparam HMMA_SET3_STEP0_0 = UPC_BITS'(31); + localparam HMMA_SET3_STEP0_1 = UPC_BITS'(32); + localparam HMMA_SET3_STEP1_0 = UPC_BITS'(33); + localparam HMMA_SET3_STEP1_1 = UPC_BITS'(34); + localparam HMMA_SET3_STEP2_0 = UPC_BITS'(35); + localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36); + localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37); + localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38); + */ + // register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C + + + + always @(*) begin + case (upc) + HMMA_SET0_STEP0_0: begin + uop = { + NEXT, + HMMA_SET0_STEP0_1, + `EX_BITS'(`EX_TENSOR), + `INST_OP_BITS'(0), // denotes that the first half is being computed + `INST_MOD_BITS'(0), // field is unused for HMMA + 1'b1, // write back + 1'b0, // don't use PC + 1'b0, // don't use immediate + 32'b0, // PC is unused - TODO: don't send a bogus PC down the pipeline as it is very confusing in trace + 32'b0, // immediate is unused + `FREG(16), // rd=f16 + `FREG(0), // rs1=f0, + `FREG(8), // rs2=f8 + `FREG(16) // rs3=f16 + }; + end + HMMA_SET0_STEP0_1: begin + uop = { + FINISH, + HMMA_SET0_STEP0_0, + `EX_BITS'(`EX_TENSOR), + `INST_OP_BITS'(1), // denotes that the second half is being computed + `INST_MOD_BITS'(0), // field is unused for HMMA + 1'b1, // write back + 1'b0, // don't use PC + 1'b0, // don't use immediate + 32'b0, // PC is unused - TODO: don't send a bogus PC down the pipeline as it is very confusing in trace + 32'b0, // immediate is unused + `FREG(17), // rd=f17 + `FREG(1), // rs1=f1, + `FREG(9), // rs2=f9 + `FREG(17) // rs3=f17 + }; + end + default: begin + uop = '0; + end + endcase + end + + logic [UPC_BITS-1:0] upc, upc_r, upc_n; + + logic [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS]; + logic [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS]; + + logic uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready; + logic uop_start = ~use_uop_1d && use_uop; + logic uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready; + logic use_uop, use_uop_1d; + + // merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint, but conceptually they should be linked + always @(*) begin + use_uop = uop_sequencer_if.valid && uop_sequencer_if.data.ex_type == `EX_TENSOR; + + if (uop_start) begin + // 1st cycle of microcoded operation, use op_type to determine entry point into microcode table + upc_n = UPC_BITS'(uop_sequencer_if.data.op_type); + end + else begin + upc_n = upc; + end + + if (uop_fire) begin + upc_n = next_upc; + end + end + + always @(*) begin + if (uop_start) begin + // 1st cycle of microcoded operation, use op_type to determine entry point into microcode table + upc = UPC_BITS'(uop_sequencer_if.data.op_type); + end + else begin + upc = upc_r; + end + end + + // copy UUID, wis, tmask from microcoded instruction + logic [IBUFFER_IF_DATAW-1:0] ibuffer_output = { + uop_sequencer_if.data.uuid, + uop_sequencer_if.data.wis, + uop_sequencer_if.data.tmask, + uop[UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS-1:0] + }; + + assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid; + assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready; + assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data; + + always @(posedge clk) begin + if (reset) begin + upc_r <= '0; + use_uop_1d <= '0; + end + else begin + upc_r <= upc_n; + if (uop_finish) begin + use_uop_1d <= 1'b0; // allow microcoded instructions to start immediately after eachother + end + else begin + use_uop_1d <= use_uop; + end + end + end +`else + `UNUSED_VAR(clk); + `UNUSED_VAR(reset); + assign ibuffer_if.valid = uop_sequencer_if.valid; + assign uop_sequencer_if.ready = ibuffer_if.ready; + assign ibuffer_if.data = uop_sequencer_if.data; +`endif + + +endmodule diff --git a/hw/rtl/fpu/VX_tensor_core.sv b/hw/rtl/fpu/VX_tensor_core.sv deleted file mode 100644 index e69de29b..00000000 From b254281295cecf7f90c2d6061d4a8dab69b86bff Mon Sep 17 00:00:00 2001 From: joshua Date: Thu, 21 Mar 2024 01:29:38 -0700 Subject: [PATCH 04/55] initial tcore impl --- hw/rtl/VX_config.vh | 2 +- hw/rtl/core/VX_tensor_core.sv | 289 +++++++++++++++++++++++++++++++- hw/rtl/core/VX_uop_sequencer.sv | 14 +- hw/rtl/fpu/VX_tensor_dpu.sv | 4 +- hw/rtl/fpu/VX_tensor_tb.sv | 2 + 5 files changed, 303 insertions(+), 8 deletions(-) diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index e8bb56fc..d741da8d 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -391,7 +391,7 @@ // Tensor Core Latency `ifndef LATENCY_HMMA -`define LATENCY_HMMA 4 +`define LATENCY_HMMA 8 `endif // Icache Configurable Knobs ////////////////////////////////////////////////// diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index c31f3f9f..a9419c66 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -10,6 +10,291 @@ module VX_tensor_core #( VX_commit_if.master commit_if [`ISSUE_WIDTH] ); `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32")); - `UNUSED_VAR(clk); - `UNUSED_VAR(reset); + + for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + VX_tensor_core_warp #( + .ISW(i) + ) tensor_core ( + .clk(clk), + .reset(reset), + + .dispatch_if(dispatch_if[i]), + .commit_if(commit_if[i]) + ); + end + +endmodule + +module VX_tensor_core_warp import VX_gpu_pkg::*; #( + parameter ISW +) ( + input clk, + input reset, + + VX_dispatch_if.slave dispatch_if, + VX_commit_if.master commit_if +); + logic [1:0] step = 2'(dispatch_if.data.op_type); + logic [3:0] octet_results_valid; + logic [3:0] octet_results_ready; + logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; + logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; + + for (genvar i = 0; i < 4; ++i) begin + logic [7:0][31:0] octet_A = { + dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4] + }; + logic [7:0][31:0] octet_B = { + dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4] + }; + logic [7:0][31:0] octet_C = { + dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4] + }; + + logic [3:0][3:0][31:0] octet_D; + logic result_valid; + logic result_ready; + VX_tensor_octet #( + + ) octet ( + .clk(clk), + .reset(reset), + + .A_in(octet_A), + .B_in(octet_B), + .C_in(octet_C), + .operands_valid(dispatch_if.valid), + .operands_ready(dispatch_if.ready), + + .step(step), + + .D_out(octet_D), + .result_valid(result_valid), + .result_ready(result_ready) + ); + + // these should always be in lockstep + assign octet_results_valid[i] = result_valid; + assign result_ready = octet_results_ready[i]; + + assign wb_data_0[4*i+0] = octet_D[0][0]; + assign wb_data_0[4*i+1] = octet_D[1][0]; + assign wb_data_0[4*i+2] = octet_D[0][2]; + assign wb_data_0[4*i+3] = octet_D[1][2]; + + assign wb_data_1[4*i+0] = octet_D[0][1]; + assign wb_data_1[4*i+1] = octet_D[1][1]; + assign wb_data_1[4*i+2] = octet_D[0][3]; + assign wb_data_1[4*i+3] = octet_D[1][3]; + + assign wb_data_0[4*i+16+0] = octet_D[2][0]; + assign wb_data_0[4*i+16+1] = octet_D[3][0]; + assign wb_data_0[4*i+16+2] = octet_D[2][2]; + assign wb_data_0[4*i+16+3] = octet_D[3][2]; + + assign wb_data_1[4*i+16+0] = octet_D[2][1]; + assign wb_data_1[4*i+16+1] = octet_D[3][1]; + assign wb_data_1[4*i+16+2] = octet_D[2][3]; + assign wb_data_1[4*i+16+3] = octet_D[3][3]; + end + + /* commit_if.data_t parts that we need to keep around: + - uuid + - wid + - tmask + - PC + - wb + - rd + */ + + localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; + + wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready; + wire commit_if_fire = commit_if.valid && commit_if.ready; + wire [DATAW-1:0] dispatch_if_data_enq = { + dispatch_if.data.uuid, + wis_to_wid(dispatch_if.data.wis, ISW), + dispatch_if.data.tmask, + dispatch_if.data.PC, + dispatch_if.data.wb, + dispatch_if.data.rd + }; + + wire [DATAW-1:0] dispatch_if_data_deq; + + // this is probably a little oversized + VX_fifo_queue #( + .DATAW(DATAW), + .DEPTH(8) + ) pending_uops ( + .clk(clk), + .reset(reset), + .push(dispatch_if_fire), + .pop(commit_if_fire), + .data_in(dispatch_if_data_enq), + .data_out(dispatch_if_data_deq), + `UNUSED_PIN(empty), + `UNUSED_PIN(alm_empty), + `UNUSED_PIN(full), // should be impossible to overflow + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + logic subcommit, subcommit_n; + logic all_valid = (& octet_results_valid); + assign commit_if.valid = all_valid; + + localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; + logic [COMMIT_DATAW-1:0] commit_if_data = { + dispatch_if_data_deq, + subcommit == 1'b0 ? wb_data_0 : wb_data_1, + 1'b0, + 1'b1, + 1'b1 + }; + + assign commit_if.data = commit_if_data; + + always @(*) begin + subcommit_n = commit_if_fire ? ~subcommit : subcommit; + if (commit_if_fire && subcommit == 1'b1) begin + octet_results_ready = '1; + end + else begin + octet_results_ready = '0; + end + end + + always @(posedge clk) begin + if (reset) begin + subcommit <= '0; + end + else begin + subcommit <= subcommit_n; + end + end + +endmodule + +module VX_tensor_octet #( + +) ( + input clk, + input reset, + + input [7:0][31:0] A_in, + input [7:0][31:0] B_in, + input [7:0][31:0] C_in, + input operands_valid, // we have to backpressure due to there potentially being contention over commit + output operands_ready, + + input [1:0] step, + + output [3:0][3:0][31:0] D_out, + output result_valid, + input result_ready +); + // 512 bits/octet * 4 octets per warp + logic [3:0][31:0] A_buffer, A_buffer_n; + logic [3:0][31:0] B_buffer, B_buffer_n; + logic [7:0][31:0] C_buffer, C_buffer_n; + + // half the inputs are buffered, half are not (instead coming straight from operand bus) + // unlike the real tensor core, the banks are only 32 bit rather than 64 bit + logic [3:0][31:0] A_half; + logic [3:0][31:0] B_half; + logic [7:0][31:0] C_half; + always @(*) begin + case (step) + 2'b00: begin + A_half = { A_in[1:0], A_in[5:4] }; + B_half = B_in[3:0]; + end + 2'b01: begin + A_half = { A_in[3:2], A_in[7:6] }; + B_half = B_in[3:0]; + end + 2'b10: begin + A_half = { A_in[1:0], A_in[5:4] }; + B_half = B_in[7:4]; + end + 2'b11: begin + A_half = { A_in[3:2], A_in[7:6] }; + B_half = B_in[7:4]; + end + endcase + C_half = C_in; + end + + logic substep; + logic substep_n = (operands_ready && operands_valid) ? ~substep : substep; + + always @(*) begin + A_buffer_n = A_buffer; + B_buffer_n = B_buffer; + C_buffer_n = C_buffer; + + if (substep == 1'b0) begin + A_buffer_n = A_half; + B_buffer_n = B_half; + C_buffer_n = C_half; + end + end + + always @(posedge clk) begin + if (reset) begin + A_buffer <= '0; + B_buffer <= '0; + C_buffer <= '0; + substep <= '0; + end + else begin + A_buffer <= A_buffer_n; + B_buffer <= B_buffer_n; + C_buffer <= C_buffer_n; + substep <= substep_n; + end + end + + + wire stall = result_valid && ~result_ready; + assign operands_ready = ~stall; + + logic [3:0][1:0][31:0] A_tile = { + { A_buffer[0], A_half[0] }, + { A_buffer[1], A_half[1] }, + { A_buffer[2], A_half[2] }, + { A_buffer[3], A_half[3] } + }; + logic [1:0][3:0][31:0] B_tile = { + B_buffer, B_half + }; + logic [3:0][3:0][31:0] C_tile; + + always @(*) begin + C_tile = { + C_buffer[0], C_half[0], C_buffer[1], C_half[1], + C_buffer[2], C_half[2], C_buffer[3], C_half[3], + C_buffer[4], C_half[4], C_buffer[5], C_half[5], + C_buffer[6], C_half[6], C_buffer[7], C_half[7] + }; + end + + wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); + VX_tensor_dpu #( + + ) dpu ( + .clk(clk), + .reset(reset), + + .stall(stall), + + .valid_in(do_hmma), + .A_tile(A_tile), + .B_tile(B_tile), + .C_tile(C_tile), + + .valid_out(result_valid), + .D_tile(D_out) + ); endmodule diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index f18e473e..c57ea0ba 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -74,8 +74,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), - `INST_OP_BITS'(0), // denotes that the first half is being computed - `INST_MOD_BITS'(0), // field is unused for HMMA + `INST_OP_BITS'(0), // denotes that the first step is being computed + `INST_MOD_BITS'(0), // denotes that this is first substep (tensor core also tracks this) 1'b1, // write back 1'b0, // don't use PC 1'b0, // don't use immediate @@ -92,8 +92,8 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), - `INST_OP_BITS'(1), // denotes that the second half is being computed - `INST_MOD_BITS'(0), // field is unused for HMMA + `INST_OP_BITS'(0), // denotes that the first step is being computed + `INST_MOD_BITS'(1), // denotes that this is first substep (tensor core also tracks this) 1'b1, // write back 1'b0, // don't use PC 1'b0, // don't use immediate @@ -161,6 +161,12 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data; always @(posedge clk) begin + + if (use_uop) begin + $display("unexpectedly used uop at %d", $time); + end + + if (reset) begin upc_r <= '0; use_uop_1d <= '0; diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index f9147c9d..8108790c 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -6,6 +6,8 @@ module VX_tensor_dpu #( input clk, input reset, + input stall, + input valid_in, input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, @@ -28,7 +30,7 @@ module VX_tensor_dpu #( ) shift_reg ( .clk (clk), .reset (reset), - .enable (1'b1), + .enable (~stall), .data_in ({valid_in, result_hmma}), .data_out ({valid_out, D_tile}) ); diff --git a/hw/rtl/fpu/VX_tensor_tb.sv b/hw/rtl/fpu/VX_tensor_tb.sv index 9fa9fa41..4a6076b0 100644 --- a/hw/rtl/fpu/VX_tensor_tb.sv +++ b/hw/rtl/fpu/VX_tensor_tb.sv @@ -17,6 +17,8 @@ module VX_tensor_tb( .clk(clk), .reset(reset), + .stall(1'b0), + .valid_in(valid_in), .A_tile(A_tile), .B_tile(B_tile), From e16584ddd9f9d15b2ee72ef6101b8f4b3c7cd0ba Mon Sep 17 00:00:00 2001 From: joshua Date: Wed, 27 Mar 2024 00:26:04 -0700 Subject: [PATCH 05/55] bleh still not work --- hw/rtl/core/VX_commit.sv | 19 ++++- hw/rtl/core/VX_ibuffer.sv | 2 +- hw/rtl/core/VX_operands.sv | 21 ++++++ hw/rtl/core/VX_tensor_core.sv | 20 ++--- hw/rtl/core/VX_tensor_ucode.vh | 96 ++++++++++++++++++++++++ hw/rtl/core/VX_uop_sequencer.sv | 72 ++++++------------ hw/rtl/core/generate_ucode.py | 81 ++++++++++++++++++++ tests/kernel/tensor/Makefile | 8 ++ tests/kernel/tensor/check_correctness.py | 94 +++++++++++++++++++++++ tests/kernel/tensor/create_test_case.py | 29 +++++++ tests/kernel/tensor/main.cpp | 96 ++++++++++++++++++++++++ tests/kernel/tensor/test_data.h | 11 +++ 12 files changed, 485 insertions(+), 64 deletions(-) create mode 100644 hw/rtl/core/VX_tensor_ucode.vh create mode 100644 hw/rtl/core/generate_ucode.py create mode 100644 tests/kernel/tensor/Makefile create mode 100644 tests/kernel/tensor/check_correctness.py create mode 100644 tests/kernel/tensor/create_test_case.py create mode 100644 tests/kernel/tensor/main.cpp create mode 100644 tests/kernel/tensor/test_data.h diff --git a/hw/rtl/core/VX_commit.sv b/hw/rtl/core/VX_commit.sv index 227104df..f0d94925 100644 --- a/hw/rtl/core/VX_commit.sv +++ b/hw/rtl/core/VX_commit.sv @@ -52,6 +52,7 @@ module VX_commit import VX_gpu_pkg::*; #( wire [`ISSUE_WIDTH-1:0][`NW_WIDTH-1:0] commit_wid; wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask; wire [`ISSUE_WIDTH-1:0] commit_eop; + wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel; for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin @@ -101,7 +102,7 @@ module VX_commit import VX_gpu_pkg::*; #( .data_out (commit_if[i].data), .valid_out (commit_if[i].valid), .ready_out (commit_if[i].ready), - `UNUSED_PIN (sel_out) + .sel_out (commit_sel[i]) ); assign commit_fire[i] = commit_if[i].valid && commit_if[i].ready; @@ -171,10 +172,24 @@ module VX_commit import VX_gpu_pkg::*; #( // Committed instructions // temporary hack to not underflow the pending instructions buffer + // relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions, + // so probably want to change this at some point + // (i.e. pass a "don't count this towards pending instructions" signal down the pipeline) + logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n; wire [`ISSUE_WIDTH-1:0] final_hmma; `ifdef EXT_T_ENABLE for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin - assign final_hmma[i] = ~(tensor_commit_if[i].ready && tensor_commit_if[i].valid) || (tensor_commit_if[i].data.rd == `NR_BITS'(32 + 23)); + assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i]; + assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0); + end + + always @(posedge clk) begin + if (reset) begin + hmma_ctr <= '0; + end + else begin + hmma_ctr <= hmma_ctr_n; + end end `else assign final_hmma = '1; diff --git a/hw/rtl/core/VX_ibuffer.sv b/hw/rtl/core/VX_ibuffer.sv index c81d48c4..b8dc2a36 100644 --- a/hw/rtl/core/VX_ibuffer.sv +++ b/hw/rtl/core/VX_ibuffer.sv @@ -36,7 +36,7 @@ module VX_ibuffer import VX_gpu_pkg::*; #( assign decode_if.ready = ibuf_ready_in[decode_isw]; - VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH]; + VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH](); for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin VX_elastic_buffer #( diff --git a/hw/rtl/core/VX_operands.sv b/hw/rtl/core/VX_operands.sv index 3747502f..3fe1d7f4 100644 --- a/hw/rtl/core/VX_operands.sv +++ b/hw/rtl/core/VX_operands.sv @@ -288,6 +288,27 @@ module VX_operands import VX_gpu_pkg::*; #( .raddr (gpr_rd_addr), .rdata (gpr_rd_data[j]) ); + + // blast read register file because printf is slowge + logic [31:0] cycle, cycle_n; + assign cycle_n = cycle + 32'd1; + always @(posedge clk) begin + if (reset) begin + cycle <= '0; + end + else begin + cycle <= cycle_n; + end + + if (cycle == 32'd25000) begin + for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin + integer warp = i * ISSUE_RATIO + (k / `NUM_REGS); + integer thread = j; + integer register = k % `NUM_REGS; + $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]); + end + end + end end end diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index a9419c66..28369140 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -9,7 +9,7 @@ module VX_tensor_core #( VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], VX_commit_if.master commit_if [`ISSUE_WIDTH] ); - `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32")); + `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")")); for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin VX_tensor_core_warp #( @@ -34,20 +34,20 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( VX_dispatch_if.slave dispatch_if, VX_commit_if.master commit_if ); - logic [1:0] step = 2'(dispatch_if.data.op_type); + wire [1:0] step = 2'(dispatch_if.data.op_type); logic [3:0] octet_results_valid; logic [3:0] octet_results_ready; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; for (genvar i = 0; i < 4; ++i) begin - logic [7:0][31:0] octet_A = { + wire [7:0][31:0] octet_A = { dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4] }; - logic [7:0][31:0] octet_B = { + wire [7:0][31:0] octet_B = { dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4] }; - logic [7:0][31:0] octet_C = { + wire [7:0][31:0] octet_C = { dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4] }; @@ -141,11 +141,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( ); logic subcommit, subcommit_n; - logic all_valid = (& octet_results_valid); + wire all_valid = (& octet_results_valid); assign commit_if.valid = all_valid; localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; - logic [COMMIT_DATAW-1:0] commit_if_data = { + wire [COMMIT_DATAW-1:0] commit_if_data = { dispatch_if_data_deq, subcommit == 1'b0 ? wb_data_0 : wb_data_1, 1'b0, @@ -227,7 +227,7 @@ module VX_tensor_octet #( end logic substep; - logic substep_n = (operands_ready && operands_valid) ? ~substep : substep; + wire substep_n = (operands_ready && operands_valid) ? ~substep : substep; always @(*) begin A_buffer_n = A_buffer; @@ -260,13 +260,13 @@ module VX_tensor_octet #( wire stall = result_valid && ~result_ready; assign operands_ready = ~stall; - logic [3:0][1:0][31:0] A_tile = { + wire [3:0][1:0][31:0] A_tile = { { A_buffer[0], A_half[0] }, { A_buffer[1], A_half[1] }, { A_buffer[2], A_half[2] }, { A_buffer[3], A_half[3] } }; - logic [1:0][3:0][31:0] B_tile = { + wire [1:0][3:0][31:0] B_tile = { B_buffer, B_half }; logic [3:0][3:0][31:0] C_tile; diff --git a/hw/rtl/core/VX_tensor_ucode.vh b/hw/rtl/core/VX_tensor_ucode.vh new file mode 100644 index 00000000..7603b4a1 --- /dev/null +++ b/hw/rtl/core/VX_tensor_ucode.vh @@ -0,0 +1,96 @@ +HMMA_SET0_STEP0_0: begin + uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)}; +end +HMMA_SET0_STEP0_1: begin + uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)}; +end +HMMA_SET0_STEP1_0: begin + uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)}; +end +HMMA_SET0_STEP1_1: begin + uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)}; +end +HMMA_SET0_STEP2_0: begin + uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)}; +end +HMMA_SET0_STEP2_1: begin + uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)}; +end +HMMA_SET0_STEP3_0: begin + uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)}; +end +HMMA_SET0_STEP3_1: begin + uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)}; +end +HMMA_SET1_STEP0_0: begin + uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)}; +end +HMMA_SET1_STEP0_1: begin + uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)}; +end +HMMA_SET1_STEP1_0: begin + uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)}; +end +HMMA_SET1_STEP1_1: begin + uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)}; +end +HMMA_SET1_STEP2_0: begin + uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)}; +end +HMMA_SET1_STEP2_1: begin + uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)}; +end +HMMA_SET1_STEP3_0: begin + uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)}; +end +HMMA_SET1_STEP3_1: begin + uop = {NEXT, HMMA_SET2_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(3), `FREG(11), `FREG(23)}; +end +HMMA_SET2_STEP0_0: begin + uop = {NEXT, HMMA_SET2_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(4), `FREG(12), `FREG(16)}; +end +HMMA_SET2_STEP0_1: begin + uop = {NEXT, HMMA_SET2_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(5), `FREG(13), `FREG(17)}; +end +HMMA_SET2_STEP1_0: begin + uop = {NEXT, HMMA_SET2_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(4), `FREG(12), `FREG(18)}; +end +HMMA_SET2_STEP1_1: begin + uop = {NEXT, HMMA_SET2_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(5), `FREG(13), `FREG(19)}; +end +HMMA_SET2_STEP2_0: begin + uop = {NEXT, HMMA_SET2_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(4), `FREG(12), `FREG(20)}; +end +HMMA_SET2_STEP2_1: begin + uop = {NEXT, HMMA_SET2_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(5), `FREG(13), `FREG(21)}; +end +HMMA_SET2_STEP3_0: begin + uop = {NEXT, HMMA_SET2_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(4), `FREG(12), `FREG(22)}; +end +HMMA_SET2_STEP3_1: begin + uop = {NEXT, HMMA_SET3_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(5), `FREG(13), `FREG(23)}; +end +HMMA_SET3_STEP0_0: begin + uop = {NEXT, HMMA_SET3_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(6), `FREG(14), `FREG(16)}; +end +HMMA_SET3_STEP0_1: begin + uop = {NEXT, HMMA_SET3_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(7), `FREG(15), `FREG(17)}; +end +HMMA_SET3_STEP1_0: begin + uop = {NEXT, HMMA_SET3_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(6), `FREG(14), `FREG(18)}; +end +HMMA_SET3_STEP1_1: begin + uop = {NEXT, HMMA_SET3_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(7), `FREG(15), `FREG(19)}; +end +HMMA_SET3_STEP2_0: begin + uop = {NEXT, HMMA_SET3_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(6), `FREG(14), `FREG(20)}; +end +HMMA_SET3_STEP2_1: begin + uop = {NEXT, HMMA_SET3_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(7), `FREG(15), `FREG(21)}; +end +HMMA_SET3_STEP3_0: begin + uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)}; +end +HMMA_SET3_STEP3_1: begin + uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(7), `FREG(15), `FREG(23)}; +end diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index c57ea0ba..3d2bb233 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -1,6 +1,6 @@ `include "VX_define.vh" -`define FREG(x) {1'b1, `NRI_BITS'(`CLOG2(x))} +`define FREG(x) {1'b1, `NRI_BITS'(x)} module VX_uop_sequencer import VX_gpu_pkg::*; ( input clk, @@ -28,7 +28,6 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( // reserve space at start of table for more uop sequences localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0); localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8); - /* localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9); localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10); localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11); @@ -62,49 +61,11 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36); localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37); localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38); - */ // register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C - - always @(*) begin case (upc) - HMMA_SET0_STEP0_0: begin - uop = { - NEXT, - HMMA_SET0_STEP0_1, - `EX_BITS'(`EX_TENSOR), - `INST_OP_BITS'(0), // denotes that the first step is being computed - `INST_MOD_BITS'(0), // denotes that this is first substep (tensor core also tracks this) - 1'b1, // write back - 1'b0, // don't use PC - 1'b0, // don't use immediate - 32'b0, // PC is unused - TODO: don't send a bogus PC down the pipeline as it is very confusing in trace - 32'b0, // immediate is unused - `FREG(16), // rd=f16 - `FREG(0), // rs1=f0, - `FREG(8), // rs2=f8 - `FREG(16) // rs3=f16 - }; - end - HMMA_SET0_STEP0_1: begin - uop = { - FINISH, - HMMA_SET0_STEP0_0, - `EX_BITS'(`EX_TENSOR), - `INST_OP_BITS'(0), // denotes that the first step is being computed - `INST_MOD_BITS'(1), // denotes that this is first substep (tensor core also tracks this) - 1'b1, // write back - 1'b0, // don't use PC - 1'b0, // don't use immediate - 32'b0, // PC is unused - TODO: don't send a bogus PC down the pipeline as it is very confusing in trace - 32'b0, // immediate is unused - `FREG(17), // rd=f17 - `FREG(1), // rs1=f1, - `FREG(9), // rs2=f9 - `FREG(17) // rs3=f17 - }; - end + `include "VX_tensor_ucode.vh" default: begin uop = '0; end @@ -113,13 +74,15 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( logic [UPC_BITS-1:0] upc, upc_r, upc_n; - logic [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS]; - logic [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS]; + wire [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS]; + wire [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS]; - logic uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready; - logic uop_start = ~use_uop_1d && use_uop; - logic uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready; logic use_uop, use_uop_1d; + wire uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready; + + wire uop_start = ~use_uop_1d && use_uop; + wire uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready; + // merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint, but conceptually they should be linked always @(*) begin @@ -149,7 +112,7 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( end // copy UUID, wis, tmask from microcoded instruction - logic [IBUFFER_IF_DATAW-1:0] ibuffer_output = { + wire [IBUFFER_IF_DATAW-1:0] ibuffer_output = { uop_sequencer_if.data.uuid, uop_sequencer_if.data.wis, uop_sequencer_if.data.tmask, @@ -161,11 +124,18 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data; always @(posedge clk) begin - - if (use_uop) begin - $display("unexpectedly used uop at %d", $time); + if (uop_start) begin + $display("UOP start @ %t", $time); + $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready); + end + + if (uop_fire) begin + $display("UOP fire @ %t", $time); + end + + if (uop_finish) begin + $display("UOP finish @ %t", $time); end - if (reset) begin upc_r <= '0; diff --git a/hw/rtl/core/generate_ucode.py b/hw/rtl/core/generate_ucode.py new file mode 100644 index 00000000..4cda111f --- /dev/null +++ b/hw/rtl/core/generate_ucode.py @@ -0,0 +1,81 @@ +num_sets = 4 +num_steps = 4 +num_substeps = 2 + + +def set_step_substep(sequence_number): + set_num, step = divmod(sequence_number, num_steps * num_substeps) + step //= num_substeps + substep = sequence_number % 2 + + return set_num, step, substep + +# set + substep -> rs1, rs2 +rs1 = { + (0, 0): 0, + (0, 1): 1, + (1, 0): 2, + (1, 1): 3, + (2, 0): 4, + (2, 1): 5, + (3, 0): 6, + (3, 1): 7 +} + +rs2 = { + (0, 0): 8, + (0, 1): 9, + (1, 0): 10, + (1, 1): 11, + (2, 0): 12, + (2, 1): 13, + (3, 0): 14, + (3, 1): 15 +} + +# step + substep -> rs3, rd +rs3_rd = { + (0, 0): 16, + (0, 1): 17, + (1, 0): 18, + (1, 1): 19, + (2, 0): 20, + (2, 1): 21, + (3, 0): 22, + (3, 1): 23 +} + +with open('VX_tensor_ucode.vh', 'w') as f: + for sequence_number in range(num_sets * num_steps * num_substeps): + set_num, step, substep = set_step_substep(sequence_number) + + + next_sequence_num = (sequence_number + 1) % (num_sets * num_steps * num_substeps) + next_set_num, next_step, next_substep = set_step_substep(next_sequence_num) + finish = (next_sequence_num == 0) + + name = "HMMA_SET{}_STEP{}_{}" + ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG({}), `FREG({}), `FREG({}), `FREG({})" + + name = name.format( + set_num, step, substep, + ) + + ucode = ucode.format( + "FINISH" if finish else "NEXT", + next_set_num, next_step, next_substep, + step, + substep, + rs3_rd[(step, substep)], + rs1[(set_num, substep)], + rs2[(set_num, substep)], + rs3_rd[(step, substep)], + ) + + entry = f"{name}: begin \n" + entry += "\tuop = {" + entry += ucode + entry += "}; \n" + entry += "end \n" + + f.write(entry) \ No newline at end of file diff --git a/tests/kernel/tensor/Makefile b/tests/kernel/tensor/Makefile new file mode 100644 index 00000000..19cb340e --- /dev/null +++ b/tests/kernel/tensor/Makefile @@ -0,0 +1,8 @@ +PROJECT = tensor + +SRCS = main.cpp +DEPS = a_matrix.h +DEPS += b_matrix.h +DEPS += c_matrix.h + +include ../common.mk diff --git a/tests/kernel/tensor/check_correctness.py b/tests/kernel/tensor/check_correctness.py new file mode 100644 index 00000000..84db43d3 --- /dev/null +++ b/tests/kernel/tensor/check_correctness.py @@ -0,0 +1,94 @@ +import numpy as np +import struct + +A_array = np.zeros((16, 8)) +B_array = np.zeros((8, 16)) +C_array = np.zeros((16, 16)) + +file = input("simulator output filename: ") + +def hex2float(float_hex_str): + # print(float_hex_str.strip()) + return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0] + +def C_index(threadgroup, thread, register): + """ + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); + """ + + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + if C_array[row, col] != 0: + print("bad") + return (row, col) + + +with open(file) as f: + for line in f.readlines(): + line = line.strip() + if "warp" in line: + a, b, c = line.split(',') + _, a = a.split(' ') + _, b = b.strip().split(' ') + c, d = c.strip().split(':') + _, c = c.split(' ') + warp = int(a) + thread = int(b) + register = int(c) + value = d.strip() + + if warp != 0: + continue + if not (32 <= register < 32+24): + continue + + register = register - 32 + + # threadgroups 0, 4, 1, 5 have all elements of A + threadgroup = thread // 4 + if threadgroup in [0, 4, 1, 5]: + row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4 + if 0 <= register < 8: + A_array[row, register] = hex2float(value) + + if threadgroup in [0, 4, 2, 6]: + col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4 + if 8 <= register < 16: + B_array[register-8, col] = hex2float(value) + + if 16 <= register < 24: + # print(value) + C_array[C_index(threadgroup, thread, register)] = hex2float(value) + + +expected = np.load("abc.npz") +expected_A = expected['A_array'] +expected_B = expected['B_array'] +expected_C = expected['C_array'] +expected_C = expected_C + expected_A @ expected_B + +print(expected_C - C_array) + +assert np.allclose(expected_A, A_array) +assert np.allclose(expected_B, B_array) +assert np.allclose(expected_C, C_array) \ No newline at end of file diff --git a/tests/kernel/tensor/create_test_case.py b/tests/kernel/tensor/create_test_case.py new file mode 100644 index 00000000..0fbd1583 --- /dev/null +++ b/tests/kernel/tensor/create_test_case.py @@ -0,0 +1,29 @@ +import numpy as np +# A_array = np.random.rand(16, 8) +# B_array = np.random.rand(8, 16) +A_array = np.zeros((16, 8)) +B_array = np.zeros((8, 16)) +A_array[0,:] = 1.0 +B_array[:,0] = 1.0 +C_array = np.random.rand(16, 16) + + +with open('a_matrix.h', 'w') as f: + for i in range(A_array.shape[0]): + for j in range(A_array.shape[1]): + f.write(f'{A_array[i,j]}f, ') + f.write('\n') + +with open('b_matrix.h', 'w') as f: + for i in range(B_array.shape[0]): + for j in range(B_array.shape[1]): + f.write(f'{B_array[i,j]}f, ') + f.write('\n') + +with open('c_matrix.h', 'w') as f: + for i in range(C_array.shape[0]): + for j in range(C_array.shape[1]): + f.write(f'{C_array[i,j]}f, ') + f.write('\n') + +np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array) \ No newline at end of file diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp new file mode 100644 index 00000000..2c45fa9b --- /dev/null +++ b/tests/kernel/tensor/main.cpp @@ -0,0 +1,96 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +#include "test_data.h" + +void vx_wmma_load() { + int tid = vx_thread_id(); + int tg = tid / 4; + + // load A + int row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + asm volatile ("flw f0, %0" :: "m"(A[row][0])); + asm volatile ("flw f1, %0" :: "m"(A[row][1])); + asm volatile ("flw f2, %0" :: "m"(A[row][2])); + asm volatile ("flw f3, %0" :: "m"(A[row][3])); + asm volatile ("flw f4, %0" :: "m"(A[row][4])); + asm volatile ("flw f5, %0" :: "m"(A[row][5])); + asm volatile ("flw f6, %0" :: "m"(A[row][6])); + asm volatile ("flw f7, %0" :: "m"(A[row][7])); + + // load B + int col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; + + asm volatile ("flw f8 , %0" :: "m"(B[0][col])); + asm volatile ("flw f9 , %0" :: "m"(B[1][col])); + asm volatile ("flw f10, %0" :: "m"(B[2][col])); + asm volatile ("flw f11, %0" :: "m"(B[3][col])); + asm volatile ("flw f12, %0" :: "m"(B[4][col])); + asm volatile ("flw f13, %0" :: "m"(B[5][col])); + asm volatile ("flw f14, %0" :: "m"(B[6][col])); + asm volatile ("flw f15, %0" :: "m"(B[7][col])); + + // load C + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; + + asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); +} + +float results[32*8]; + +void store_wmma_result() { + int tid = vx_thread_id(); + + asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); + asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); + asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); + asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); + asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); + asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); + asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); + asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); +} + +void print_wmma_result() { + for (int tid = 0; tid < 32; tid += 1) { + for (int reg = 0; reg < 8; reg += 1) { + vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); + } + } +} + +int main() +{ + vx_tmc(-1); + vx_wmma_load(); + vx_wmma(); + store_wmma_result(); + vx_tmc(1); + print_wmma_result(); + + return 0; +} \ No newline at end of file diff --git a/tests/kernel/tensor/test_data.h b/tests/kernel/tensor/test_data.h new file mode 100644 index 00000000..83b05157 --- /dev/null +++ b/tests/kernel/tensor/test_data.h @@ -0,0 +1,11 @@ +float A[16][8] = { + #include "a_matrix.h" +}; + +float B[8][16] = { + #include "b_matrix.h" +}; + +float C[16][16] = { + #include "c_matrix.h" +}; \ No newline at end of file From 08d7721e1163636f705ad869ecb781364431e32c Mon Sep 17 00:00:00 2001 From: joshua Date: Thu, 28 Mar 2024 03:00:15 -0700 Subject: [PATCH 06/55] annoying swizzling problems --- hw/rtl/core/VX_tensor_core.sv | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 28369140..1055351d 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -42,13 +42,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( for (genvar i = 0; i < 4; ++i) begin wire [7:0][31:0] octet_A = { - dispatch_if.data.rs1_data[4*i +: 4], dispatch_if.data.rs1_data[16+4*i +: 4] + dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] }; wire [7:0][31:0] octet_B = { - dispatch_if.data.rs2_data[4*i +: 4], dispatch_if.data.rs2_data[16+4*i +: 4] + dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] }; wire [7:0][31:0] octet_C = { - dispatch_if.data.rs3_data[4*i +: 4], dispatch_if.data.rs3_data[16+4*i +: 4] + dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] }; logic [3:0][3:0][31:0] octet_D; @@ -125,7 +125,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // this is probably a little oversized VX_fifo_queue #( .DATAW(DATAW), - .DEPTH(8) + .DEPTH(16) ) pending_uops ( .clk(clk), .reset(reset), @@ -207,19 +207,19 @@ module VX_tensor_octet #( always @(*) begin case (step) 2'b00: begin - A_half = { A_in[1:0], A_in[5:4] }; + A_half = { A_in[5:4], A_in[1:0] }; B_half = B_in[3:0]; end 2'b01: begin - A_half = { A_in[3:2], A_in[7:6] }; + A_half = { A_in[7:6], A_in[3:2] }; B_half = B_in[3:0]; end 2'b10: begin - A_half = { A_in[1:0], A_in[5:4] }; + A_half = { A_in[5:4], A_in[1:0] }; B_half = B_in[7:4]; end 2'b11: begin - A_half = { A_in[3:2], A_in[7:6] }; + A_half = { A_in[7:6], A_in[3:2] }; B_half = B_in[7:4]; end endcase @@ -261,22 +261,22 @@ module VX_tensor_octet #( assign operands_ready = ~stall; wire [3:0][1:0][31:0] A_tile = { - { A_buffer[0], A_half[0] }, - { A_buffer[1], A_half[1] }, - { A_buffer[2], A_half[2] }, - { A_buffer[3], A_half[3] } + { A_half[3], A_buffer[3] }, + { A_half[2], A_buffer[2] }, + { A_half[1], A_buffer[1] }, + { A_half[0], A_buffer[0] } }; wire [1:0][3:0][31:0] B_tile = { - B_buffer, B_half + B_half, B_buffer }; logic [3:0][3:0][31:0] C_tile; always @(*) begin C_tile = { - C_buffer[0], C_half[0], C_buffer[1], C_half[1], - C_buffer[2], C_half[2], C_buffer[3], C_half[3], - C_buffer[4], C_half[4], C_buffer[5], C_half[5], - C_buffer[6], C_half[6], C_buffer[7], C_half[7] + C_half[7], C_buffer[7], C_half[5], C_buffer[5], + C_half[6], C_buffer[6], C_half[4], C_buffer[4], + C_half[3], C_buffer[3], C_half[1], C_buffer[1], + C_half[2], C_buffer[2], C_half[0], C_buffer[0] }; end From d8f9359fae1a93630d5d775da55cb22501345bcc Mon Sep 17 00:00:00 2001 From: joshua Date: Thu, 28 Mar 2024 13:04:02 -0700 Subject: [PATCH 07/55] test case update --- hw/unittest/tensor/Makefile | 2 +- tests/kernel/common.mk | 2 +- tests/kernel/tensor/check_correctness.py | 5 +++-- tests/kernel/tensor/create_test_case.py | 17 ++++++++++------- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/hw/unittest/tensor/Makefile b/hw/unittest/tensor/Makefile index 021b7dcb..a968ab14 100644 --- a/hw/unittest/tensor/Makefile +++ b/hw/unittest/tensor/Makefile @@ -35,7 +35,7 @@ SRCS += $(DPI_DIR)/float_dpi.cpp SRCS += $(SIM_DIR)/common/rvfloats.cpp SRCS += ./main.cpp -RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_core.sv +RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_dpu.sv RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv TOP = VX_tensor_tb diff --git a/tests/kernel/common.mk b/tests/kernel/common.mk index 7bf4b520..e276c624 100644 --- a/tests/kernel/common.mk +++ b/tests/kernel/common.mk @@ -33,7 +33,7 @@ $(PROJECT).dump: $(PROJECT).elf $(PROJECT).bin: $(PROJECT).elf $(CP) -O binary $(PROJECT).elf $(PROJECT).bin -$(PROJECT).elf: $(SRCS) +$(PROJECT).elf: $(SRCS) $(DEPS) $(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf run-rtlsim: $(PROJECT).bin diff --git a/tests/kernel/tensor/check_correctness.py b/tests/kernel/tensor/check_correctness.py index 84db43d3..c81212d0 100644 --- a/tests/kernel/tensor/check_correctness.py +++ b/tests/kernel/tensor/check_correctness.py @@ -86,8 +86,9 @@ expected_A = expected['A_array'] expected_B = expected['B_array'] expected_C = expected['C_array'] expected_C = expected_C + expected_A @ expected_B - -print(expected_C - C_array) +print(expected_C[0:8, 0:8]) +print(C_array[0:8, 0:8]) +print((expected_C - C_array)[0:8, 0:8]) assert np.allclose(expected_A, A_array) assert np.allclose(expected_B, B_array) diff --git a/tests/kernel/tensor/create_test_case.py b/tests/kernel/tensor/create_test_case.py index 0fbd1583..35ad7d73 100644 --- a/tests/kernel/tensor/create_test_case.py +++ b/tests/kernel/tensor/create_test_case.py @@ -1,12 +1,15 @@ import numpy as np -# A_array = np.random.rand(16, 8) -# B_array = np.random.rand(8, 16) -A_array = np.zeros((16, 8)) -B_array = np.zeros((8, 16)) -A_array[0,:] = 1.0 -B_array[:,0] = 1.0 +A_array = np.random.rand(16, 8) +B_array = np.random.rand(8, 16) C_array = np.random.rand(16, 16) - +# A_array = np.zeros((16, 8)) +# B_array = np.zeros((8, 16)) +# A_array[0,:] = 1.0 +# B_array[:,4] = 1.0 +# C_array = np.zeros((16, 16)) +# for i in range(16): +# for j in range(16): +# C_array[i,j] = i * 16 + j with open('a_matrix.h', 'w') as f: for i in range(A_array.shape[0]): From 5bd25985c6ecabad0eb1277c86474d0969c1d106 Mon Sep 17 00:00:00 2001 From: joshua Date: Sat, 4 May 2024 23:01:47 -0700 Subject: [PATCH 08/55] i kinda forgot most of changes --- hw/dpi/float_dpi.cpp | 152 ++++++++++++- hw/dpi/float_dpi.vh | 1 + hw/rtl/core/VX_commit.sv | 13 +- hw/rtl/core/VX_operands.sv | 16 +- hw/rtl/core/VX_tensor_core.sv | 9 +- hw/rtl/core/VX_tensor_ucode.vh | 2 +- hw/rtl/core/VX_uop_sequencer.sv | 8 +- hw/rtl/core/generate_ucode.py | 6 +- hw/rtl/fpu/VX_tensor_dpu.sv | 9 +- kernel/include/vx_spawn.h | 1 + kernel/src/vx_spawn.c | 105 +++++++++ tests/kernel/tensor/main.cpp | 2 +- tests/regression/sgemm_tcore/Makefile | 9 + tests/regression/sgemm_tcore/common.h | 18 ++ tests/regression/sgemm_tcore/kernel.cpp | 285 ++++++++++++++++++++++++ tests/regression/sgemm_tcore/main.cpp | 270 ++++++++++++++++++++++ 16 files changed, 882 insertions(+), 24 deletions(-) create mode 100644 tests/regression/sgemm_tcore/Makefile create mode 100644 tests/regression/sgemm_tcore/common.h create mode 100644 tests/regression/sgemm_tcore/kernel.cpp create mode 100644 tests/regression/sgemm_tcore/main.cpp diff --git a/hw/dpi/float_dpi.cpp b/hw/dpi/float_dpi.cpp index d5209bed..a9d9eaef 100644 --- a/hw/dpi/float_dpi.cpp +++ b/hw/dpi/float_dpi.cpp @@ -56,6 +56,7 @@ extern "C" { void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags); void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile); + void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile); } inline uint64_t nan_box(uint32_t value) { @@ -413,4 +414,153 @@ void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, } write_float_array(D_tile, &c_D_tile[0][0], M, M); -} \ No newline at end of file +} + +// 1 copy per warp +float A_tile_full[4][16][8]; +float B_tile_full[4][8][16]; +float C_tile_full[4][16][16]; +float D_tile_full[4][16][16]; +int steps[4]; + +void print_array(float* array, int rows, int cols) { + for (int i = 0; i < rows; i += 1) { + for (int j = 0; j < cols; j += 1) { + std::cout << array[i*cols+j] << " "; + } + std::cout << "\n"; + } + std::cout << std::endl; +} + +void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile) { + // std::cout << "A: " << std::endl; + fill_float_array(A_tile, &c_A_tile[0][0], M, K); + // std::cout << "B: " << std::endl; + fill_float_array(B_tile, &c_B_tile[0][0], K, M); + // std::cout << "C: " << std::endl; + fill_float_array(C_tile, &c_C_tile[0][0], M, M); + // for some reason this still holds onto old value? very strange + // std::cout << "D: " << std::endl; + fill_float_array(D_tile, &c_D_tile[0][0], M, M); + + int octet_row_offset; + int octet_col_offset; + switch(octet) { + case 0: + octet_row_offset = 0; + octet_col_offset = 0; + break; + case 1: + octet_row_offset = 8; + octet_col_offset = 0; + break; + case 2: + octet_row_offset = 0; + octet_col_offset = 8; + break; + case 3: + octet_row_offset = 8; + octet_col_offset = 8; + break; + } + + int step_row_offset; + int step_col_offset; + int step = (steps[wid] % 16) / 4; + int set = (steps[wid] / 16); + switch(step) { + case 0: + step_row_offset = 0; + step_col_offset = 0; + break; + case 1: + step_row_offset = 2; + step_col_offset = 0; + break; + case 2: + step_row_offset = 0; + step_col_offset = 4; + break; + case 3: + step_row_offset = 2; + step_col_offset = 4; + break; + } + + if (steps[0] >= 48) { + // std::cout << "octet " << octet << " step " << steps[0] << "\n"; + // print_array(&c_D_tile[0][0], 4, 4); + } + + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+0] = c_D_tile[0][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+1] = c_D_tile[0][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+2] = c_D_tile[0][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+3] = c_D_tile[0][3]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+0] = c_D_tile[1][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+1] = c_D_tile[1][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+2] = c_D_tile[1][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+3] = c_D_tile[1][3]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+0] = c_D_tile[2][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+1] = c_D_tile[2][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+2] = c_D_tile[2][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+3] = c_D_tile[2][3]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+0] = c_D_tile[3][0]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+1] = c_D_tile[3][1]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+2] = c_D_tile[3][2]; + D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+3] = c_D_tile[3][3]; + + if (octet == 0 || octet == 1) { + octet_row_offset = octet * 8; + if (step == 0) { + step_row_offset = 0; + } + if (step == 1) { + step_row_offset = 2; + } + if (step == 0 || step == 1) { + A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+0] = c_A_tile[0][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+1] = c_A_tile[0][1]; + A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+0] = c_A_tile[1][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+1] = c_A_tile[1][1]; + A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+0] = c_A_tile[2][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+1] = c_A_tile[2][1]; + A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+0] = c_A_tile[3][0]; + A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+1] = c_A_tile[3][1]; + } + } + + if (octet == 0 || octet == 2) { + octet_col_offset = octet * 4; + if (step == 0) { + step_col_offset = 0; + } + else if (step == 2) { + step_col_offset = 4; + } + if (step == 0 || step == 2) { + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+0] = c_B_tile[0][0]; + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+1] = c_B_tile[0][1]; + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+2] = c_B_tile[0][2]; + B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+3] = c_B_tile[0][3]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+0] = c_B_tile[1][0]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+1] = c_B_tile[1][1]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+2] = c_B_tile[1][2]; + B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+3] = c_B_tile[1][3]; + } + } + + steps[wid] += 1; + if (steps[wid] % 64 == 0) { + steps[wid] = 0; + std::cout << "warp " << wid << " finished wmma\n"; + std::cout << "A tile" << "\n"; + print_array(&A_tile_full[wid][0][0], 16, 8); + std::cout << "B tile" << "\n"; + print_array(&B_tile_full[wid][0][0], 8, 16); + // std::cout << "C tile" << "\n"; + // print_array(&C_tile_full[wid][0][0], 16, 16); + std::cout << "D tile" << "\n"; + print_array(&D_tile_full[wid][0][0], 16, 16); + } +} diff --git a/hw/dpi/float_dpi.vh b/hw/dpi/float_dpi.vh index c8e7c9cb..08abf25b 100644 --- a/hw/dpi/float_dpi.vh +++ b/hw/dpi/float_dpi.vh @@ -45,5 +45,6 @@ import "DPI-C" function void dpi_fmin(input logic enable, input int dst_fmt, inp import "DPI-C" function void dpi_fmax(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags); import "DPI-C" function void dpi_hmma(input logic enable, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, output bit[3:0][3:0][31:0] D_tile); +import "DPI-C" function void dpi_print_results(input int wid, input int octet, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, input bit[3:0][3:0][31:0] D_tile); `endif diff --git a/hw/rtl/core/VX_commit.sv b/hw/rtl/core/VX_commit.sv index f0d94925..5fda49f9 100644 --- a/hw/rtl/core/VX_commit.sv +++ b/hw/rtl/core/VX_commit.sv @@ -53,6 +53,7 @@ module VX_commit import VX_gpu_pkg::*; #( wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask; wire [`ISSUE_WIDTH-1:0] commit_eop; wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel; + `UNUSED_VAR (commit_sel) for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin @@ -175,14 +176,17 @@ module VX_commit import VX_gpu_pkg::*; #( // relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions, // so probably want to change this at some point // (i.e. pass a "don't count this towards pending instructions" signal down the pipeline) - logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n; + // logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n; wire [`ISSUE_WIDTH-1:0] final_hmma; `ifdef EXT_T_ENABLE for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin - assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i]; - assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0); + // assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i]; + // assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0); + // i suppose this is now a feature and not a bug + // if PC is 0, this means it is not final step of a wmma, shouldn't be committed + assign final_hmma[i] = (commit_if[i].data.PC != 32'b0); end - + /* always @(posedge clk) begin if (reset) begin hmma_ctr <= '0; @@ -191,6 +195,7 @@ module VX_commit import VX_gpu_pkg::*; #( hmma_ctr <= hmma_ctr_n; end end + */ `else assign final_hmma = '1; `endif diff --git a/hw/rtl/core/VX_operands.sv b/hw/rtl/core/VX_operands.sv index 3fe1d7f4..f4ff1edc 100644 --- a/hw/rtl/core/VX_operands.sv +++ b/hw/rtl/core/VX_operands.sv @@ -300,14 +300,14 @@ module VX_operands import VX_gpu_pkg::*; #( cycle <= cycle_n; end - if (cycle == 32'd25000) begin - for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin - integer warp = i * ISSUE_RATIO + (k / `NUM_REGS); - integer thread = j; - integer register = k % `NUM_REGS; - $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]); - end - end + // if (cycle == 32'd25000) begin + // for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin + // integer warp = i * ISSUE_RATIO + (k / `NUM_REGS); + // integer thread = j; + // integer register = k % `NUM_REGS; + // $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]); + // end + // end end end end diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 1055351d..de27f757 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -55,7 +55,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( logic result_valid; logic result_ready; VX_tensor_octet #( - + .ISW(ISW), + .OCTET(i) ) octet ( .clk(clk), .reset(reset), @@ -177,7 +178,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( endmodule module VX_tensor_octet #( - + parameter ISW, + parameter OCTET ) ( input clk, input reset, @@ -282,7 +284,8 @@ module VX_tensor_octet #( wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); VX_tensor_dpu #( - + .ISW(ISW), + .OCTET(OCTET) ) dpu ( .clk(clk), .reset(reset), diff --git a/hw/rtl/core/VX_tensor_ucode.vh b/hw/rtl/core/VX_tensor_ucode.vh index 7603b4a1..8c3243de 100644 --- a/hw/rtl/core/VX_tensor_ucode.vh +++ b/hw/rtl/core/VX_tensor_ucode.vh @@ -92,5 +92,5 @@ HMMA_SET3_STEP3_0: begin uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)}; end HMMA_SET3_STEP3_1: begin - uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(7), `FREG(15), `FREG(23)}; + uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(7), `FREG(15), `FREG(23)}; end diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index 3d2bb233..8fe2f3d9 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -125,16 +125,16 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( always @(posedge clk) begin if (uop_start) begin - $display("UOP start @ %t", $time); - $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready); + // $display("UOP start @ %t", $time); + // $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready); end if (uop_fire) begin - $display("UOP fire @ %t", $time); + // $display("UOP fire @ %t", $time); end if (uop_finish) begin - $display("UOP finish @ %t", $time); + // $display("UOP finish @ %t", $time); end if (reset) begin diff --git a/hw/rtl/core/generate_ucode.py b/hw/rtl/core/generate_ucode.py index 4cda111f..671b6c7b 100644 --- a/hw/rtl/core/generate_ucode.py +++ b/hw/rtl/core/generate_ucode.py @@ -55,17 +55,21 @@ with open('VX_tensor_ucode.vh', 'w') as f: finish = (next_sequence_num == 0) name = "HMMA_SET{}_STEP{}_{}" - ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG({}), `FREG({}), `FREG({}), `FREG({})" + ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b{}, 32'b{}, `FREG({}), `FREG({}), `FREG({}), `FREG({})" name = name.format( set_num, step, substep, ) + + pc_imm = 1 if finish else 0 ucode = ucode.format( "FINISH" if finish else "NEXT", next_set_num, next_step, next_substep, step, substep, + pc_imm, + pc_imm, rs3_rd[(step, substep)], rs1[(set_num, substep)], rs2[(set_num, substep)], diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 8108790c..3927e1e3 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -1,7 +1,8 @@ `include "VX_fpu_define.vh" module VX_tensor_dpu #( - + parameter ISW, + parameter OCTET ) ( input clk, input reset, @@ -21,6 +22,12 @@ module VX_tensor_dpu #( always @(*) begin dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma); end + + always @(posedge clk) begin + if (~reset && valid_in) begin + dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma); + end + end VX_shift_register #( diff --git a/kernel/include/vx_spawn.h b/kernel/include/vx_spawn.h index 2584b997..d8797945 100644 --- a/kernel/include/vx_spawn.h +++ b/kernel/include/vx_spawn.h @@ -48,6 +48,7 @@ void vx_wspawn_wait(); void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg); void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg); +void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg); void vx_serial(vx_serial_cb callback, void * arg); diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index fd8258e1..b1ef7230 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -83,6 +83,38 @@ static void __attribute__ ((noinline)) spawn_tasks_rem_stub() { (p_wspawn_args->callback)(task_id, p_wspawn_args->arg); } +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { + int NT = vx_num_threads(); + int NW = vx_num_warps(); + int cid = vx_core_id(); + int wid = vx_warp_id(); + int tid = vx_thread_id(); + + wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; + + int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs); + int offset = p_wspawn_args->offset + (NT * wid + tid); + + vx_spawn_tasks_cb callback = p_wspawn_args->callback; + void* arg = p_wspawn_args->arg; + for (int wave_id = 0; wave_id < waves; ++wave_id) { + int task_id = offset + (wave_id * NT * NW); + callback(task_id, arg); + } +} + +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() { + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_tasks_contiguous_all_stub(); + + // disable warp + // deadlock here on warps 1, 2, 3 + vx_tmc_zero(); +} + static void __attribute__ ((noinline)) spawn_tasks_all_cb() { // activate all threads vx_tmc(-1); @@ -94,6 +126,79 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { vx_tmc_zero(); } +void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { + // device specs + int NC = vx_num_cores(); + int NW = vx_num_warps(); + int NT = vx_num_threads(); + + // current core id + int core_id = vx_core_id(); + if (core_id >= NUM_CORES_MAX) + return; + + // calculate necessary active cores + int WT = NW * NT; + int nC = (num_tasks > WT) ? (num_tasks / WT) : 1; + int nc = MIN(nC, NC); + if (core_id >= nc) + return; // terminate extra cores + + // number of tasks per core + int tasks_per_core = num_tasks / nc; + int tasks_per_core_n1 = tasks_per_core; + if (core_id == (nc-1)) { + int rem = num_tasks - (nc * tasks_per_core); + tasks_per_core_n1 += rem; // last core also executes remaining tasks + } + + // number of tasks per warp + int TW = tasks_per_core_n1 / NT; // occupied warps + int rT = tasks_per_core_n1 - TW * NT; // remaining threads + int fW = 1, rW = 0; + if (TW >= NW) { + fW = TW / NW; // full warps iterations + rW = TW - fW * NW; // remaining warps + } + + wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW }; + g_wspawn_args[core_id] = &wspawn_args; + + if (TW >= 1) { + // execute callback on other warps + int nw = MIN(TW, NW); + vx_wspawn(nw, spawn_tasks_contiguous_all_cb); + + // activate all threads + vx_tmc(-1); + + // call stub routine + spawn_tasks_contiguous_all_stub(); + + // back to single-threaded + vx_tmc_one(); + + // wait for spawn warps to terminate + // deadlock here on warp 0! + vx_wspawn_wait(); + } + + if (rT != 0) { + // adjust offset + wspawn_args.offset += (tasks_per_core_n1 - rT); + + // activate remaining threads + int tmask = (1 << rT) - 1; + vx_tmc(tmask); + + // call stub routine + spawn_tasks_rem_stub(); + + // back to single-threaded + vx_tmc_one(); + } +} + void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { // device specs int NC = vx_num_cores(); diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 2c45fa9b..5fc222b2 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -90,7 +90,7 @@ int main() vx_wmma(); store_wmma_result(); vx_tmc(1); - print_wmma_result(); + // print_wmma_result(); return 0; } \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/Makefile b/tests/regression/sgemm_tcore/Makefile new file mode 100644 index 00000000..0c378af0 --- /dev/null +++ b/tests/regression/sgemm_tcore/Makefile @@ -0,0 +1,9 @@ +PROJECT = sgemm_tcore + +SRCS = main.cpp common.h + +VX_SRCS = kernel.cpp + +OPTS ?= -n16 + +include ../common.mk \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/common.h b/tests/regression/sgemm_tcore/common.h new file mode 100644 index 00000000..d94a270f --- /dev/null +++ b/tests/regression/sgemm_tcore/common.h @@ -0,0 +1,18 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define DEV_SMEM_START_ADDR 0xff000000 + +typedef struct { + uint32_t dim_m; + uint32_t dim_n; + uint32_t dim_k; + uint64_t addr_a; + uint64_t addr_b; + uint64_t addr_c; +} kernel_arg_t; + +#endif \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp new file mode 100644 index 00000000..f4e467f4 --- /dev/null +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -0,0 +1,285 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include +#include +#include "common.h" + +#define BM 16 +#define BN 16 +#define BK 8 + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) { + int tid = thread_in_warp; + int tg = tid / 4; + + // load A + int row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + int smem_A_m = 32; + int smem_A_n = 8; + int smem_B_m = 8; + int smem_B_n = 32; + + int A_offset = (row + BM * warp_y) * smem_A_n; + + asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0])); + asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1])); + asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2])); + asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3])); + asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4])); + asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5])); + asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6])); + asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7])); + + // load B + int col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; + + asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); + asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); +} + +inline void initialize_C() { + // initialize C to zeros + asm volatile ("fmv.w.x f16, x0"); + asm volatile ("fmv.w.x f17, x0"); + asm volatile ("fmv.w.x f18, x0"); + asm volatile ("fmv.w.x f19, x0"); + asm volatile ("fmv.w.x f20, x0"); + asm volatile ("fmv.w.x f21, x0"); + asm volatile ("fmv.w.x f22, x0"); + asm volatile ("fmv.w.x f23, x0"); +} + +inline void write_results( + volatile float *local_warp_results, + int thread_in_warp, + int warp_x, + int warp_y, + int dim_m, + int dim_n, + float *C, + int threadblock_id_x, + int threadblock_id_y +) { + int tid = thread_in_warp; + int tg = tid / 4; + + asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0])); + asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1])); + asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2])); + asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3])); + asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4])); + asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5])); + asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6])); + asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7])); + + /* + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + return (row, col) + */ + + int local_col = ((tg % 4) / 2) * 8; + int local_row = (tg * 8) % 16; + local_row += (tg / 4) * 4; + + // int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2}; + // int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5}; + + // int thread_row_offsets[4] = {0, 1, 0, 1}; + // int thread_col_offsets[4] = {0, 0, 2, 2}; + int thread_row_offset = (tid % 4) % 2; + int thread_col_offset = ((tid % 4) / 2) * 2; + + float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM; + for (int i = 0; i < 8; i += 1) { + int row_offset = ((i / 2) % 2) * 2; + int col_offset = (i / 4) * 4 + i % 2; + + int adjusted_local_row = local_row + thread_row_offset + row_offset; + int adjusted_local_col = local_col + thread_col_offset + col_offset; + + float v = local_warp_results[tid*8+i]; + global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; + } +} + +void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { + vx_fence(); + vx_barrier(barrier_id, count); +} + +void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_dim_x, + const uint32_t threadblock_dim_y, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, + const uint32_t threadblock_id, + float *sharedmem_per_threadblock) { + const float *A = (const float *)arg->addr_a; + const float *B = (const float *)arg->addr_b; + float *C = (float *)arg->addr_c; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_k = arg->dim_k; + + // FIXME: Output block size is assumed to be square, i.e. BM == BN + // const uint32_t BM = threadblock_dim_y; + // const uint32_t BN = threadblock_dim_y; + // const uint32_t BK = threadblock_dim_x; + // constexpr uint32_t BM = 8; + // constexpr uint32_t BN = 8; + // constexpr uint32_t BK = 2; + + const uint32_t warp_in_threadblock = tid_in_threadblock / 32; + const uint32_t tid_in_warp = tid_in_threadblock % 32; + const uint32_t warp_x = warp_in_threadblock % 2; + const uint32_t warp_y = warp_in_threadblock / 2; + + const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y; + + // 32 * 8 block of A, being loaded by 4 warps + const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK); + const uint32_t local_a_col = tid_in_warp % BK; + + // 8 * 32 block of B, being loaded by 4 warps + // do a fat coalesced load + const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; + const uint32_t local_b_row = warp_in_threadblock; + const uint32_t local_b_col = tid_in_warp; + + + volatile float *local_a = sharedmem_per_threadblock; + const size_t local_a_elems = (threadblock_dim_y * BK); + volatile float *local_b = sharedmem_per_threadblock + local_a_elems; + const size_t local_b_elems = (threadblock_dim_x * BK); + volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN); + + // clear out C + initialize_C(); + + for (uint32_t k = 0; k < dim_k; k += BK) { + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable) + + // coalesced load from global memory -> unit-stride store into shared memory + uint32_t global_a_offset = + dim_k * (global_a_row + local_a_row) + (k + local_a_col); + uint32_t shared_a_offset = + BK * local_a_row + local_a_col; + + local_a[shared_a_offset] = A[global_a_offset]; + + global_a_offset += dim_k * (BM / 4); + shared_a_offset += BK * (BM / 4); + + local_a[shared_a_offset] = A[global_a_offset]; + + uint32_t global_b_offset = + dim_n * (k + local_b_row) + (global_b_col + local_b_col); + uint32_t shared_b_offset = + (BN * 2) * (local_b_row) + local_b_col; + + local_b[shared_b_offset] = B[global_b_offset]; + + global_b_offset += dim_n * (BK / 2); + shared_b_offset += (BN * 2) * (BK / 2); + + local_b[shared_b_offset] = B[global_b_offset]; + + // want all 4 warps to reach barrier before moving on (just use barrier 0 for now) + threadblock_barrier(tid_in_threadblock, 0, 4); + + // perform wmma + vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + vx_wmma(); + + // same as above + threadblock_barrier(tid_in_threadblock, 0, 4); + } + + write_results( + local_warp_results, + tid_in_warp, + warp_x, + warp_y, + dim_m, + dim_n, + C, + threadblock_id_x, + threadblock_id_y + ); +} + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + const int NT = 32; // vx_num_threads(); + const int NW = 4; // vx_num_warps(); + const uint32_t threads_per_threadblock = NT * NW; + + // matches 4 warp capacity + const uint32_t threadblock_dim_x = 2 * BN; + const uint32_t threadblock_dim_y = 2 * BM; + const int threadblock_id = task_id / threads_per_threadblock; + const int tid_in_threadblock = task_id % threads_per_threadblock; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x; + const int threadblock_id_x = threadblock_id % dim_n_in_blocks; + const int threadblock_id_y = threadblock_id / dim_n_in_blocks; + + // "static" shared memory allocation. This would determine threadblock + // occupancy of a single cluster + // only 1 threadblock running at a time, so this is ok + float *sharedmem_per_threadblock = + (float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id; + + thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, + threadblock_dim_y, threadblock_id_x, threadblock_id_y, + threadblock_id, sharedmem_per_threadblock); +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + int NT = vx_num_threads(); + + // TODO: add support for edge-case (m, n not divisible by 16) + const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN); + + // for now, simplifying assumption of just 1 core + // vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc. + // we can thus treat 1 through NW as a single threadblock for the purposes of the optimization. + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); + return 0; +} \ No newline at end of file diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp new file mode 100644 index 00000000..5ae65809 --- /dev/null +++ b/tests/regression/sgemm_tcore/main.cpp @@ -0,0 +1,270 @@ +#include +#include +#include +#include +#include +#include +#include "common.h" + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.bin"; +uint32_t count = 0; + +std::vector src_a_data; +std::vector src_b_data; +std::vector ref_data; + +vx_device_h device = nullptr; +std::vector staging_buf; +kernel_arg_t kernel_arg = {}; + +static void show_usage() { + std::cout << "Vortex Test." << std::endl; + std::cout << "Usage: [-k: kernel] [-n words] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "n:k:h?")) != -1) { + switch (c) { + case 'n': + count = atoi(optarg); + break; + case 'k': + kernel_file = optarg; + break; + case 'h': + case '?': { + show_usage(); + exit(0); + } break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(device, kernel_arg.addr_a); + vx_mem_free(device, kernel_arg.addr_b); + vx_mem_free(device, kernel_arg.addr_c); + vx_dev_close(device); + } +} + +void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + src_a_data.resize(dim_m * dim_k); + src_b_data.resize(dim_k * dim_n); + + for (uint32_t i = 0; i < src_a_data.size(); ++i) { + src_a_data[i] = static_cast(i); + std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl; + } + for (uint32_t i = 0; i < src_b_data.size(); ++i) { + src_b_data[i] = static_cast(i); + std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl; + } +} + +void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) { + ref_data.resize(dim_m * dim_n); + + for (uint32_t i = 0; i < dim_m; ++i) { + for (uint32_t j = 0; j < dim_n; ++j) { + float ref = 0.0f; + for (uint32_t k = 0; k < dim_k; ++k) { + ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j]; + } + ref_data.at(dim_n * i + j) = ref; + } + } +} + +int run_test(const kernel_arg_t& kernel_arg, + uint32_t buf_size, + uint32_t dim_m, uint32_t dim_n) { + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download destination buffer + std::cout << "download destination buffer" << std::endl; + RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size)); + + // verify result + std::cout << "verify result" << std::endl; + { + int errors = 0; + auto buf_ptr = (float*)staging_buf.data(); + for (uint32_t i = 0; i < dim_m * dim_n; ++i) { + float ref = ref_data.at(i); + float cur = buf_ptr[i]; + if (std::abs((cur - ref) / ref) > 1e-6) { + std::cout << "error at result #" << std::dec << i + << std::hex << ": actual=" << cur << ", expected=" << ref << std::endl; + ++errors; + } + } + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + } + + return 0; +} + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + if (count == 0) { + count = 1; + } + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + // FIXME: hardcoded + uint32_t dim_m = 64; + uint32_t dim_n = 64; + uint32_t dim_k = 64; + + generate_source_matrix(dim_m, dim_n, dim_k); + generate_reference_matmul(dim_m, dim_n, dim_k); + + uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]); + uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]); + uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]); + + std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl; + + // upload program + std::cout << "upload program" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file)); + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); + RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); + RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); + + kernel_arg.dim_m = dim_m; + kernel_arg.dim_n = dim_n; + kernel_arg.dim_k = dim_k; + + std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl; + std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl; + std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl; + + // allocate staging buffer + { + std::cout << "allocate staging buffer" << std::endl; + uint32_t staging_buf_size = std::max( + src_a_buf_size, + std::max( + src_b_buf_size, + std::max(dst_buf_size, sizeof(kernel_arg_t)))); + staging_buf.resize(staging_buf_size); + } + + // upload kernel argument + { + std::cout << "upload kernel argument" << std::endl; + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, &kernel_arg, sizeof(kernel_arg_t)); + RT_CHECK(vx_copy_to_dev(device, KERNEL_ARG_DEV_MEM_ADDR, staging_buf.data(), sizeof(kernel_arg_t))); + + std::cout << "uploading argument buffer to device, device mem address=" + << std::hex << KERNEL_ARG_DEV_MEM_ADDR << ", size=" << std::dec + << sizeof(kernel_arg_t) << " bytes\n"; + std::ofstream file("args.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(staging_buf.data()), + sizeof(kernel_arg_t)); + file.close(); + } + + // upload source buffer + { + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(), + src_a_buf_size)); + + std::cout << "uploading source A matrix to device, device mem address=" + << std::hex << kernel_arg.addr_a << ", size=" << std::dec + << src_a_buf_size << " bytes\n"; + std::ofstream file("input.a.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_a_buf_size); + file.close(); + } + { + auto buf_ptr = staging_buf.data(); + memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float)); + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(), + src_b_buf_size)); + + std::cout << "uploading source B matrix to device, device mem address=" + << std::hex << kernel_arg.addr_b << ", size=" << std::dec + << src_b_buf_size << " bytes\n"; + std::ofstream file("input.b.bin", std::ios::binary | std::ios::out); + if (!file) { + std::cerr << "error: failed to open args.bin for writing\n"; + exit(EXIT_FAILURE); + } + file.write(reinterpret_cast(buf_ptr), src_b_buf_size); + file.close(); + } + } + + // clear destination buffer + { + std::cout << "clear destination buffer" << std::endl; + auto buf_ptr = (int32_t*)staging_buf.data(); + for (uint32_t i = 0; i < ref_data.size(); ++i) { + buf_ptr[i] = 0xdeadbeef; + } + RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size)); + } + + // run tests + std::cout << "run tests" << std::endl; + RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n)); + std::cout << "PASSED!" << std::endl; + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + return 0; +} \ No newline at end of file From b4c812f9f81dca4cc96cc9000d46af9c6ca09fe4 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 3 May 2024 17:27:25 -0700 Subject: [PATCH 09/55] Write expected_C to a binary file --- tests/kernel/tensor/check_correctness.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/kernel/tensor/check_correctness.py b/tests/kernel/tensor/check_correctness.py index c81212d0..de0c976a 100644 --- a/tests/kernel/tensor/check_correctness.py +++ b/tests/kernel/tensor/check_correctness.py @@ -86,10 +86,15 @@ expected_A = expected['A_array'] expected_B = expected['B_array'] expected_C = expected['C_array'] expected_C = expected_C + expected_A @ expected_B +print('expected C:') print(expected_C[0:8, 0:8]) +print('got C:') print(C_array[0:8, 0:8]) +print('diff C:') print((expected_C - C_array)[0:8, 0:8]) +expected_C.astype('float32').tofile("c_expected.bin") + assert np.allclose(expected_A, A_array) assert np.allclose(expected_B, B_array) -assert np.allclose(expected_C, C_array) \ No newline at end of file +assert np.allclose(expected_C, C_array) From 77758308146f2c2086bcb3dc2268068478fa7a8c Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 7 May 2024 16:30:30 -0700 Subject: [PATCH 10/55] Hardcode chipyard device addresses --- tests/regression/sgemm_tcore/kernel.cpp | 2 +- tests/regression/sgemm_tcore/main.cpp | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index f4e467f4..11a795df 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -282,4 +282,4 @@ int main() { // we can thus treat 1 through NW as a single threadblock for the purposes of the optimization. vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); return 0; -} \ No newline at end of file +} diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index 5ae65809..eb4f55df 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -58,9 +58,9 @@ static void parse_args(int argc, char **argv) { void cleanup() { if (device) { - vx_mem_free(device, kernel_arg.addr_a); - vx_mem_free(device, kernel_arg.addr_b); - vx_mem_free(device, kernel_arg.addr_c); + // vx_mem_free(device, kernel_arg.addr_a); + // vx_mem_free(device, kernel_arg.addr_b); + // vx_mem_free(device, kernel_arg.addr_c); vx_dev_close(device); } } @@ -166,9 +166,12 @@ int main(int argc, char *argv[]) { // allocate device memory std::cout << "allocate device memory" << std::endl; - RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); - RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); - RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); + // RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); + // RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); + // RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); + kernel_arg.addr_a = 0x20000; + kernel_arg.addr_b = 0x28000; + kernel_arg.addr_c = 0xc0000000; kernel_arg.dim_m = dim_m; kernel_arg.dim_n = dim_n; @@ -267,4 +270,4 @@ int main(int argc, char *argv[]) { cleanup(); return 0; -} \ No newline at end of file +} From 5821bfd10d88566e578a13dc58ccf3a04f0b4d9a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 8 May 2024 13:22:26 -0700 Subject: [PATCH 11/55] Repeat vx_wmma issue & hardcode dst address --- tests/kernel/reductions/main.cpp | 3 ++- tests/kernel/tensor/main.cpp | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/kernel/reductions/main.cpp b/tests/kernel/reductions/main.cpp index edde1da4..fcadddb6 100644 --- a/tests/kernel/reductions/main.cpp +++ b/tests/kernel/reductions/main.cpp @@ -138,6 +138,7 @@ void test_maxu_reduce() { y = reduced; } +// assumes NUM_THREADS == 4 unsigned bit_vectors[4] = {0b11010110000111001100010100100110, 0b10010100011010001010000000001110, 0b10001001010111110001110000000010, 0b00010011010100101101110111001111}; void test_and_reduce() { @@ -213,4 +214,4 @@ int main() return 0; -} \ No newline at end of file +} diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 5fc222b2..0fc4274d 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -65,6 +65,7 @@ float results[32*8]; void store_wmma_result() { int tid = vx_thread_id(); + float *results = reinterpret_cast(0xc0000000UL); asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); @@ -87,10 +88,13 @@ int main() { vx_tmc(-1); vx_wmma_load(); - vx_wmma(); +#pragma GCC unroll 100 + for (int i = 0; i < 100; i++) { + vx_wmma(); + } store_wmma_result(); vx_tmc(1); // print_wmma_result(); return 0; -} \ No newline at end of file +} From 6af0c305ea8cc412ead8da695b2fbe52f560c677 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 8 May 2024 13:27:11 -0700 Subject: [PATCH 12/55] Fix path to OBJCOPY --- tests/regression/common.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 24a871eb..f9d71634 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -101,7 +101,7 @@ kernel.bin: kernel.elf kernel.radiance.elf kernel.elf: $(VX_SRCS) $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o kernel.elf -OBJCOPY ?= "riscv32-unknown-elf-objcopy" +OBJCOPY ?= $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy OBJCOPY_FLAGS ?= "LOAD,ALLOC,DATA,CONTENTS" kernel.radiance.elf: kernel.elf $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o kernel.radiance.elf From 8a521a1de89b1d93cdb9c66d8b8fbc551e0db7df Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 10 May 2024 23:23:11 -0700 Subject: [PATCH 13/55] Add 8-lane operand mapping --- tests/kernel/tensor/main.cpp | 166 ++++++++++++++++++++++++++--------- 1 file changed, 124 insertions(+), 42 deletions(-) diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 0fc4274d..2fadbbd9 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -4,35 +4,103 @@ #include #include +constexpr int DIM_M = 16; + inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } #include "test_data.h" +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + void vx_wmma_load() { int tid = vx_thread_id(); int tg = tid / 4; - // load A - int row = tid % 4; - row += (tg * 8) % 16; - row += (tg / 4) * 4; + int row = 0; + int col = 0; + map_operand_32lanes(tid, row, col); + + // load A + // each operand element is read twice by two threadgroups (Sec. III-B); + // i.e. 8 regs * 32 lanes = 256 fp32 elements = 2 * (16 * 8) elements asm volatile ("flw f0, %0" :: "m"(A[row][0])); asm volatile ("flw f1, %0" :: "m"(A[row][1])); - asm volatile ("flw f2, %0" :: "m"(A[row][2])); - asm volatile ("flw f3, %0" :: "m"(A[row][3])); - asm volatile ("flw f4, %0" :: "m"(A[row][4])); - asm volatile ("flw f5, %0" :: "m"(A[row][5])); - asm volatile ("flw f6, %0" :: "m"(A[row][6])); - asm volatile ("flw f7, %0" :: "m"(A[row][7])); - - // load B - int col = tid % 4; - col += ((tg % 4) / 2) * 8; - col += (tg / 4) * 4; + asm volatile ("flw f2, %0" :: "m"(A[row][2])); + asm volatile ("flw f3, %0" :: "m"(A[row][3])); + asm volatile ("flw f4, %0" :: "m"(A[row][4])); + asm volatile ("flw f5, %0" :: "m"(A[row][5])); + asm volatile ("flw f6, %0" :: "m"(A[row][6])); + asm volatile ("flw f7, %0" :: "m"(A[row][7])); + // load B asm volatile ("flw f8 , %0" :: "m"(B[0][col])); asm volatile ("flw f9 , %0" :: "m"(B[1][col])); asm volatile ("flw f10, %0" :: "m"(B[2][col])); @@ -42,14 +110,9 @@ void vx_wmma_load() { asm volatile ("flw f14, %0" :: "m"(B[6][col])); asm volatile ("flw f15, %0" :: "m"(B[7][col])); - // load C - col = ((tg % 4) / 2) * 8; - row = (tg * 8) % 16; - row += (tg / 4) * 4; - - row += (tid % 4) % 2; - col += ((tid % 4) / 2) * 2; + map_c_32lanes(tid, row, col); + // load C asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); @@ -60,38 +123,57 @@ void vx_wmma_load() { asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); } -float results[32*8]; +// float results[32*8]; +float *const results = reinterpret_cast(0xc0000000UL); void store_wmma_result() { int tid = vx_thread_id(); - - float *results = reinterpret_cast(0xc0000000UL); - asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); - asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); - asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); - asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); - asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); - asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); - asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); - asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + int tg = tid / 4; + + int row = 0; + int col = 0; + + map_c_32lanes(tid, row, col); + + // store C + // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); + // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); + // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); + // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); + // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); + // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); + // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); + // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + + asm volatile ("fsw f16, %0" :: "m"(results[DIM_M * (row + 0) + (col + 0)])); + asm volatile ("fsw f17, %0" :: "m"(results[DIM_M * (row + 0) + (col + 1)])); + asm volatile ("fsw f18, %0" :: "m"(results[DIM_M * (row + 2) + (col + 0)])); + asm volatile ("fsw f19, %0" :: "m"(results[DIM_M * (row + 2) + (col + 1)])); + asm volatile ("fsw f20, %0" :: "m"(results[DIM_M * (row + 0) + (col + 4)])); + asm volatile ("fsw f21, %0" :: "m"(results[DIM_M * (row + 0) + (col + 5)])); + asm volatile ("fsw f22, %0" :: "m"(results[DIM_M * (row + 2) + (col + 4)])); + asm volatile ("fsw f23, %0" :: "m"(results[DIM_M * (row + 2) + (col + 5)])); } void print_wmma_result() { - for (int tid = 0; tid < 32; tid += 1) { - for (int reg = 0; reg < 8; reg += 1) { - vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); - } - } + const int num_threads = vx_num_threads(); + + for (int tid = 0; tid < num_threads; tid += 1) { + for (int reg = 0; reg < 8; reg += 1) { + vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); + } + } } int main() { vx_tmc(-1); vx_wmma_load(); -#pragma GCC unroll 100 - for (int i = 0; i < 100; i++) { - vx_wmma(); - } +// #pragma GCC unroll 100 +// for (int i = 0; i < 100; i++) { +// vx_wmma(); +// } + vx_wmma(); store_wmma_result(); vx_tmc(1); // print_wmma_result(); From 5c298c81df683ad891a09433ab2a3b4cc89be448 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 12 May 2024 22:22:54 -0700 Subject: [PATCH 14/55] sgemm_tg: Use reg mapping functions --- tests/regression/sgemm_tcore/.gitignore | 1 + tests/regression/sgemm_tcore/kernel.cpp | 222 ++++++++++++++---------- 2 files changed, 136 insertions(+), 87 deletions(-) create mode 100644 tests/regression/sgemm_tcore/.gitignore diff --git a/tests/regression/sgemm_tcore/.gitignore b/tests/regression/sgemm_tcore/.gitignore new file mode 100644 index 00000000..6ef379cc --- /dev/null +++ b/tests/regression/sgemm_tcore/.gitignore @@ -0,0 +1 @@ +sgemm_tcore diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 11a795df..f498f57b 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -10,18 +10,85 @@ #define BN 16 #define BK 8 -inline void vx_wmma() { - asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; } -void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) { - int tid = thread_in_warp; - int tg = tid / 4; +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; - // load A - int row = tid % 4; - row += (tg * 8) % 16; - row += (tg / 4) * 4; + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, + int warp_y, int thread_in_warp) { + int tid = thread_in_warp; + int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand_32lanes(tid, row, col); int smem_A_m = 32; int smem_A_n = 8; @@ -30,101 +97,83 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, in int A_offset = (row + BM * warp_y) * smem_A_n; - asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0])); - asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1])); - asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2])); - asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3])); - asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4])); - asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5])); - asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6])); - asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7])); + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); - // load B - int col = tid % 4; - col += ((tg % 4) / 2) * 8; - col += (tg / 4) * 4; - - asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); - asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); } inline void initialize_C() { // initialize C to zeros - asm volatile ("fmv.w.x f16, x0"); - asm volatile ("fmv.w.x f17, x0"); - asm volatile ("fmv.w.x f18, x0"); - asm volatile ("fmv.w.x f19, x0"); - asm volatile ("fmv.w.x f20, x0"); - asm volatile ("fmv.w.x f21, x0"); - asm volatile ("fmv.w.x f22, x0"); - asm volatile ("fmv.w.x f23, x0"); + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); } -inline void write_results( - volatile float *local_warp_results, - int thread_in_warp, - int warp_x, - int warp_y, - int dim_m, - int dim_n, - float *C, - int threadblock_id_x, - int threadblock_id_y -) { +inline void write_results(volatile float *local_warp_results, + int thread_in_warp, int warp_x, int warp_y, int dim_m, + int dim_n, float *C, int threadblock_id_x, + int threadblock_id_y) { int tid = thread_in_warp; - int tg = tid / 4; + int tg = tid / 4; - asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0])); - asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1])); - asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2])); - asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3])); - asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4])); - asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5])); - asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6])); - asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7])); + asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0])); + asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1])); + asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2])); + asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3])); + asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4])); + asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5])); + asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6])); + asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7])); /* - col = ((threadgroup % 4) // 2) * 8 - row = (threadgroup * 8) % 16 - row += (threadgroup // 4) * 4 - offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] - offset = offsets[register-16] - row += offset[0] - col += offset[1] - thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] - thread_offset = thread_offsets[thread % 4] - row += thread_offset[0] - col += thread_offset[1] - return (row, col) - */ + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + return (row, col) + */ - int local_col = ((tg % 4) / 2) * 8; - int local_row = (tg * 8) % 16; - local_row += (tg / 4) * 4; + int local_row = 0; + int local_col = 0; + map_c_32lanes(tid, local_row, local_col); - // int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2}; - // int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5}; - - // int thread_row_offsets[4] = {0, 1, 0, 1}; - // int thread_col_offsets[4] = {0, 0, 2, 2}; - int thread_row_offset = (tid % 4) % 2; - int thread_col_offset = ((tid % 4) / 2) * 2; - - float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM; + float *global_offset_C = C + + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + + threadblock_id_x * BN * 2 + warp_x * BM; for (int i = 0; i < 8; i += 1) { int row_offset = ((i / 2) % 2) * 2; int col_offset = (i / 4) * 4 + i % 2; - int adjusted_local_row = local_row + thread_row_offset + row_offset; - int adjusted_local_col = local_col + thread_col_offset + col_offset; + int adjusted_local_row = local_row + row_offset; + int adjusted_local_col = local_col + col_offset; - float v = local_warp_results[tid*8+i]; + float v = local_warp_results[tid * 8 + i]; global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; } } @@ -174,7 +223,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; const uint32_t local_b_row = warp_in_threadblock; const uint32_t local_b_col = tid_in_warp; - volatile float *local_a = sharedmem_per_threadblock; const size_t local_a_elems = (threadblock_dim_y * BK); From 9e60b1834c519e7ef2d22077f5b04c01311126f0 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 13 May 2024 13:22:06 -0700 Subject: [PATCH 15/55] sgemm_tcore: Rewrite with sgemm_Wg parametrization --- .../regression/sgemm_tcore/kernel.4warps.cpp | 333 ++++++++++++++++++ tests/regression/sgemm_tcore/kernel.cpp | 271 ++++++++------ 2 files changed, 503 insertions(+), 101 deletions(-) create mode 100644 tests/regression/sgemm_tcore/kernel.4warps.cpp diff --git a/tests/regression/sgemm_tcore/kernel.4warps.cpp b/tests/regression/sgemm_tcore/kernel.4warps.cpp new file mode 100644 index 00000000..f498f57b --- /dev/null +++ b/tests/regression/sgemm_tcore/kernel.4warps.cpp @@ -0,0 +1,333 @@ +#define RISCV_CUSTOM3 0x7B + +#include +#include +#include +#include +#include "common.h" + +#define BM 16 +#define BN 16 +#define BK 8 + +inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // Figure 7(a) in paper + // row 0~ 3: threadgroups 0 and 2 + // row 4~ 7: threadgroups 4 and 6 + // row 8~11: threadgroups 1 and 3 + // row 12~15: threadgroups 5 and 7 + row = tid % 4; + row += (tg * 8) % 16; + row += (tg / 4) * 4; + + // B (column major) + // NOTE: Matrix B mapping in Figure 7(a) is incorrect; below is the + // corrected mapping: + // col 0~ 3: threadgroups 0 and 1 + // col 4~ 7: threadgroups 4 and 5 + // col 8~11: threadgroups 2 and 3 + // col 12~15: threadgroups 6 and 7 + col = tid % 4; + col += ((tg % 4) / 2) * 8; + col += (tg / 4) * 4; +} + +inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // A (row major) + // row 0~ 3: threadgroup 0 + // row 4~ 7: threadgroup 1 + row = tid % 4; + row += tg * 4; + + // B (column major) + // col 0~ 3: threadgroup 0 + // col 4~ 7: threadgroup 1 + col = tid % 4; + col += tg * 4; +} + +inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + // Figure 7(b), left + col = ((tg % 4) / 2) * 8; + row = (tg * 8) % 16; + row += (tg / 4) * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { + const int tg = tid / 4; + + // C + col = 0; + row = tg * 4; + + // Figure 7(b), right + row += (tid % 4) % 2; + col += ((tid % 4) / 2) * 2; +} + +inline void vx_wmma() { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, + int warp_y, int thread_in_warp) { + int tid = thread_in_warp; + int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand_32lanes(tid, row, col); + + int smem_A_m = 32; + int smem_A_n = 8; + int smem_B_m = 8; + int smem_B_n = 32; + + int A_offset = (row + BM * warp_y) * smem_A_n; + + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); + + asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); +} + +inline void initialize_C() { + // initialize C to zeros + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); +} + +inline void write_results(volatile float *local_warp_results, + int thread_in_warp, int warp_x, int warp_y, int dim_m, + int dim_n, float *C, int threadblock_id_x, + int threadblock_id_y) { + int tid = thread_in_warp; + int tg = tid / 4; + + asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0])); + asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1])); + asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2])); + asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3])); + asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4])); + asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5])); + asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6])); + asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7])); + + /* + col = ((threadgroup % 4) // 2) * 8 + row = (threadgroup * 8) % 16 + row += (threadgroup // 4) * 4 + offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] + offset = offsets[register-16] + row += offset[0] + col += offset[1] + thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] + thread_offset = thread_offsets[thread % 4] + row += thread_offset[0] + col += thread_offset[1] + return (row, col) + */ + + int local_row = 0; + int local_col = 0; + map_c_32lanes(tid, local_row, local_col); + + float *global_offset_C = C + + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + + threadblock_id_x * BN * 2 + warp_x * BM; + for (int i = 0; i < 8; i += 1) { + int row_offset = ((i / 2) % 2) * 2; + int col_offset = (i / 4) * 4 + i % 2; + + int adjusted_local_row = local_row + row_offset; + int adjusted_local_col = local_col + col_offset; + + float v = local_warp_results[tid * 8 + i]; + global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; + } +} + +void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { + vx_fence(); + vx_barrier(barrier_id, count); +} + +void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_dim_x, + const uint32_t threadblock_dim_y, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, + const uint32_t threadblock_id, + float *sharedmem_per_threadblock) { + const float *A = (const float *)arg->addr_a; + const float *B = (const float *)arg->addr_b; + float *C = (float *)arg->addr_c; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_k = arg->dim_k; + + // FIXME: Output block size is assumed to be square, i.e. BM == BN + // const uint32_t BM = threadblock_dim_y; + // const uint32_t BN = threadblock_dim_y; + // const uint32_t BK = threadblock_dim_x; + // constexpr uint32_t BM = 8; + // constexpr uint32_t BN = 8; + // constexpr uint32_t BK = 2; + + const uint32_t warp_in_threadblock = tid_in_threadblock / 32; + const uint32_t tid_in_warp = tid_in_threadblock % 32; + const uint32_t warp_x = warp_in_threadblock % 2; + const uint32_t warp_y = warp_in_threadblock / 2; + + const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y; + + // 32 * 8 block of A, being loaded by 4 warps + const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK); + const uint32_t local_a_col = tid_in_warp % BK; + + // 8 * 32 block of B, being loaded by 4 warps + // do a fat coalesced load + const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; + const uint32_t local_b_row = warp_in_threadblock; + const uint32_t local_b_col = tid_in_warp; + + volatile float *local_a = sharedmem_per_threadblock; + const size_t local_a_elems = (threadblock_dim_y * BK); + volatile float *local_b = sharedmem_per_threadblock + local_a_elems; + const size_t local_b_elems = (threadblock_dim_x * BK); + volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN); + + // clear out C + initialize_C(); + + for (uint32_t k = 0; k < dim_k; k += BK) { + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable) + + // coalesced load from global memory -> unit-stride store into shared memory + uint32_t global_a_offset = + dim_k * (global_a_row + local_a_row) + (k + local_a_col); + uint32_t shared_a_offset = + BK * local_a_row + local_a_col; + + local_a[shared_a_offset] = A[global_a_offset]; + + global_a_offset += dim_k * (BM / 4); + shared_a_offset += BK * (BM / 4); + + local_a[shared_a_offset] = A[global_a_offset]; + + uint32_t global_b_offset = + dim_n * (k + local_b_row) + (global_b_col + local_b_col); + uint32_t shared_b_offset = + (BN * 2) * (local_b_row) + local_b_col; + + local_b[shared_b_offset] = B[global_b_offset]; + + global_b_offset += dim_n * (BK / 2); + shared_b_offset += (BN * 2) * (BK / 2); + + local_b[shared_b_offset] = B[global_b_offset]; + + // want all 4 warps to reach barrier before moving on (just use barrier 0 for now) + threadblock_barrier(tid_in_threadblock, 0, 4); + + // perform wmma + vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + vx_wmma(); + + // same as above + threadblock_barrier(tid_in_threadblock, 0, 4); + } + + write_results( + local_warp_results, + tid_in_warp, + warp_x, + warp_y, + dim_m, + dim_n, + C, + threadblock_id_x, + threadblock_id_y + ); +} + +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { + // @perf: All threads are running these compute whose result is mostly same + // across the threadblock + const int NT = 32; // vx_num_threads(); + const int NW = 4; // vx_num_warps(); + const uint32_t threads_per_threadblock = NT * NW; + + // matches 4 warp capacity + const uint32_t threadblock_dim_x = 2 * BN; + const uint32_t threadblock_dim_y = 2 * BM; + const int threadblock_id = task_id / threads_per_threadblock; + const int tid_in_threadblock = task_id % threads_per_threadblock; + + const uint32_t dim_m = arg->dim_m; + const uint32_t dim_n = arg->dim_n; + const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x; + const int threadblock_id_x = threadblock_id % dim_n_in_blocks; + const int threadblock_id_y = threadblock_id / dim_n_in_blocks; + + // "static" shared memory allocation. This would determine threadblock + // occupancy of a single cluster + // only 1 threadblock running at a time, so this is ok + float *sharedmem_per_threadblock = + (float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id; + + thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, + threadblock_dim_y, threadblock_id_x, threadblock_id_y, + threadblock_id, sharedmem_per_threadblock); +} + +int main() { + kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; + int NT = vx_num_threads(); + + // TODO: add support for edge-case (m, n not divisible by 16) + const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN); + + // for now, simplifying assumption of just 1 core + // vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc. + // we can thus treat 1 through NW as a single threadblock for the purposes of the optimization. + vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); + return 0; +} diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index f498f57b..99621ec2 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -6,9 +6,25 @@ #include #include "common.h" +// Constraints on parameters: +// * Memory: +// (BM + BN) * BK * sizeof(float) <= sharedmem size. +// BM * BK == BN * BK >= threadblock size >= NT * CORES_PER_CLUSTER +// When larger, the kernel runs a sequential loop to read into sharedmem; +// but smaller case is not handled. +// * Compute: +// ( M* N) / (TM*TN) == grid size >= NC*NW*NT +// (BM*BN) / (TM*TN) == threadblock size < NT * NW * CORES_PER_CLUSTER +// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER +// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields +// BM <= BK*TM*TN #define BM 16 -#define BN 16 +#define BN BM #define BK 8 +#define TCM 16 +#define TCN 16 +#define TM 1 +#define TN 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -90,12 +106,12 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int col = 0; map_operand_32lanes(tid, row, col); - int smem_A_m = 32; - int smem_A_n = 8; - int smem_B_m = 8; - int smem_B_n = 32; + int smem_A_rows = BM; + int smem_A_cols = BK; + int smem_B_rows = BK; + int smem_B_cols = BN; - int A_offset = (row + BM * warp_y) * smem_A_n; + int A_offset = (row + TCM * warp_y) * smem_A_cols; asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); @@ -106,14 +122,14 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); - asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_n) + warp_x * BN + col])); - asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_n) + warp_x * BN + col])); + asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_cols) + warp_x * TCN + col])); } inline void initialize_C() { @@ -163,9 +179,15 @@ inline void write_results(volatile float *local_warp_results, int local_col = 0; map_c_32lanes(tid, local_row, local_col); + // C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + + // (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = + // reg_c[TN * res_idx_m + res_idx_n]; + // float *global_offset_C = C + + // (threadblock_id_y * TCM * 2 + warp_y * TCM) * dim_n + + // threadblock_id_x * TCN * 2 + warp_x * TCN; float *global_offset_C = C + - (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + - threadblock_id_x * BN * 2 + warp_x * BM; + (BM * threadblock_id_y /* 1 warp */) * dim_n + + BN * threadblock_id_x /* 1 warp */; for (int i = 0; i < 8; i += 1) { int row_offset = ((i / 2) % 2) * 2; int col_offset = (i / 4) * 4 + i % 2; @@ -173,6 +195,7 @@ inline void write_results(volatile float *local_warp_results, int adjusted_local_row = local_row + row_offset; int adjusted_local_col = local_col + col_offset; + // FIXME: do we need to store to SMEM at all? float v = local_warp_results[tid * 8 + i]; global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; } @@ -184,17 +207,18 @@ void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_i } void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, - const uint32_t tid_in_threadblock, - const uint32_t threadblock_dim_x, - const uint32_t threadblock_dim_y, - const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y, - const uint32_t threadblock_id, - float *sharedmem_per_threadblock) { + const uint32_t tid_in_threadblock, + const uint32_t threadblock_dim_x, + const uint32_t threadblock_dim_y, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, + const uint32_t threadblock_id_in_cluster, + float *sharedmem_per_threadblock) { const float *A = (const float *)arg->addr_a; const float *B = (const float *)arg->addr_b; float *C = (float *)arg->addr_c; + // assumes NT == NW == matrix_dim const uint32_t dim_m = arg->dim_m; const uint32_t dim_n = arg->dim_n; const uint32_t dim_k = arg->dim_k; @@ -207,28 +231,38 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // constexpr uint32_t BN = 8; // constexpr uint32_t BK = 2; + const uint32_t local_a_row = tid_in_threadblock / BK; + const uint32_t local_a_col = tid_in_threadblock % BK; + const uint32_t local_b_row = tid_in_threadblock / BN; + const uint32_t local_b_col = tid_in_threadblock % BN; + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; + + const uint32_t local_c_row = tid_in_threadblock / (BN / TN); + const uint32_t local_c_col = tid_in_threadblock % (BN / TN); + + // each thread generates TM output element + float reg_c[TM * TN] = { 0.0f }; + float reg_a[TM] = { 0.0f }; + float reg_b[TN] = { 0.0f }; + const uint32_t warp_in_threadblock = tid_in_threadblock / 32; const uint32_t tid_in_warp = tid_in_threadblock % 32; const uint32_t warp_x = warp_in_threadblock % 2; const uint32_t warp_y = warp_in_threadblock / 2; - const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y; - - // 32 * 8 block of A, being loaded by 4 warps - const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK); - const uint32_t local_a_col = tid_in_warp % BK; - - // 8 * 32 block of B, being loaded by 4 warps - // do a fat coalesced load - const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x; - const uint32_t local_b_row = warp_in_threadblock; - const uint32_t local_b_col = tid_in_warp; - volatile float *local_a = sharedmem_per_threadblock; - const size_t local_a_elems = (threadblock_dim_y * BK); + // const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; + // FIXME: this better be BM * BK, but the GMEM->SMEM load assumes all threads + // in TB participates in the load + const size_t local_a_elems = (BM * BN); volatile float *local_b = sharedmem_per_threadblock + local_a_elems; - const size_t local_b_elems = (threadblock_dim_x * BK); - volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN); + const size_t local_b_elems = (BM * BN); + volatile float *local_warp_results = + local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); + + constexpr uint32_t stride_a = (BM * BN) / BK / (TM * TN); + constexpr uint32_t stride_b = (BM * BN) / BN / (TM * TN); // clear out C initialize_C(); @@ -237,97 +271,132 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // Data move from GMEM to SMEM // // Make sure global offset values for A and B are contiguous between - // neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable) - - // coalesced load from global memory -> unit-stride store into shared memory - uint32_t global_a_offset = - dim_k * (global_a_row + local_a_row) + (k + local_a_col); - uint32_t shared_a_offset = - BK * local_a_row + local_a_col; - - local_a[shared_a_offset] = A[global_a_offset]; - - global_a_offset += dim_k * (BM / 4); - shared_a_offset += BK * (BM / 4); - - local_a[shared_a_offset] = A[global_a_offset]; + // neighboring threads to ensure GMEM coalescing. +#pragma GCC unroll 2 + for (uint32_t load_offset = 0; load_offset < BM; load_offset += stride_a) { + const uint32_t global_a_offset = + dim_k * (global_a_row + load_offset) + (k + local_a_col); + // FIXME: all threads in TB (BM*BN) will do this load, even if this is + // out-of-bounds of BM*BK + local_a[BK * (local_a_row + load_offset) + local_a_col] = + A[global_a_offset]; + } +#pragma GCC unroll 2 + for (uint32_t load_offset = 0; load_offset < BK; load_offset += stride_b) { + const uint32_t global_b_offset = + dim_n * (k + local_b_row + load_offset) + global_b_col; + local_b[BN * (local_b_row + load_offset) + local_b_col] = + B[global_b_offset]; + } - uint32_t global_b_offset = - dim_n * (k + local_b_row) + (global_b_col + local_b_col); - uint32_t shared_b_offset = - (BN * 2) * (local_b_row) + local_b_col; - - local_b[shared_b_offset] = B[global_b_offset]; - - global_b_offset += dim_n * (BK / 2); - shared_b_offset += (BN * 2) * (BK / 2); - - local_b[shared_b_offset] = B[global_b_offset]; - - // want all 4 warps to reach barrier before moving on (just use barrier 0 for now) - threadblock_barrier(tid_in_threadblock, 0, 4); + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); // perform wmma - vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - vx_wmma(); - - // same as above - threadblock_barrier(tid_in_threadblock, 0, 4); + // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + // FIXME: If multiple warps try to issue to Tensor Core at the same time, + // does one stall the other? + if (warp_in_threadblock == 0) { + vx_wmma_load(local_a, local_b, 0, 0, tid_in_warp); + vx_wmma(); + } + +#if 0 + // Compute single tile*tile matmul +#pragma GCC unroll 4 + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + + // Next, compute multiple result elements (TM*TN) by reusing data in RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + } + } +#endif + + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); } - write_results( - local_warp_results, - tid_in_warp, - warp_x, - warp_y, - dim_m, - dim_n, - C, - threadblock_id_x, - threadblock_id_y - ); + if (warp_in_threadblock == 0) { + write_results( + local_warp_results, + tid_in_warp, + // warp_x, + // warp_y, + 0, + 0, + dim_m, + dim_n, + C, + threadblock_id_x, + threadblock_id_y + ); + } } void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock - const int NT = 32; // vx_num_threads(); - const int NW = 4; // vx_num_warps(); - const uint32_t threads_per_threadblock = NT * NW; - // matches 4 warp capacity - const uint32_t threadblock_dim_x = 2 * BN; - const uint32_t threadblock_dim_y = 2 * BM; + const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN); +#ifdef RADIANCE + const uint32_t threadblocks_per_core = vx_num_threads() * vx_num_warps() / + threads_per_threadblock * + CORES_PER_CLUSTER; +#else + const uint32_t threadblocks_per_core = + vx_num_threads() * vx_num_warps() / threads_per_threadblock; +#endif + const uint32_t threadblock_dim_x = vx_num_threads(); + const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_core; const int threadblock_id = task_id / threads_per_threadblock; + const int threadblock_id_in_cluster = threadblock_id % threadblocks_per_core; const int tid_in_threadblock = task_id % threads_per_threadblock; const uint32_t dim_m = arg->dim_m; const uint32_t dim_n = arg->dim_n; - const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x; + const uint32_t dim_n_in_blocks = dim_n / BN; const int threadblock_id_x = threadblock_id % dim_n_in_blocks; const int threadblock_id_y = threadblock_id / dim_n_in_blocks; // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster - // only 1 threadblock running at a time, so this is ok float *sharedmem_per_threadblock = - (float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id; - + (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, threadblock_dim_y, threadblock_id_x, threadblock_id_y, - threadblock_id, sharedmem_per_threadblock); + threadblock_id_in_cluster, sharedmem_per_threadblock); } int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - int NT = vx_num_threads(); - - // TODO: add support for edge-case (m, n not divisible by 16) - const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN); - - // for now, simplifying assumption of just 1 core - // vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc. - // we can thus treat 1 through NW as a single threadblock for the purposes of the optimization. + const uint32_t grid_size = arg->dim_m * arg->dim_n / (TM * TN); +#ifdef RADIANCE + vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#else + // NOTE: This kernel assumes contiguous thread scheduling for efficient shared + // memory allocation, and therefore does not work with original vx_spawn_tasks vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); +#endif return 0; } From d848e88f728555c1b5544664c94f10948eb19528 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 13 May 2024 14:00:50 -0700 Subject: [PATCH 16/55] sgemm_tcore: Move C from regF->GMEM directly --- tests/regression/sgemm_tcore/kernel.cpp | 55 ++++++------------------- 1 file changed, 13 insertions(+), 42 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 99621ec2..8913d95a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -151,54 +151,25 @@ inline void write_results(volatile float *local_warp_results, int tid = thread_in_warp; int tg = tid / 4; - asm volatile("fsw f16, %0" ::"m"(local_warp_results[tid * 8 + 0])); - asm volatile("fsw f17, %0" ::"m"(local_warp_results[tid * 8 + 1])); - asm volatile("fsw f18, %0" ::"m"(local_warp_results[tid * 8 + 2])); - asm volatile("fsw f19, %0" ::"m"(local_warp_results[tid * 8 + 3])); - asm volatile("fsw f20, %0" ::"m"(local_warp_results[tid * 8 + 4])); - asm volatile("fsw f21, %0" ::"m"(local_warp_results[tid * 8 + 5])); - asm volatile("fsw f22, %0" ::"m"(local_warp_results[tid * 8 + 6])); - asm volatile("fsw f23, %0" ::"m"(local_warp_results[tid * 8 + 7])); - - /* - col = ((threadgroup % 4) // 2) * 8 - row = (threadgroup * 8) % 16 - row += (threadgroup // 4) * 4 - offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)] - offset = offsets[register-16] - row += offset[0] - col += offset[1] - thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)] - thread_offset = thread_offsets[thread % 4] - row += thread_offset[0] - col += thread_offset[1] - return (row, col) - */ - + // these are [0, TCM/TCN) int local_row = 0; int local_col = 0; + map_c_32lanes(tid, local_row, local_col); - // C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + - // (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = - // reg_c[TN * res_idx_m + res_idx_n]; - // float *global_offset_C = C + - // (threadblock_id_y * TCM * 2 + warp_y * TCM) * dim_n + - // threadblock_id_x * TCN * 2 + warp_x * TCN; float *global_offset_C = C + - (BM * threadblock_id_y /* 1 warp */) * dim_n + - BN * threadblock_id_x /* 1 warp */; - for (int i = 0; i < 8; i += 1) { - int row_offset = ((i / 2) % 2) * 2; - int col_offset = (i / 4) * 4 + i % 2; + (BM * threadblock_id_y) * dim_n + + BN * threadblock_id_x; - int adjusted_local_row = local_row + row_offset; - int adjusted_local_col = local_col + col_offset; - - // FIXME: do we need to store to SMEM at all? - float v = local_warp_results[tid * 8 + i]; - global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v; - } + // @perf: this likely causes a lot of gmem bank conflicts + asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { From 09b23ffe87a8398f9accba2a63e90aacc55ae4ee Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 13 May 2024 14:52:33 -0700 Subject: [PATCH 17/55] sgemm_tg: 1-octet 8-lane kernel --- tests/regression/sgemm_tcore/kernel.cpp | 78 +++++++++---------------- 1 file changed, 29 insertions(+), 49 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 8913d95a..1484d555 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -18,14 +18,16 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 16 +#define BM 8 #define BN BM #define BK 8 -#define TCM 16 -#define TCN 16 +#define TCM 8 +#define TCN 8 #define TM 1 #define TN 1 +#define NUM_LANES 8 + inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -67,6 +69,16 @@ inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { col += tg * 4; } +inline constexpr void map_operand(const int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_operand_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_operand_8lanes(tid, row, col); + } else { + // FIXME: not allowed + } +} + inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -93,6 +105,16 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } +inline constexpr void map_c(const int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_c_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_c_8lanes(tid, row, col); + } else { + // FIXME: not allowed + } +} + inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } @@ -104,7 +126,7 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int row = 0; int col = 0; - map_operand_32lanes(tid, row, col); + map_operand(tid, row, col); int smem_A_rows = BM; int smem_A_cols = BK; @@ -154,8 +176,7 @@ inline void write_results(volatile float *local_warp_results, // these are [0, TCM/TCN) int local_row = 0; int local_col = 0; - - map_c_32lanes(tid, local_row, local_col); + map_c(tid, local_row, local_col); float *global_offset_C = C + (BM * threadblock_id_y) * dim_n + @@ -189,19 +210,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const float *B = (const float *)arg->addr_b; float *C = (float *)arg->addr_c; - // assumes NT == NW == matrix_dim const uint32_t dim_m = arg->dim_m; const uint32_t dim_n = arg->dim_n; const uint32_t dim_k = arg->dim_k; - // FIXME: Output block size is assumed to be square, i.e. BM == BN - // const uint32_t BM = threadblock_dim_y; - // const uint32_t BN = threadblock_dim_y; - // const uint32_t BK = threadblock_dim_x; - // constexpr uint32_t BM = 8; - // constexpr uint32_t BN = 8; - // constexpr uint32_t BK = 2; - const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_b_row = tid_in_threadblock / BN; @@ -217,8 +229,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, float reg_a[TM] = { 0.0f }; float reg_b[TN] = { 0.0f }; - const uint32_t warp_in_threadblock = tid_in_threadblock / 32; - const uint32_t tid_in_warp = tid_in_threadblock % 32; + const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; + const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; const uint32_t warp_x = warp_in_threadblock % 2; const uint32_t warp_y = warp_in_threadblock / 2; @@ -272,38 +284,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_wmma(); } -#if 0 - // Compute single tile*tile matmul -#pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - - // Next, compute multiple result elements (TM*TN) by reusing data in RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - } - } -#endif - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); } From 9d2b533d5c0baafd44ea15e9aef9f917a711db51 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 13 May 2024 16:48:13 -0700 Subject: [PATCH 18/55] sgemm_tg: Do operand elf stitching for kernel.elf as well --- tests/regression/common.mk | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/regression/common.mk b/tests/regression/common.mk index f9d71634..82aefd4d 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -98,17 +98,21 @@ endif kernel.bin: kernel.elf kernel.radiance.elf $(VX_CP) -O binary kernel.elf kernel.bin -kernel.elf: $(VX_SRCS) - $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o kernel.elf - OBJCOPY ?= $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy OBJCOPY_FLAGS ?= "LOAD,ALLOC,DATA,CONTENTS" -kernel.radiance.elf: kernel.elf - $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o kernel.radiance.elf - $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) kernel.radiance.elf - $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) kernel.radiance.elf - $(OBJCOPY) --update-section .operand.a=input.a.bin kernel.radiance.elf - $(OBJCOPY) --update-section .operand.b=input.b.bin kernel.radiance.elf +kernel.elf: $(VX_SRCS) + $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o $@ + $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --update-section .operand.a=input.a.bin $@ + $(OBJCOPY) --update-section .operand.b=input.b.bin $@ + +kernel.radiance.elf: $(VX_SRCS) + $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o $@ + $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --update-section .operand.a=input.a.bin $@ + $(OBJCOPY) --update-section .operand.b=input.b.bin $@ ifneq ($(CONFIG),) kernel.radiance$(CONFIGEXT).elf: kernel.radiance.elf From 5de8e7c33aee1cad8ff63746af2c9100c4b7248e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 13 May 2024 23:09:57 -0700 Subject: [PATCH 19/55] sgemm_tg: Fix device address to use ELF operands --- tests/regression/sgemm_tcore/main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index eb4f55df..e34b7066 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -169,8 +169,8 @@ int main(int argc, char *argv[]) { // RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a)); // RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b)); // RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c)); - kernel_arg.addr_a = 0x20000; - kernel_arg.addr_b = 0x28000; + kernel_arg.addr_a = 0xa0000000; + kernel_arg.addr_b = 0xa1000000; kernel_arg.addr_c = 0xc0000000; kernel_arg.dim_m = dim_m; From df1aa62916d360d22382f1d887109596efbc70d7 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 15 May 2024 15:23:26 -0700 Subject: [PATCH 20/55] sgemm_tcore: Add warptiling parameters FIXME: accumulation is done wrong --- tests/regression/sgemm_tcore/kernel.cpp | 100 +++++++++++++----------- 1 file changed, 54 insertions(+), 46 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 1484d555..090df810 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -6,6 +6,8 @@ #include #include "common.h" +#define NUM_LANES 8 + // Constraints on parameters: // * Memory: // (BM + BN) * BK * sizeof(float) <= sharedmem size. @@ -23,10 +25,14 @@ #define BK 8 #define TCM 8 #define TCN 8 +#define WM 8 +#define WN 8 +#define WMITER (WM / TCM) +#define WNITER (WN / TCN) #define TM 1 +// #define TN ((TCM * TCN) / NUM_LANES / TM) #define TN 1 -#define NUM_LANES 8 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -119,8 +125,9 @@ inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } -void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, - int warp_y, int thread_in_warp) { +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col, + int warp_row, int wn_iter, int wm_iter, + int thread_in_warp) { int tid = thread_in_warp; int tg = tid / 4; @@ -133,7 +140,7 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int smem_B_rows = BK; int smem_B_cols = BN; - int A_offset = (row + TCM * warp_y) * smem_A_cols; + int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols; asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); @@ -144,14 +151,14 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); - asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_cols) + warp_x * TCN + col])); - asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_cols) + warp_x * TCN + col])); + asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } inline void initialize_C() { @@ -167,16 +174,20 @@ inline void initialize_C() { } inline void write_results(volatile float *local_warp_results, - int thread_in_warp, int warp_x, int warp_y, int dim_m, - int dim_n, float *C, int threadblock_id_x, + int thread_in_warp, int warp_col, int warp_row, + int wn_iter, int wm_iter, int dim_m, int dim_n, + float *C, int threadblock_id_x, int threadblock_id_y) { int tid = thread_in_warp; int tg = tid / 4; // these are [0, TCM/TCN) - int local_row = 0; - int local_col = 0; - map_c(tid, local_row, local_col); + int tid_row = 0; + int tid_col = 0; + map_c(tid, tid_row, tid_col); + + int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; + int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; float *global_offset_C = C + (BM * threadblock_id_y) * dim_n + @@ -221,18 +232,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; - const uint32_t local_c_row = tid_in_threadblock / (BN / TN); - const uint32_t local_c_col = tid_in_threadblock % (BN / TN); - - // each thread generates TM output element - float reg_c[TM * TN] = { 0.0f }; - float reg_a[TM] = { 0.0f }; - float reg_b[TN] = { 0.0f }; - const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; + const uint32_t warp_row = warp_in_threadblock / (BN / WN); + const uint32_t warp_col = warp_in_threadblock % (BN / WN); const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; - const uint32_t warp_x = warp_in_threadblock % 2; - const uint32_t warp_y = warp_in_threadblock / 2; volatile float *local_a = sharedmem_per_threadblock; // const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; @@ -244,8 +247,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); - constexpr uint32_t stride_a = (BM * BN) / BK / (TM * TN); - constexpr uint32_t stride_b = (BM * BN) / BN / (TM * TN); + // number of rows a full TB can read at a time + constexpr uint32_t row_stride_a = (BM * BN) / BK / (TM * TN); + constexpr uint32_t row_stride_b = (BM * BN) / BN / (TM * TN); // clear out C initialize_C(); @@ -255,8 +259,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // // Make sure global offset values for A and B are contiguous between // neighboring threads to ensure GMEM coalescing. + // + // TODO: Sharedmem swizzling is important here #pragma GCC unroll 2 - for (uint32_t load_offset = 0; load_offset < BM; load_offset += stride_a) { + for (uint32_t load_offset = 0; load_offset < BM; load_offset += row_stride_a) { const uint32_t global_a_offset = dim_k * (global_a_row + load_offset) + (k + local_a_col); // FIXME: all threads in TB (BM*BN) will do this load, even if this is @@ -265,7 +271,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, A[global_a_offset]; } #pragma GCC unroll 2 - for (uint32_t load_offset = 0; load_offset < BK; load_offset += stride_b) { + for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { const uint32_t global_b_offset = dim_n * (k + local_b_row + load_offset) + global_b_col; local_b[BN * (local_b_row + load_offset) + local_b_col] = @@ -279,9 +285,16 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); // FIXME: If multiple warps try to issue to Tensor Core at the same time, // does one stall the other? + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS if (warp_in_threadblock == 0) { - vx_wmma_load(local_a, local_b, 0, 0, tid_in_warp); - vx_wmma(); + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load(local_a, local_b, warp_col, warp_row, wn_iter, wm_iter, + tid_in_warp); + vx_wmma(); + } + } } threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, @@ -289,19 +302,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } if (warp_in_threadblock == 0) { - write_results( - local_warp_results, - tid_in_warp, - // warp_x, - // warp_y, - 0, - 0, - dim_m, - dim_n, - C, - threadblock_id_x, - threadblock_id_y - ); + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + write_results(local_warp_results, tid_in_warp, + warp_col, warp_row, + wn_iter, wm_iter, + dim_m, dim_n, C, threadblock_id_x, threadblock_id_y); + } + } } } From 8f64fae7a7a4af899041a0072e4b3aa41a895640 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 16 May 2024 14:09:55 -0700 Subject: [PATCH 21/55] sgemm_tcore: Addr gen for local_k; add SIMT-only for reference --- tests/regression/sgemm_tcore/kernel.cpp | 164 +++++++++++++++++------- tests/regression/sgemm_tcore/main.cpp | 6 +- 2 files changed, 121 insertions(+), 49 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 090df810..e10a3c0d 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -6,6 +6,9 @@ #include #include "common.h" +#define USE_TENSOR_CORE 1 +#define TC_SINGLE_WARP 0 + #define NUM_LANES 8 // Constraints on parameters: @@ -20,18 +23,19 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 8 -#define BN BM -#define BK 8 +#define BM 16 +#define BN 16 +#define BK 32 #define TCM 8 #define TCN 8 +#define TCK 8 #define WM 8 #define WN 8 #define WMITER (WM / TCM) #define WNITER (WN / TCN) #define TM 1 -// #define TN ((TCM * TCN) / NUM_LANES / TM) -#define TN 1 +#define TN ((TCM * TCN) / NUM_LANES / TM) +// #define TN 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -125,9 +129,10 @@ inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } -void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col, - int warp_row, int wn_iter, int wm_iter, - int thread_in_warp) { +// `local_k` is assumed to be multiple of TCK +void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, + const int warp_col, const int warp_row, const int wn_iter, + const int wm_iter, const int thread_in_warp) { int tid = thread_in_warp; int tg = tid / 4; @@ -142,23 +147,24 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_col, int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols; - asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + 0])); - asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + 1])); - asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + 2])); - asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + 3])); - asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + 4])); - asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + 5])); - asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + 6])); - asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + 7])); + // @perf: bank conflicts + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); - asm volatile("flw f8 , %0" ::"m"(smem_B[(0 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f9 , %0" ::"m"(smem_B[(1 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f10, %0" ::"m"(smem_B[(2 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f11, %0" ::"m"(smem_B[(3 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f12, %0" ::"m"(smem_B[(4 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f13, %0" ::"m"(smem_B[(5 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f14, %0" ::"m"(smem_B[(6 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f15, %0" ::"m"(smem_B[(7 * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f8 , %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f9 , %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } inline void initialize_C() { @@ -232,6 +238,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; + const uint32_t local_c_row = tid_in_threadblock / (BN / TN); + const uint32_t local_c_col = tid_in_threadblock % (BN / TN); + + // each thread generates TM output element + float reg_c[TM * TN] = { 0.0f }; + float reg_a[TM] = { 0.0f }; + float reg_b[TN] = { 0.0f }; + const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; const uint32_t warp_row = warp_in_threadblock / (BN / WN); const uint32_t warp_col = warp_in_threadblock % (BN / WN); @@ -239,11 +253,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_a = sharedmem_per_threadblock; // const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; - // FIXME: this better be BM * BK, but the GMEM->SMEM load assumes all threads - // in TB participates in the load - const size_t local_a_elems = (BM * BN); + const size_t local_a_elems = (BM * BK); volatile float *local_b = sharedmem_per_threadblock + local_a_elems; - const size_t local_b_elems = (BM * BN); + const size_t local_b_elems = (BK * BN); volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); @@ -281,36 +293,95 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); - // perform wmma - // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - // FIXME: If multiple warps try to issue to Tensor Core at the same time, - // does one stall the other? - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS - if (warp_in_threadblock == 0) { +#if USE_TENSOR_CORE + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + // FIXME: If multiple warps try to issue to Tensor Core at the same time, + // does one stall the other? + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load(local_a, local_b, warp_col, warp_row, wn_iter, wm_iter, - tid_in_warp); - vx_wmma(); +#if TC_SINGLE_WARP + if (warp_in_threadblock == 0) { +#endif + vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, + wm_iter, tid_in_warp); + vx_wmma(); +#if TC_SINGLE_WARP + } +#endif } } } +#else + + // Compute single tile*tile matmul +#pragma GCC unroll 4 + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + + // Next, compute multiple result elements (TM*TN) by reusing data in RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + } + } +#endif + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); } - if (warp_in_threadblock == 0) { - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - write_results(local_warp_results, tid_in_warp, - warp_col, warp_row, - wn_iter, wm_iter, - dim_m, dim_n, C, threadblock_id_x, threadblock_id_y); +#if USE_TENSOR_CORE + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { +#if TC_SINGLE_WARP + if (warp_in_threadblock == 0) { +#endif + write_results(local_warp_results, tid_in_warp, warp_col, warp_row, + wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, + threadblock_id_y); +#if TC_SINGLE_WARP } +#endif } } + +#else + + // Store result data from RF to GMEM +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + + (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = + reg_c[TN * res_idx_m + res_idx_n]; + } + } +#endif + } void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { @@ -340,8 +411,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster + // FIXME: 4* is unnecessary; being safe for overlaps float *sharedmem_per_threadblock = - (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; + (float *)DEV_SMEM_START_ADDR + (4 * BM * BK) * threadblock_id_in_cluster; thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, threadblock_dim_y, threadblock_id_x, threadblock_id_y, threadblock_id_in_cluster, sharedmem_per_threadblock); diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index e34b7066..5294f27b 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -147,9 +147,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 64; - uint32_t dim_n = 64; - uint32_t dim_k = 64; + uint32_t dim_m = 32; + uint32_t dim_n = 32; + uint32_t dim_k = 32; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k); From 78b2a318c1c2a60d39b4c809c5cbef4efa76182b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 16 May 2024 20:22:15 -0700 Subject: [PATCH 22/55] sgemm_tcore: Implement A transpose for coalesced smem access --- tests/regression/sgemm_tcore/kernel.cpp | 133 +++++++++++++++--------- 1 file changed, 85 insertions(+), 48 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index e10a3c0d..6a067453 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -7,7 +7,7 @@ #include "common.h" #define USE_TENSOR_CORE 1 -#define TC_SINGLE_WARP 0 +#define TC_SINGLE_WARP 1 #define NUM_LANES 8 @@ -23,9 +23,9 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 16 -#define BN 16 -#define BK 32 +#define BM 8 +#define BN 8 +#define BK 8 #define TCM 8 #define TCN 8 #define TCK 8 @@ -34,9 +34,13 @@ #define WMITER (WM / TCM) #define WNITER (WN / TCN) #define TM 1 -#define TN ((TCM * TCN) / NUM_LANES / TM) -// #define TN 1 +// #define TN ((TCM * TCN) / NUM_LANES / TM) +#define TN 1 +// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM +// scenario +#define BK_LOOP 1 +#define TRANSPOSE_AS 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -130,7 +134,7 @@ inline void vx_wmma() { } // `local_k` is assumed to be multiple of TCK -void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, +inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, const int warp_col, const int warp_row, const int wn_iter, const int wm_iter, const int thread_in_warp) { int tid = thread_in_warp; @@ -145,20 +149,32 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int loca int smem_B_rows = BK; int smem_B_cols = BN; - int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols; + if constexpr (!TRANSPOSE_AS) { + int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols; - // @perf: bank conflicts - asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); - asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); - asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); - asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); - asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); - asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); - asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); - asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); + // @perf: bank conflicts + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); + } else { + // transposed A + asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + } - asm volatile("flw f8 , %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f9 , %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); @@ -233,10 +249,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; + const uint32_t local_as_row = tid_in_threadblock / BM; + const uint32_t local_as_col = tid_in_threadblock % BM; const uint32_t local_b_row = tid_in_threadblock / BN; const uint32_t local_b_col = tid_in_threadblock % BN; - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; const uint32_t local_c_row = tid_in_threadblock / (BN / TN); const uint32_t local_c_col = tid_in_threadblock % (BN / TN); @@ -259,10 +275,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); - // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / BK / (TM * TN); - constexpr uint32_t row_stride_b = (BM * BN) / BN / (TM * TN); - // clear out C initialize_C(); @@ -273,16 +285,34 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // neighboring threads to ensure GMEM coalescing. // // TODO: Sharedmem swizzling is important here -#pragma GCC unroll 2 - for (uint32_t load_offset = 0; load_offset < BM; load_offset += row_stride_a) { - const uint32_t global_a_offset = - dim_k * (global_a_row + load_offset) + (k + local_a_col); - // FIXME: all threads in TB (BM*BN) will do this load, even if this is - // out-of-bounds of BM*BK - local_a[BK * (local_a_row + load_offset) + local_a_col] = - A[global_a_offset]; + if constexpr (!TRANSPOSE_AS) { + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + // number of rows a full TB can read at a time + constexpr uint32_t row_stride_a = (BM * BN) / BK / (TM * TN); +#pragma GCC unroll 1 + for (uint32_t load_offset = 0; load_offset < BM; load_offset += row_stride_a) { + const uint32_t global_a_offset = + dim_k * (global_a_row + load_offset) + (k + local_a_col); + // NOTE: all threads in TB will do this load; make sure this is not + // out-of-bounds of BM*BK + local_a[BK * (local_a_row + load_offset) + local_a_col] = + A[global_a_offset]; + } + } else { + const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; + constexpr uint32_t row_stride_a = (BM * BN) / BM / (TM * TN); +#pragma GCC unroll 1 + for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_a) { + const uint32_t global_a_offset = + dim_k * (global_a_row + load_offset) + (k + local_as_row); + local_a[BM * (local_as_row + load_offset) + local_as_col] = + A[global_a_offset]; + } } -#pragma GCC unroll 2 + + constexpr uint32_t row_stride_b = (BM * BN) / BN / (TM * TN); + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; +#pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { const uint32_t global_b_offset = dim_n * (k + local_b_row + load_offset) + global_b_col; @@ -294,24 +324,31 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_dim_y); #if USE_TENSOR_CORE - for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { - // perform wmma - // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - // FIXME: If multiple warps try to issue to Tensor Core at the same time, - // does one stall the other? - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { +// #pragma GCC unroll 1 + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 1 + // @perf: this loop spills to stack a lot because of all the flws in vx_wmma_load + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); + // FIXME: If multiple warps try to issue to Tensor Core at the same time, + // does one stall the other? + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP - if (warp_in_threadblock == 0) { + if (warp_in_threadblock == 0) { #endif - vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, - wm_iter, tid_in_warp); - vx_wmma(); + // SMEM -> RF + vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, + wm_iter, tid_in_warp); + // compute + vx_wmma(); #if TC_SINGLE_WARP + } +#endif } -#endif } } } From 18ecebddc0ea8a5e1fb722c4b3460f1096344f30 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 16 May 2024 21:36:24 -0700 Subject: [PATCH 23/55] sgemm_tcore: Fix round-down error with CORES_PER_CLUSTER --- tests/regression/sgemm_tcore/kernel.cpp | 26 +++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 6a067453..2e95f047 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -6,9 +6,6 @@ #include #include "common.h" -#define USE_TENSOR_CORE 1 -#define TC_SINGLE_WARP 1 - #define NUM_LANES 8 // Constraints on parameters: @@ -23,9 +20,9 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 8 -#define BN 8 -#define BK 8 +#define BM 32 +#define BN 32 +#define BK 32 #define TCM 8 #define TCN 8 #define TCK 8 @@ -34,12 +31,14 @@ #define WMITER (WM / TCM) #define WNITER (WN / TCN) #define TM 1 -// #define TN ((TCM * TCN) / NUM_LANES / TM) -#define TN 1 +#define TN ((TCM * TCN) / NUM_LANES / TM) +// #define TN 1 +#define USE_TENSOR_CORE 1 +#define TC_SINGLE_WARP 0 // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario -#define BK_LOOP 1 +#define BK_LOOP 8 #define TRANSPOSE_AS 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -171,6 +170,10 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); +// #pragma GCC unroll 8 +// for (int i = 0; i < 8; i++) { +// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); +// } } asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); @@ -427,9 +430,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN); #ifdef RADIANCE - const uint32_t threadblocks_per_core = vx_num_threads() * vx_num_warps() / - threads_per_threadblock * - CORES_PER_CLUSTER; + const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps() / + threads_per_threadblock; #else const uint32_t threadblocks_per_core = vx_num_threads() * vx_num_warps() / threads_per_threadblock; From b892c22f003a6e2e3233ebea41804a60fcd8cf14 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 16 May 2024 23:31:52 -0700 Subject: [PATCH 24/55] sgemm_tcore: Reflect WMITER/WNITER in threadblock size --- tests/regression/sgemm_tcore/kernel.cpp | 37 ++++++++++++++++++------- tests/regression/sgemm_tcore/main.cpp | 6 ++-- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 2e95f047..ad741156 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -20,9 +20,9 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 32 -#define BN 32 -#define BK 32 +#define BM 16 +#define BN 16 +#define BK 8 #define TCM 8 #define TCN 8 #define TCK 8 @@ -33,12 +33,13 @@ #define TM 1 #define TN ((TCM * TCN) / NUM_LANES / TM) // #define TN 1 +#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN) #define USE_TENSOR_CORE 1 #define TC_SINGLE_WARP 0 // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario -#define BK_LOOP 8 +#define BK_LOOP 16 #define TRANSPOSE_AS 1 inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -281,6 +282,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // clear out C initialize_C(); +#pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { // Data move from GMEM to SMEM // @@ -291,7 +293,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if constexpr (!TRANSPOSE_AS) { const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / BK / (TM * TN); + constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BM; load_offset += row_stride_a) { const uint32_t global_a_offset = @@ -303,7 +305,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } else { const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - constexpr uint32_t row_stride_a = (BM * BN) / BM / (TM * TN); + constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BM; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_a) { const uint32_t global_a_offset = @@ -313,7 +315,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } - constexpr uint32_t row_stride_b = (BM * BN) / BN / (TM * TN); + constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { @@ -329,8 +331,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #if USE_TENSOR_CORE // #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 // @perf: this loop spills to stack a lot because of all the flws in vx_wmma_load +#pragma GCC unroll 1 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); @@ -338,11 +340,24 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // does one stall the other? // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS +#pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll 1 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { #endif + // if ((threadblock_id_in_cluster % 2) == 0) { + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // } // SMEM -> RF vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, wm_iter, tid_in_warp); @@ -394,7 +409,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } #if USE_TENSOR_CORE +#pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll 1 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { @@ -428,7 +445,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock - const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN); + const uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); #ifdef RADIANCE const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps() / threads_per_threadblock; @@ -460,7 +477,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - const uint32_t grid_size = arg->dim_m * arg->dim_n / (TM * TN); + const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #else diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index 5294f27b..e34b7066 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -147,9 +147,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 32; - uint32_t dim_n = 32; - uint32_t dim_k = 32; + uint32_t dim_m = 64; + uint32_t dim_n = 64; + uint32_t dim_k = 64; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k); From 0a884e1ead0b462a2e27910e37948aa0d47afe05 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 25 May 2024 20:19:57 -0700 Subject: [PATCH 25/55] tensor: spawn on all warps, 8 lanes --- tests/kernel/tensor/main.cpp | 116 ++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 48 deletions(-) diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 2fadbbd9..7fa759a8 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -1,10 +1,11 @@ #define RISCV_CUSTOM3 0x7B +#include #include #include #include -constexpr int DIM_M = 16; +constexpr int DIM_M = 8; inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); @@ -86,7 +87,7 @@ void vx_wmma_load() { int row = 0; int col = 0; - map_operand_32lanes(tid, row, col); + map_operand_8lanes(tid, row, col); // load A // each operand element is read twice by two threadgroups (Sec. III-B); @@ -110,49 +111,52 @@ void vx_wmma_load() { asm volatile ("flw f14, %0" :: "m"(B[6][col])); asm volatile ("flw f15, %0" :: "m"(B[7][col])); - map_c_32lanes(tid, row, col); + map_c_8lanes(tid, row, col); // load C - asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); - asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); - asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); - asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); - asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); - asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); - asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); - asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); + asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); } // float results[32*8]; float *const results = reinterpret_cast(0xc0000000UL); void store_wmma_result() { - int tid = vx_thread_id(); - int tg = tid / 4; + int wid = vx_warp_id(); + int tid = vx_thread_id(); + int tg = tid / 4; - int row = 0; - int col = 0; + int row = 0; + int col = 0; - map_c_32lanes(tid, row, col); + map_c_8lanes(tid, row, col); - // store C - // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); - // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); - // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); - // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); - // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); - // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); - // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); - // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); + // store C + // asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0])); + // asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1])); + // asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2])); + // asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3])); + // asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4])); + // asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5])); + // asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6])); + // asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7])); - asm volatile ("fsw f16, %0" :: "m"(results[DIM_M * (row + 0) + (col + 0)])); - asm volatile ("fsw f17, %0" :: "m"(results[DIM_M * (row + 0) + (col + 1)])); - asm volatile ("fsw f18, %0" :: "m"(results[DIM_M * (row + 2) + (col + 0)])); - asm volatile ("fsw f19, %0" :: "m"(results[DIM_M * (row + 2) + (col + 1)])); - asm volatile ("fsw f20, %0" :: "m"(results[DIM_M * (row + 0) + (col + 4)])); - asm volatile ("fsw f21, %0" :: "m"(results[DIM_M * (row + 0) + (col + 5)])); - asm volatile ("fsw f22, %0" :: "m"(results[DIM_M * (row + 2) + (col + 4)])); - asm volatile ("fsw f23, %0" :: "m"(results[DIM_M * (row + 2) + (col + 5)])); + float *const results_wid = results + (DIM_M * DIM_M * wid); + + asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); + asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); + asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); + asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); + asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); + asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); + asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); + asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); } void print_wmma_result() { @@ -160,23 +164,39 @@ void print_wmma_result() { for (int tid = 0; tid < num_threads; tid += 1) { for (int reg = 0; reg < 8; reg += 1) { - vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg])); + vx_printf("thread %d, f%d: %x\n", tid, 16 + reg, + *((int *)&results[tid * 8 + reg])); } } } -int main() -{ - vx_tmc(-1); - vx_wmma_load(); -// #pragma GCC unroll 100 -// for (int i = 0; i < 100; i++) { -// vx_wmma(); -// } - vx_wmma(); - store_wmma_result(); - vx_tmc(1); - // print_wmma_result(); - - return 0; +void wmma() { + vx_tmc(-1); + + // if (vx_warp_id() == 1) { + // for (int i = 0; i < 100; i++) { + // asm volatile ("nop"); + // } + // } + + vx_wmma_load(); + // #pragma GCC unroll 100 + // for (int i = 0; i < 100; i++) { + // vx_wmma(); + // } + vx_wmma(); + + store_wmma_result(); + // print_wmma_result(); + vx_tmc(1); +} + +int main() { + const int num_warps = vx_num_warps(); + + vx_wspawn(num_warps, wmma); + wmma(); + vx_wspawn_wait(); + + return 0; } From bc7bd1a1dd4470b5f346381595fefb5dbf6aa124 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 25 May 2024 22:47:15 -0700 Subject: [PATCH 26/55] sgemm_tcore: Write reference C matrix to file --- tests/regression/sgemm_tcore/main.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index e34b7066..0fbd838b 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -108,6 +108,14 @@ int run_test(const kernel_arg_t& kernel_arg, std::cout << "download destination buffer" << std::endl; RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size)); + std::ofstream ref_file("reference.c.bin", std::ios::binary | std::ios::out); + if (!ref_file) { + std::cerr << "error: failed to open reference.c.bin for writing\n"; + exit(EXIT_FAILURE); + } + ref_file.write(reinterpret_cast(ref_data.data()), buf_size); + ref_file.close(); + // verify result std::cout << "verify result" << std::endl; { @@ -147,9 +155,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 64; - uint32_t dim_n = 64; - uint32_t dim_k = 64; + uint32_t dim_m = 16; + uint32_t dim_n = 16; + uint32_t dim_k = 16; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k); From 200fd3e08ca8d05e9ae2ca90695e81477a44bd2b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 25 May 2024 22:47:59 -0700 Subject: [PATCH 27/55] sgemm_tcore: Revert to packed smem alloc --- tests/regression/sgemm_tcore/kernel.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index ad741156..e7fd9976 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -308,6 +308,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BM; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_a) { + // @perf: bank conflicts here const uint32_t global_a_offset = dim_k * (global_a_row + load_offset) + (k + local_as_row); local_a[BM * (local_as_row + load_offset) + local_as_col] = @@ -447,7 +448,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t threads_per_threadblock = (BM * BN) / (ELEM_PER_THREAD); #ifdef RADIANCE - const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps() / + const uint32_t threadblocks_per_core = CORES_PER_CLUSTER * vx_num_threads() * + vx_num_warps() / threads_per_threadblock; #else const uint32_t threadblocks_per_core = @@ -467,9 +469,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster - // FIXME: 4* is unnecessary; being safe for overlaps float *sharedmem_per_threadblock = - (float *)DEV_SMEM_START_ADDR + (4 * BM * BK) * threadblock_id_in_cluster; + (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, threadblock_dim_y, threadblock_id_x, threadblock_id_y, threadblock_id_in_cluster, sharedmem_per_threadblock); From 1e48bad4f9d3de20d998376ddfedb5c5568891f4 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 26 May 2024 13:51:47 -0700 Subject: [PATCH 28/55] sgemm_tcore: Fix AS transpose --- tests/regression/sgemm_tcore/kernel.cpp | 99 ++++++++++++++----------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index e7fd9976..74cfd38a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -8,6 +8,13 @@ #define NUM_LANES 8 +#define USE_TENSOR_CORE 1 +#define TC_SINGLE_WARP 0 +// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM +// scenario +#define BK_LOOP 1 +#define TRANSPOSE_AS 1 + // Constraints on parameters: // * Memory: // (BM + BN) * BK * sizeof(float) <= sharedmem size. @@ -20,28 +27,25 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 16 -#define BN 16 +#define BM 8 +#define BN 8 #define BK 8 +#define WM 8 +#define WN 8 #define TCM 8 #define TCN 8 #define TCK 8 -#define WM 8 -#define WN 8 #define WMITER (WM / TCM) #define WNITER (WN / TCN) +#if USE_TENSOR_CORE == 1 #define TM 1 #define TN ((TCM * TCN) / NUM_LANES / TM) -// #define TN 1 +#else +#define TM 1 +#define TN 1 +#endif #define ELEM_PER_THREAD (WMITER * WNITER * TM * TN) -#define USE_TENSOR_CORE 1 -#define TC_SINGLE_WARP 0 -// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM -// scenario -#define BK_LOOP 16 -#define TRANSPOSE_AS 1 - inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -137,46 +141,51 @@ inline void vx_wmma() { inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, const int warp_col, const int warp_row, const int wn_iter, const int wm_iter, const int thread_in_warp) { - int tid = thread_in_warp; - int tg = tid / 4; + const int tid = thread_in_warp; + const int tg = tid / 4; int row = 0; int col = 0; map_operand(tid, row, col); - int smem_A_rows = BM; - int smem_A_cols = BK; - int smem_B_rows = BK; - int smem_B_cols = BN; + constexpr int smem_A_rows = BM; + constexpr int smem_A_cols = BK; + constexpr int smem_AS_rows = BK; + constexpr int smem_AS_cols = BM; + constexpr int smem_B_rows = BK; + constexpr int smem_B_cols = BN; if constexpr (!TRANSPOSE_AS) { - int A_offset = (row + WM * warp_row + TCM * wm_iter) * smem_A_cols; + int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; // @perf: bank conflicts - asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); - asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); - asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); - asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); - asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); - asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); - asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); - asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); + // f8-f15 stores a single row of A + asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)])); + asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)])); + asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)])); + asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)])); + asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)])); + asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)])); + asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)])); + asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)])); } else { // transposed A - asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); + // f8-f15 stores a single row of A + asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); // #pragma GCC unroll 8 // for (int i = 0; i < 8; i++) { // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); // } } + // f8-f15 stores a single column of B asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); @@ -295,29 +304,31 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // number of rows a full TB can read at a time constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; #pragma GCC unroll 1 - for (uint32_t load_offset = 0; load_offset < BM; load_offset += row_stride_a) { + for (uint32_t local_row_offset = 0; local_row_offset < BM; + local_row_offset += row_stride_a) { const uint32_t global_a_offset = - dim_k * (global_a_row + load_offset) + (k + local_a_col); + dim_k * (global_a_row + local_row_offset) + (k + local_a_col); // NOTE: all threads in TB will do this load; make sure this is not // out-of-bounds of BM*BK - local_a[BK * (local_a_row + load_offset) + local_a_col] = + local_a[BK * (local_a_row + local_row_offset) + local_a_col] = A[global_a_offset]; } } else { const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BM; + constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; #pragma GCC unroll 1 - for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_a) { + for (uint32_t local_row_offset = 0; local_row_offset < BK; + local_row_offset += row_stride_as) { // @perf: bank conflicts here const uint32_t global_a_offset = - dim_k * (global_a_row + load_offset) + (k + local_as_row); - local_a[BM * (local_as_row + load_offset) + local_as_col] = + dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + local_a[BM * (local_as_row + local_row_offset) + local_as_col] = A[global_a_offset]; } } - constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; + constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { const uint32_t global_b_offset = From c08a4cba8bebdc12ab2cace1272d620ff517ab8c Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 26 May 2024 13:56:34 -0700 Subject: [PATCH 29/55] Add -ffixed-regs to tests/kernel makefile --- tests/kernel/common.mk | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/kernel/common.mk b/tests/kernel/common.mk index e276c624..b58b3110 100644 --- a/tests/kernel/common.mk +++ b/tests/kernel/common.mk @@ -21,6 +21,9 @@ CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy SIM_DIR = ../../../sim CFLAGS += -O3 -mcmodel=medany -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections +CFLAGS += -ffixed-ft0 -ffixed-ft1 -ffixed-ft2 -ffixed-ft3 -ffixed-ft4 -ffixed-ft5 -ffixed-ft6 -ffixed-ft7 +CFLAGS += -ffixed-fs0 -ffixed-fs1 -ffixed-fs2 -ffixed-fs3 -ffixed-fs4 -ffixed-fs5 -ffixed-fs6 -ffixed-fs7 +CFLAGS += -ffixed-fa0 -ffixed-fa1 -ffixed-fa2 -ffixed-fa3 -ffixed-fa4 -ffixed-fa5 -ffixed-fa6 -ffixed-fa7 CFLAGS += -I$(VORTEX_KN_PATH)/include -I$(VORTEX_KN_PATH)/../hw LDFLAGS += -lm -Wl,-Bstatic,--gc-sections,-T,$(VORTEX_KN_PATH)/linker/vx_link$(XLEN).ld,--defsym=STARTUP_ADDR=0x80000000 $(VORTEX_KN_PATH)/libvortexrt.a From 2b5836022da9cc0adbd25d08867b22650dbe953f Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 26 May 2024 13:57:07 -0700 Subject: [PATCH 30/55] Also generate kernel.CONFIG.elf --- tests/regression/common.mk | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 82aefd4d..33e78828 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -82,7 +82,7 @@ endif # CONFIG is supplied from the command line to differentiate ELF files with custom suffixes CONFIGEXT = $(if $(CONFIG),.$(CONFIG),) -all: $(PROJECT) kernel.bin kernel.dump kernel.radiance.dump kernel.radiance$(CONFIGEXT).dump +all: $(PROJECT) kernel.bin kernel.dump kernel.radiance.dump kernel$(CONFIGEXT).dump kernel.radiance$(CONFIGEXT).dump kernel.dump: kernel.elf $(VX_DP) -D kernel.elf > kernel.dump @@ -91,6 +91,9 @@ kernel.radiance.dump: kernel.radiance.elf $(VX_DP) -D kernel.radiance.elf > kernel.radiance.dump ifneq ($(CONFIG),) +kernel$(CONFIGEXT).dump: kernel$(CONFIGEXT).elf + $(VX_DP) -D kernel$(CONFIGEXT).elf > kernel$(CONFIGEXT).dump + kernel.radiance$(CONFIGEXT).dump: kernel.radiance$(CONFIGEXT).elf $(VX_DP) -D kernel.radiance$(CONFIGEXT).elf > kernel.radiance$(CONFIGEXT).dump endif @@ -115,6 +118,9 @@ kernel.radiance.elf: $(VX_SRCS) $(OBJCOPY) --update-section .operand.b=input.b.bin $@ ifneq ($(CONFIG),) +kernel$(CONFIGEXT).elf: kernel.elf + cp $< $@ + kernel.radiance$(CONFIGEXT).elf: kernel.radiance.elf cp $< $@ endif From 220ee0aa5ecebfcdf28f5b0cd912a80d369b1e22 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 31 May 2024 17:35:01 -0700 Subject: [PATCH 31/55] sgemm_tcore: Unroll around WMITER/WNITER This is within a very tight loop so it's worth unrolling at the risk of stack spills somewhere else. --- tests/regression/sgemm_tcore/kernel.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 74cfd38a..5e048fc5 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -341,20 +341,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_dim_y); #if USE_TENSOR_CORE -// #pragma GCC unroll 1 + // @perf: this loop spills to stack a lot because of all the flws in + // vx_wmma_load +#pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { - // @perf: this loop spills to stack a lot because of all the flws in vx_wmma_load #pragma GCC unroll 1 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - // FIXME: If multiple warps try to issue to Tensor Core at the same time, - // does one stall the other? // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { From c8d6c56dd9c952af91d327814855fe50739beccc Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 31 May 2024 17:37:06 -0700 Subject: [PATCH 32/55] sgemm_tcore: Split global DMEM load into a function --- tests/regression/sgemm_tcore/kernel.cpp | 122 ++++++++++++++---------- 1 file changed, 73 insertions(+), 49 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 5e048fc5..4ded4758 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -239,9 +239,69 @@ inline void write_results(volatile float *local_warp_results, asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } -void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { - vx_fence(); - vx_barrier(barrier_id, count); +inline void threadblock_barrier(unsigned int tid_in_threadblock, + unsigned int barrier_id, unsigned int count) { + vx_fence(); + vx_barrier(barrier_id, count); +} + +inline void +global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, + const float *A, const float *B, volatile float *local_a, + volatile float *local_b, const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y, const uint32_t local_a_row, + const uint32_t local_a_col, const uint32_t local_as_row, + const uint32_t local_as_col, const uint32_t local_b_row, + const uint32_t local_b_col) { + + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. + // + // TODO: Sharedmem swizzling is important here + if constexpr (!TRANSPOSE_AS) { + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + // number of rows a full TB can read at a time + constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BM; + local_row_offset += row_stride_a) { + const uint32_t global_a_offset = + dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE: all threads in TB will do this load; make sure this is not + // out-of-bounds of BM*BK + local_a[BK * (local_a_row + local_row_offset) + local_a_col] = + A[global_a_offset]; + } + } else { + const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; + // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; + constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; +#pragma GCC unroll 1 + for (uint32_t local_row_offset = 0; local_row_offset < BK; + local_row_offset += row_stride_as) { + // @perf: bank conflicts here + const uint32_t global_a_offset = + dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // FIXME experimenting with global coalescing + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); + local_a[BM * (local_as_row + local_row_offset) + local_as_col] = + A[global_a_offset]; + } + } + + constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; +#pragma GCC unroll 1 + for (uint32_t load_offset = 0; load_offset < BK; + load_offset += row_stride_b) { + const uint32_t global_b_offset = + dim_n * (k + local_b_row + load_offset) + global_b_col; + local_b[BN * (local_b_row + load_offset) + local_b_col] = + B[global_b_offset]; + } } void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, @@ -293,49 +353,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { - // Data move from GMEM to SMEM - // - // Make sure global offset values for A and B are contiguous between - // neighboring threads to ensure GMEM coalescing. - // - // TODO: Sharedmem swizzling is important here - if constexpr (!TRANSPOSE_AS) { - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; - // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BM; - local_row_offset += row_stride_a) { - const uint32_t global_a_offset = - dim_k * (global_a_row + local_row_offset) + (k + local_a_col); - // NOTE: all threads in TB will do this load; make sure this is not - // out-of-bounds of BM*BK - local_a[BK * (local_a_row + local_row_offset) + local_a_col] = - A[global_a_offset]; - } - } else { - const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; -#pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BK; - local_row_offset += row_stride_as) { - // @perf: bank conflicts here - const uint32_t global_a_offset = - dim_k * (global_a_row) + (k + local_as_row + local_row_offset); - local_a[BM * (local_as_row + local_row_offset) + local_as_col] = - A[global_a_offset]; - } - } - - constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; -#pragma GCC unroll 1 - for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { - const uint32_t global_b_offset = - dim_n * (k + local_b_row + load_offset) + global_b_col; - local_b[BN * (local_b_row + load_offset) + local_b_col] = - B[global_b_offset]; - } + global_dmem_load(dim_n, dim_k, k, A, B, local_a, local_b, + threadblock_id_x, threadblock_id_y, local_a_row, + local_a_col, local_as_row, local_as_col, local_b_row, + local_b_col); threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); @@ -370,8 +391,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // asm volatile("addi a0, a0, 0"); // } // SMEM -> RF - vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, - wm_iter, tid_in_warp); + vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, + wn_iter, wm_iter, tid_in_warp); // compute vx_wmma(); #if TC_SINGLE_WARP @@ -382,6 +403,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); + #else // Compute single tile*tile matmul @@ -413,10 +437,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } } -#endif threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); +#endif } #if USE_TENSOR_CORE From 4e723c46558d71a1a033bfdf3409baebfaa31e1d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 1 Jun 2024 01:12:08 -0700 Subject: [PATCH 33/55] sgemm_tcore: Support two accumulation reg tiles --- tests/regression/sgemm_tcore/kernel.cpp | 86 ++++++++++++++++--------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 4ded4758..4ac80775 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -27,10 +27,10 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 8 -#define BN 8 -#define BK 8 -#define WM 8 +#define BM 32 +#define BN 32 +#define BK 32 +#define WM 16 #define WN 8 #define TCM 8 #define TCN 8 @@ -133,8 +133,12 @@ inline constexpr void map_c(const int tid, int &row, int &col) { } } -inline void vx_wmma() { - asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +inline void vx_wmma(const int dest_reg) { + if (dest_reg == 0) { + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); + } else { + asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3)); + } } // `local_k` is assumed to be multiple of TCK @@ -196,23 +200,35 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } -inline void initialize_C() { +inline void initialize_C(const int dest_reg) { // initialize C to zeros - asm volatile("fmv.w.x f16, x0"); - asm volatile("fmv.w.x f17, x0"); - asm volatile("fmv.w.x f18, x0"); - asm volatile("fmv.w.x f19, x0"); - asm volatile("fmv.w.x f20, x0"); - asm volatile("fmv.w.x f21, x0"); - asm volatile("fmv.w.x f22, x0"); - asm volatile("fmv.w.x f23, x0"); + if (dest_reg == 0) { + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); + } else { + asm volatile("fmv.w.x f24, x0"); + asm volatile("fmv.w.x f25, x0"); + asm volatile("fmv.w.x f26, x0"); + asm volatile("fmv.w.x f27, x0"); + asm volatile("fmv.w.x f28, x0"); + asm volatile("fmv.w.x f29, x0"); + asm volatile("fmv.w.x f30, x0"); + asm volatile("fmv.w.x f31, x0"); + } } inline void write_results(volatile float *local_warp_results, - int thread_in_warp, int warp_col, int warp_row, - int wn_iter, int wm_iter, int dim_m, int dim_n, - float *C, int threadblock_id_x, - int threadblock_id_y) { + const int thread_in_warp, const int warp_col, + const int warp_row, const int wn_iter, + const int wm_iter, const int dim_m, const int dim_n, + float *C, const int threadblock_id_x, + const int threadblock_id_y) { int tid = thread_in_warp; int tg = tid / 4; @@ -229,14 +245,25 @@ inline void write_results(volatile float *local_warp_results, BN * threadblock_id_x; // @perf: this likely causes a lot of gmem bank conflicts - asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); - asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); - asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); - asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); - asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); - asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); - asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); - asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + if (wm_iter == 0) { + asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + } else { + asm volatile ("fsw f24, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + asm volatile ("fsw f25, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + asm volatile ("fsw f26, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + asm volatile ("fsw f27, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + asm volatile ("fsw f28, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + asm volatile ("fsw f29, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + asm volatile ("fsw f30, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + asm volatile ("fsw f31, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + } } inline void threadblock_barrier(unsigned int tid_in_threadblock, @@ -349,7 +376,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); // clear out C - initialize_C(); + initialize_C(0); + initialize_C(1); #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { @@ -394,7 +422,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, wn_iter, wm_iter, tid_in_warp); // compute - vx_wmma(); + vx_wmma(wm_iter); #if TC_SINGLE_WARP } #endif From 18e3653d31503598b435f98122e0613bca22f8c7 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 3 Jun 2024 21:10:42 -0700 Subject: [PATCH 34/55] sgemm_tcore: Increase RF data reuse for WMITER/WNITER ... by splitting vx_wmma_load to vx_wmma_load_{a,b} and pulling it out of the innermost loop. TODO: there's some duplicate address compute being done in the both functions. --- tests/regression/sgemm_tcore/kernel.cpp | 35 ++++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 4ac80775..69451813 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -142,12 +142,12 @@ inline void vx_wmma(const int dest_reg) { } // `local_k` is assumed to be multiple of TCK -inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, - const int warp_col, const int warp_row, const int wn_iter, - const int wm_iter, const int thread_in_warp) { +inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, + const int warp_row, const int wm_iter, const int thread_in_warp) { const int tid = thread_in_warp; const int tg = tid / 4; + // TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b int row = 0; int col = 0; map_operand(tid, row, col); @@ -188,6 +188,25 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); // } } +} + +// `local_k` is assumed to be multiple of TCK +inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, + const int warp_col, const int wn_iter, + const int thread_in_warp) { + const int tid = thread_in_warp; + const int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand(tid, row, col); + + constexpr int smem_A_rows = BM; + constexpr int smem_A_cols = BK; + constexpr int smem_AS_rows = BK; + constexpr int smem_AS_cols = BM; + constexpr int smem_B_rows = BK; + constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); @@ -401,9 +420,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS #pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); + // vx_wmma_load_b(local_b, 0, 0, 0, tid_in_warp); #pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { #endif @@ -419,8 +440,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // asm volatile("addi a0, a0, 0"); // } // SMEM -> RF - vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, - wn_iter, wm_iter, tid_in_warp); + vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); + // vx_wmma_load_a(local_a, 0, 0, 0, tid_in_warp); // compute vx_wmma(wm_iter); #if TC_SINGLE_WARP From d8944db36950d306583cd5074ebbd34990613d74 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 4 Jun 2024 18:23:27 -0700 Subject: [PATCH 35/55] sgemm_tcore: Double-buffer over K-dimension TODO: Not completely parameterized with DOUBLE_BUFFER yet. --- tests/regression/sgemm_tcore/kernel.cpp | 241 +++++++++++++++--------- tests/regression/sgemm_tcore/main.cpp | 6 +- 2 files changed, 155 insertions(+), 92 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 69451813..3e3bed78 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -15,6 +15,8 @@ #define BK_LOOP 1 #define TRANSPOSE_AS 1 +#define DOUBLE_BUFFER 1 + // Constraints on parameters: // * Memory: // (BM + BN) * BK * sizeof(float) <= sharedmem size. @@ -29,7 +31,7 @@ // BM <= BK*TM*TN #define BM 32 #define BN 32 -#define BK 32 +#define BK 8 #define WM 16 #define WN 8 #define TCM 8 @@ -44,7 +46,12 @@ #define TM 1 #define TN 1 #endif -#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN) +#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN / (DOUBLE_BUFFER ? 2 : 1)) + +// FIXME: NUM_THREADS and NUM_WARPS hardcoded +#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) +#error "threadblock size too big for cluster" +#endif inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -156,8 +163,6 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, constexpr int smem_A_cols = BK; constexpr int smem_AS_rows = BK; constexpr int smem_AS_cols = BM; - constexpr int smem_B_rows = BK; - constexpr int smem_B_cols = BN; if constexpr (!TRANSPOSE_AS) { int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; @@ -201,10 +206,6 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, int col = 0; map_operand(tid, row, col); - constexpr int smem_A_rows = BM; - constexpr int smem_A_cols = BK; - constexpr int smem_AS_rows = BK; - constexpr int smem_AS_cols = BM; constexpr int smem_B_rows = BK; constexpr int smem_B_cols = BN; @@ -294,11 +295,21 @@ inline void threadblock_barrier(unsigned int tid_in_threadblock, inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const float *A, const float *B, volatile float *local_a, - volatile float *local_b, const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y, const uint32_t local_a_row, - const uint32_t local_a_col, const uint32_t local_as_row, - const uint32_t local_as_col, const uint32_t local_b_row, - const uint32_t local_b_col) { + volatile float *local_b, const uint32_t tid_in_threadblock, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y) { + constexpr uint32_t BM_d = BM; + constexpr uint32_t BN_d = BN; + + const uint32_t local_a_row = tid_in_threadblock / BK; + const uint32_t local_a_col = tid_in_threadblock % BK; + const uint32_t local_as_row = tid_in_threadblock / BM; + const uint32_t local_as_col = tid_in_threadblock % BM; + const uint32_t local_b_row = tid_in_threadblock / BN; + const uint32_t local_b_col = tid_in_threadblock % BN; + + constexpr uint32_t threads_in_warpgroup = + (BM * BN) / ELEM_PER_THREAD / (DOUBLE_BUFFER ? 2 : 1); // FIXME // Data move from GMEM to SMEM // @@ -307,24 +318,24 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // // TODO: Sharedmem swizzling is important here if constexpr (!TRANSPOSE_AS) { - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; #pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BM; + for (uint32_t local_row_offset = 0; local_row_offset < BM_d; local_row_offset += row_stride_a) { const uint32_t global_a_offset = dim_k * (global_a_row + local_row_offset) + (k + local_a_col); // NOTE: all threads in TB will do this load; make sure this is not - // out-of-bounds of BM*BK + // out-of-bounds of BM_d*BK local_a[BK * (local_a_row + local_row_offset) + local_a_col] = A[global_a_offset]; } } else { - const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; - constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; -#pragma GCC unroll 1 + const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; + // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; +#pragma GCC unroll 4 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as) { // @perf: bank conflicts here @@ -333,25 +344,26 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // FIXME experimenting with global coalescing // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - local_a[BM * (local_as_row + local_row_offset) + local_as_col] = + local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = A[global_a_offset]; } } - constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; -#pragma GCC unroll 1 + constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; + const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col; +#pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { const uint32_t global_b_offset = dim_n * (k + local_b_row + load_offset) + global_b_col; - local_b[BN * (local_b_row + load_offset) + local_b_col] = + local_b[BN_d * (local_b_row + load_offset) + local_b_col] = B[global_b_offset]; } } void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, const uint32_t threadblock_dim_x, const uint32_t threadblock_dim_y, const uint32_t threadblock_id_x, @@ -376,14 +388,20 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t local_c_row = tid_in_threadblock / (BN / TN); const uint32_t local_c_col = tid_in_threadblock % (BN / TN); +#if !USE_TENSOR_CORE // each thread generates TM output element float reg_c[TM * TN] = { 0.0f }; float reg_a[TM] = { 0.0f }; float reg_b[TN] = { 0.0f }; +#endif - const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; - const uint32_t warp_row = warp_in_threadblock / (BN / WN); - const uint32_t warp_col = warp_in_threadblock % (BN / WN); + const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1); + const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup; + const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME + const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_LANES; + // FIXME: warp_row / BN should be warp-specialized? + const uint32_t warp_row = warp_in_warpgroup / (BN / WN); + const uint32_t warp_col = warp_in_warpgroup % (BN / WN); const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; volatile float *local_a = sharedmem_per_threadblock; @@ -391,69 +409,109 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const size_t local_a_elems = (BM * BK); volatile float *local_b = sharedmem_per_threadblock + local_a_elems; const size_t local_b_elems = (BK * BN); + + volatile float *local_a_buf = local_b + local_b_elems; + volatile float *local_b_buf = local_a_buf + local_a_elems; + volatile float *local_warp_results = - local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); + local_b_buf + local_b_elems + (warp_in_warpgroup * TCM * TCN); // clear out C initialize_C(0); initialize_C(1); -#pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k; k += BK) { - global_dmem_load(dim_n, dim_k, k, A, B, local_a, local_b, - threadblock_id_x, threadblock_id_y, local_a_row, - local_a_col, local_as_row, local_as_col, local_b_row, - local_b_col); - - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); - -#if USE_TENSOR_CORE - // @perf: this loop spills to stack a lot because of all the flws in - // vx_wmma_load -#pragma GCC unroll 1 - for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 - for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { - // perform wmma - // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS -#pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); - // vx_wmma_load_b(local_b, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { -#if TC_SINGLE_WARP - if (warp_in_threadblock == 0) { -#endif - // if ((threadblock_id_in_cluster % 2) == 0) { - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // } - // SMEM -> RF - vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); - // vx_wmma_load_a(local_a, 0, 0, 0, tid_in_warp); - // compute - vx_wmma(wm_iter); -#if TC_SINGLE_WARP - } -#endif - } - } - } + if constexpr (DOUBLE_BUFFER) { + // initiate software pipeline + if (warpgroup_id == 0) { + global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, + tid_in_warpgroup, threadblock_id_x, threadblock_id_y); } threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); + } + + uint32_t k_index = 0; + +#pragma GCC unroll 1 + for (uint32_t k = 0; k < dim_k; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + local_a_produce = (k_index % 2) ? local_a : local_a_buf; + local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_consume = (k_index % 2) ? local_a_buf : local_a; + local_b_consume = (k_index % 2) ? local_b_buf : local_b; + } else { + local_a_produce = local_a; + local_b_produce = local_b; + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; + + if (warpgroup_id == 0) { + if (k != (dim_k - BK)) { + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, + local_a_produce, local_b_produce, tid_in_warpgroup, + threadblock_id_x, threadblock_id_y); + } + + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); + } + + else { +#if USE_TENSOR_CORE + // @perf: this loop spills to stack a lot because of all the flws in + // vx_wmma_load +#pragma GCC unroll 1 + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 1 + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, tid_in_warp); + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS +#pragma GCC unroll 2 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); + // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); +#pragma GCC unroll 1 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#if TC_SINGLE_WARP + if (warp_in_warpgroup == 0) { +#endif + // if ((threadblock_id_in_cluster % 2) == 0) { + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // } + // SMEM -> RF + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + tid_in_warp); + // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); + // compute + vx_wmma(wm_iter); +#if TC_SINGLE_WARP + } +#endif + } + } + } + } + + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); + } #else @@ -498,11 +556,13 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP - if (warp_in_threadblock == 0) { + if (warp_in_warpgroup == 0) { #endif - write_results(local_warp_results, tid_in_warp, warp_col, warp_row, - wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, - threadblock_id_y); + if (warpgroup_id == 1) { + write_results(local_warp_results, tid_in_warp, warp_col, warp_row, + wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, + threadblock_id_y); + } #if TC_SINGLE_WARP } #endif @@ -554,9 +614,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // occupancy of a single cluster float *sharedmem_per_threadblock = (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; - thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, - threadblock_dim_y, threadblock_id_x, threadblock_id_y, - threadblock_id_in_cluster, sharedmem_per_threadblock); + + const int warp_id = vx_warp_id(); + thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, + threadblock_dim_x, threadblock_dim_y, threadblock_id_x, + threadblock_id_y, threadblock_id_in_cluster, + sharedmem_per_threadblock); } int main() { diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index 0fbd838b..e6f18317 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -155,9 +155,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 16; - uint32_t dim_n = 16; - uint32_t dim_k = 16; + uint32_t dim_m = 32; + uint32_t dim_n = 32; + uint32_t dim_k = 32; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k); From ff6e5bf6dc9756ba1b64fd08d65938aab03f5fa5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 14:50:36 -0700 Subject: [PATCH 36/55] sgemm_tcore: Deconstruct smem addr calc to reduce reg alloc --- tests/regression/sgemm_tcore/kernel.cpp | 150 +++++++++++++++--------- 1 file changed, 96 insertions(+), 54 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 3e3bed78..6c677326 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -180,18 +180,32 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, } else { // transposed A // f8-f15 stores a single row of A - asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); - asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); -// #pragma GCC unroll 8 -// for (int i = 0; i < 8; i++) { -// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); -// } + register volatile float *smem_addr asm("t5"); + smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]; + asm volatile("flw f0, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f1, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f2, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f3, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f4, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f5, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f6, %0" ::"m"(*smem_addr)); + smem_addr += smem_AS_cols; + asm volatile("flw f7, %0" ::"m"(*smem_addr)); + + // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); + // asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); } } @@ -210,14 +224,31 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B - asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); - asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + register volatile float *smem_addr asm("t5"); + smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]; + asm volatile("flw f8, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f9, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f10, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f11, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f12, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f13, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f14, %0" ::"m"(*smem_addr)); + smem_addr += smem_B_cols; + asm volatile("flw f15, %0" ::"m"(*smem_addr)); + // asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); + // asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); } inline void initialize_C(const int dest_reg) { @@ -243,8 +274,7 @@ inline void initialize_C(const int dest_reg) { } } -inline void write_results(volatile float *local_warp_results, - const int thread_in_warp, const int warp_col, +inline void write_results(const int thread_in_warp, const int warp_col, const int warp_row, const int wn_iter, const int wm_iter, const int dim_m, const int dim_n, float *C, const int threadblock_id_x, @@ -266,28 +296,47 @@ inline void write_results(volatile float *local_warp_results, // @perf: this likely causes a lot of gmem bank conflicts if (wm_iter == 0) { - asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); - asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); - asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); - asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); - asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); - asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); - asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); - asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + register volatile float *gmem_addr asm("t5"); + register volatile float *gmem_addr_tmp asm("t6"); + gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; + asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1))); + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f18, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f19, %0" :: "m"(*(gmem_addr_tmp + 1))); + gmem_addr += 4; + asm volatile ("fsw f20, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f21, %0" :: "m"(*(gmem_addr + 1))); + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f22, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f23, %0" :: "m"(*(gmem_addr_tmp + 1))); + // asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); + // asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); + // asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); + // asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); + // asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); + // asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); + // asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); + // asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } else { - asm volatile ("fsw f24, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)])); - asm volatile ("fsw f25, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)])); - asm volatile ("fsw f26, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)])); - asm volatile ("fsw f27, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)])); - asm volatile ("fsw f28, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)])); - asm volatile ("fsw f29, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)])); - asm volatile ("fsw f30, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); - asm volatile ("fsw f31, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); + register volatile float *gmem_addr asm("t5"); + register volatile float *gmem_addr_tmp asm("t6"); + gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f25, %0" :: "m"(*(gmem_addr + 1))); + asm volatile ("fsw f26, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f27, %0" :: "m"(*(gmem_addr_tmp + 1))); + gmem_addr += 4; + gmem_addr_tmp = gmem_addr + (2 * dim_n); + asm volatile ("fsw f28, %0" :: "m"(*(gmem_addr + 0))); + asm volatile ("fsw f29, %0" :: "m"(*(gmem_addr + 1))); + asm volatile ("fsw f30, %0" :: "m"(*(gmem_addr_tmp + 0))); + asm volatile ("fsw f31, %0" :: "m"(*(gmem_addr_tmp + 1))); } } -inline void threadblock_barrier(unsigned int tid_in_threadblock, - unsigned int barrier_id, unsigned int count) { +inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { vx_fence(); vx_barrier(barrier_id, count); } @@ -406,16 +455,13 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_a = sharedmem_per_threadblock; // const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; - const size_t local_a_elems = (BM * BK); + constexpr size_t local_a_elems = (BM * BK); volatile float *local_b = sharedmem_per_threadblock + local_a_elems; - const size_t local_b_elems = (BK * BN); + constexpr size_t local_b_elems = (BK * BN); volatile float *local_a_buf = local_b + local_b_elems; volatile float *local_b_buf = local_a_buf + local_a_elems; - volatile float *local_warp_results = - local_b_buf + local_b_elems + (warp_in_warpgroup * TCM * TCN); - // clear out C initialize_C(0); initialize_C(1); @@ -427,8 +473,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, tid_in_warpgroup, threadblock_id_x, threadblock_id_y); } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } uint32_t k_index = 0; @@ -459,8 +504,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_id_x, threadblock_id_y); } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { @@ -509,8 +553,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } #else @@ -559,9 +602,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if (warp_in_warpgroup == 0) { #endif if (warpgroup_id == 1) { - write_results(local_warp_results, tid_in_warp, warp_col, warp_row, - wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, - threadblock_id_y); + write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + dim_m, dim_n, C, threadblock_id_x, threadblock_id_y); } #if TC_SINGLE_WARP } From e44173c65edcfa0ebedd773ca9aa40d870d60d4e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 15:11:01 -0700 Subject: [PATCH 37/55] sgemm_tcore: Deconstruct addr calc for GMEM->SMEM --- tests/regression/sgemm_tcore/kernel.cpp | 59 ++++++++++++++++++------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 6c677326..760c8467 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -370,43 +370,64 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; -#pragma GCC unroll 1 + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); + volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; + +#pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BM_d; local_row_offset += row_stride_a) { - const uint32_t global_a_offset = - dim_k * (global_a_row + local_row_offset) + (k + local_a_col); - // NOTE: all threads in TB will do this load; make sure this is not - // out-of-bounds of BM_d*BK - local_a[BK * (local_a_row + local_row_offset) + local_a_col] = - A[global_a_offset]; + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // local_a[BK * (local_a_row + local_row_offset) + local_a_col] = + // A[global_a_offset]; + *local_a_tmp = *global_a; + + global_a += dim_k * row_stride_a; + local_a_tmp += BK * row_stride_a; } } else { const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; -#pragma GCC unroll 4 + const float *global_a = A + dim_k * global_a_row + (k + local_as_row); + volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; + +#pragma GCC ivdep for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as) { // @perf: bank conflicts here - const uint32_t global_a_offset = - dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // const uint32_t global_a_offset = + // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); // FIXME experimenting with global coalescing // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = - A[global_a_offset]; + // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = + // A[global_a_offset]; + + *local_a_tmp = *global_a; + + global_a += row_stride_as; + local_a_tmp += BM * row_stride_as; } } constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col; -#pragma GCC unroll 2 + const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; + volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col; + +#pragma GCC ivdep for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { - const uint32_t global_b_offset = - dim_n * (k + local_b_row + load_offset) + global_b_col; - local_b[BN_d * (local_b_row + load_offset) + local_b_col] = - B[global_b_offset]; + // const uint32_t global_b_offset = + // dim_n * (k + local_b_row + load_offset) + global_b_col; + // local_b[BN_d * (local_b_row + load_offset) + local_b_col] = + // B[global_b_offset]; + + *local_b_tmp = *global_b; + + global_b += dim_n * row_stride_b; + local_b_tmp += BN_d * row_stride_b; } } @@ -480,6 +501,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { + // register volatile float *local_a_produce asm("t0"); + // register volatile float *local_b_produce asm("t1"); + // register volatile float *local_a_consume asm("t2"); + // register volatile float *local_b_consume asm("t3"); volatile float *local_a_produce; volatile float *local_b_produce; volatile float *local_a_consume; From 150f14af25667fca55f3ecedb774d421dd12a56b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 16:53:39 -0700 Subject: [PATCH 38/55] sgemm_tcore: Use multiple fp regs for GMEM->SMEM --- tests/regression/sgemm_tcore/kernel.cpp | 80 ++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 760c8467..e5a9cf33 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -386,15 +386,21 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, local_a_tmp += BK * row_stride_a; } } else { - const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; - // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; const float *global_a = A + dim_k * global_a_row + (k + local_as_row); + // FIXME experimenting with global coalescing + // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; + static_assert( + row_stride_as * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + #pragma GCC ivdep for (uint32_t local_row_offset = 0; local_row_offset < BK; - local_row_offset += row_stride_as) { + local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here // const uint32_t global_a_offset = // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); @@ -404,10 +410,33 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; - *local_a_tmp = *global_a; - + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); global_a += row_stride_as; - local_a_tmp += BM * row_stride_as; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BM * row_stride_as * 8; } } @@ -416,18 +445,49 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col; + static_assert( + row_stride_b * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + #pragma GCC ivdep for (uint32_t load_offset = 0; load_offset < BK; - load_offset += row_stride_b) { + load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = // dim_n * (k + local_b_row + load_offset) + global_b_col; // local_b[BN_d * (local_b_row + load_offset) + local_b_col] = // B[global_b_offset]; - *local_b_tmp = *global_b; + // *local_b_tmp = *global_b; + // global_b += dim_n * row_stride_b; + // local_b_tmp += BN_d * row_stride_b; + + asm volatile ("flw ft0, (%0)" :: "r"(global_b)); global_b += dim_n * row_stride_b; - local_b_tmp += BN_d * row_stride_b; + asm volatile ("flw ft1, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft2, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft3, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft4, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft5, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft6, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + asm volatile ("flw ft7, (%0)" :: "r"(global_b)); + global_b += dim_n * row_stride_b; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_d * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_d * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_d * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_d * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_d * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_d * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_d * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_d * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN_d * row_stride_b * 8; } } @@ -514,6 +574,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_a_consume = (k_index % 2) ? local_a_buf : local_a; local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // local_a_consume = local_a_produce; + // local_b_consume = local_b_produce; } else { local_a_produce = local_a; local_b_produce = local_b; From 95b5719847da2a08b13aedc00a02f454aae73160 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 17:14:39 -0700 Subject: [PATCH 39/55] sgemm_tcore: Split K-dim loop between consumer/producer ... so that you don't have to run (warpgroup_id == 0) condition at every loop iteration which is expensive due to vx_split/join. --- tests/regression/sgemm_tcore/kernel.cpp | 134 +++++++++++++----------- 1 file changed, 75 insertions(+), 59 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index e5a9cf33..cbd3b1df 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -557,44 +557,57 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - uint32_t k_index = 0; - + if (warpgroup_id == 0) { + // TODO: bring initiation pipeline here + uint32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k; k += BK) { - // register volatile float *local_a_produce asm("t0"); - // register volatile float *local_b_produce asm("t1"); - // register volatile float *local_a_consume asm("t2"); - // register volatile float *local_b_consume asm("t3"); - volatile float *local_a_produce; - volatile float *local_b_produce; - volatile float *local_a_consume; - volatile float *local_b_consume; - if constexpr (DOUBLE_BUFFER) { - local_a_produce = (k_index % 2) ? local_a : local_a_buf; - local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_consume = (k_index % 2) ? local_a_buf : local_a; - local_b_consume = (k_index % 2) ? local_b_buf : local_b; - // local_a_consume = local_a_produce; - // local_b_consume = local_b_produce; - } else { - local_a_produce = local_a; - local_b_produce = local_b; - local_a_consume = local_a; - local_b_consume = local_b; - } - k_index++; - - if (warpgroup_id == 0) { - if (k != (dim_k - BK)) { - global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, - local_a_produce, local_b_produce, tid_in_warpgroup, - threadblock_id_x, threadblock_id_y); + for (uint32_t k = 0; k < dim_k - BK; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + local_a_produce = (k_index % 2) ? local_a : local_a_buf; + local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_consume = (k_index % 2) ? local_a_buf : local_a; + local_b_consume = (k_index % 2) ? local_b_buf : local_b; + } else { + local_a_produce = local_a; + local_b_produce = local_b; + local_a_consume = local_a; + local_b_consume = local_b; } + k_index++; + + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, local_a_produce, + local_b_produce, tid_in_warpgroup, threadblock_id_x, + threadblock_id_y); threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - else { + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + } else { + uint32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < dim_k; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + local_a_produce = (k_index % 2) ? local_a : local_a_buf; + local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_consume = (k_index % 2) ? local_a_buf : local_a; + local_b_consume = (k_index % 2) ? local_b_buf : local_b; + } else { + local_a_produce = local_a; + local_b_produce = local_b; + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; + #if USE_TENSOR_CORE // @perf: this loop spills to stack a lot because of all the flws in // vx_wmma_load @@ -603,12 +616,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma - // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, tid_in_warp); + // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, + // tid_in_warp); // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS #pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, + tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); #pragma GCC unroll 1 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { @@ -641,43 +656,44 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); - } #else - // Compute single tile*tile matmul + // Compute single tile*tile matmul #pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } #pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } - // Next, compute multiple result elements (TM*TN) by reusing data in RF + // Next, compute multiple result elements (TM*TN) by reusing data in + // RF #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { #pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + } } - } - } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); #endif + } } #if USE_TENSOR_CORE From c7a6ed03def3a6ee434da6adf2271d6c7348a731 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 18:02:34 -0700 Subject: [PATCH 40/55] sgemm_tcore: Use constant offset to reduce SMEM addr calc --- tests/regression/sgemm_tcore/kernel.cpp | 47 +++++++++---------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index cbd3b1df..7a05f0d4 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -182,21 +182,14 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, // f8-f15 stores a single row of A register volatile float *smem_addr asm("t5"); smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]; - asm volatile("flw f0, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f1, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f2, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f3, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f4, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f5, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f6, %0" ::"m"(*smem_addr)); - smem_addr += smem_AS_cols; - asm volatile("flw f7, %0" ::"m"(*smem_addr)); + asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f3, %0(%1)" :: "i"(smem_AS_cols * 3 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f4, %0(%1)" :: "i"(smem_AS_cols * 4 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr)); // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); // asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row])); @@ -226,21 +219,15 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, // f8-f15 stores a single column of B register volatile float *smem_addr asm("t5"); smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]; - asm volatile("flw f8, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f9, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f10, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f11, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f12, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f13, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f14, %0" ::"m"(*smem_addr)); - smem_addr += smem_B_cols; - asm volatile("flw f15, %0" ::"m"(*smem_addr)); + asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr)); + asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr)); + // asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); // asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); // asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); From 65c653afded4eb3fba8816be5a2027a5bd64cac3 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 18:03:08 -0700 Subject: [PATCH 41/55] sgemm_tcore: Use arithmetic instead of branch for double-buffered addr --- tests/regression/sgemm_tcore/kernel.cpp | 37 +++++++++++++------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 7a05f0d4..8f73e0f4 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -385,7 +385,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, row_stride_as * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); -#pragma GCC ivdep +#pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here @@ -436,7 +436,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, row_stride_b * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); -#pragma GCC ivdep +#pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = @@ -551,18 +551,18 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, for (uint32_t k = 0; k < dim_k - BK; k += BK) { volatile float *local_a_produce; volatile float *local_b_produce; - volatile float *local_a_consume; - volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { local_a_produce = (k_index % 2) ? local_a : local_a_buf; local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_consume = (k_index % 2) ? local_a_buf : local_a; - local_b_consume = (k_index % 2) ? local_b_buf : local_b; + local_a_produce = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_a) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_a_buf)); + local_b_produce = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_b) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_b_buf)); } else { local_a_produce = local_a; local_b_produce = local_b; - local_a_consume = local_a; - local_b_consume = local_b; } k_index++; @@ -578,18 +578,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, uint32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { - volatile float *local_a_produce; - volatile float *local_b_produce; volatile float *local_a_consume; volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { - local_a_produce = (k_index % 2) ? local_a : local_a_buf; - local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_consume = (k_index % 2) ? local_a_buf : local_a; - local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // local_a_consume = (k_index % 2) ? local_a_buf : local_a; + // local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // FIXME: swap multiply with bitshifts + local_a_consume = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_a_buf) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_a)); + local_b_consume = reinterpret_cast( + ((k_index & 1) & 1) * reinterpret_cast(local_b_buf) + + ((k_index & 1) ^ 1) * reinterpret_cast(local_b)); } else { - local_a_produce = local_a; - local_b_produce = local_b; local_a_consume = local_a; local_b_consume = local_b; } @@ -600,7 +601,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 +#pragma GCC unroll 10 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, @@ -612,7 +613,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #if TC_SINGLE_WARP if (warp_in_warpgroup == 0) { From a42fa6a1131b0288796b5c4ed52f97d2cfcbe372 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 5 Jun 2024 19:01:59 -0700 Subject: [PATCH 42/55] sgemm_tcore: Swap out mul with bitwise ops for addr ping-pong --- tests/regression/sgemm_tcore/kernel.cpp | 30 ++++++++++++++----------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 8f73e0f4..fb06966a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -546,20 +546,22 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if (warpgroup_id == 0) { // TODO: bring initiation pipeline here - uint32_t k_index = 0; + int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k - BK; k += BK) { volatile float *local_a_produce; volatile float *local_b_produce; if constexpr (DOUBLE_BUFFER) { - local_a_produce = (k_index % 2) ? local_a : local_a_buf; - local_b_produce = (k_index % 2) ? local_b : local_b_buf; + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + // local_a_produce = (k_index % 2) ? local_a : local_a_buf; + // local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_a_produce = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_a) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_a_buf)); + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); local_b_produce = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_b) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_b_buf)); + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); } else { local_a_produce = local_a; local_b_produce = local_b; @@ -575,7 +577,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { - uint32_t k_index = 0; + int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { volatile float *local_a_consume; @@ -584,12 +586,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // local_a_consume = (k_index % 2) ? local_a_buf : local_a; // local_b_consume = (k_index % 2) ? local_b_buf : local_b; // FIXME: swap multiply with bitshifts + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; local_a_consume = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_a_buf) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_a)); + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); local_b_consume = reinterpret_cast( - ((k_index & 1) & 1) * reinterpret_cast(local_b_buf) + - ((k_index & 1) ^ 1) * reinterpret_cast(local_b)); + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); } else { local_a_consume = local_a; local_b_consume = local_b; @@ -601,7 +605,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 10 +#pragma GCC unroll 4 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, From ab4d52597038428d7f720c013e0c454c29ab9555 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 12:26:07 -0700 Subject: [PATCH 43/55] sgemm_tcore: More asserts on manual unrolling --- tests/regression/sgemm_tcore/kernel.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index fb06966a..899afd8a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -384,6 +384,10 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, static_assert( row_stride_as * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); + static_assert( + (BK % (row_stride_as * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); + #pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BK; @@ -435,6 +439,9 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, static_assert( row_stride_b * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); + static_assert( + (BK % (row_stride_b * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BK; @@ -546,6 +553,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, if (warpgroup_id == 0) { // TODO: bring initiation pipeline here + // NOTE: this *should* be signed integer to trigger arithmetic right-shift int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k - BK; k += BK) { From deb6e5eba27703444d3e897442ad61cdb7c0f8a2 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 12:43:08 -0700 Subject: [PATCH 44/55] sgemm_tcore: Move bank-conflicts to SMEM stores from GMEM loads --- tests/regression/sgemm_tcore/kernel.cpp | 63 ++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 899afd8a..fd2e73da 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -14,6 +14,10 @@ // scenario #define BK_LOOP 1 #define TRANSPOSE_AS 1 +// GMEM_COALESCED sets bank conflict-free accesses for +// 1: GMEM loads of A matrix +// 0: SMEM stores of A matrix +#define GMEM_COALESCED_A 1 #define DOUBLE_BUFFER 1 @@ -31,7 +35,7 @@ // BM <= BK*TM*TN #define BM 32 #define BN 32 -#define BK 8 +#define BK 32 #define WM 16 #define WN 8 #define TCM 8 @@ -326,6 +330,7 @@ inline void write_results(const int thread_in_warp, const int warp_col, inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { vx_fence(); vx_barrier(barrier_id, count); + // vx_barrier(0, count); } inline void @@ -373,6 +378,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, local_a_tmp += BK * row_stride_a; } } else { +#if !GMEM_COALESCED_A constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; const float *global_a = A + dim_k * global_a_row + (k + local_as_row); @@ -388,7 +394,6 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, (BK % (row_stride_as * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); - #pragma GCC unroll 2 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as * 8) { @@ -429,6 +434,59 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); local_a_tmp += BM * row_stride_as * 8; } +#else + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); + // NOTE that SMEM writes are transposed + volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; + + static_assert( + row_stride_a * 8 <= BM_d, + "manual loop unrolling condition not met; consider increasing BM"); + static_assert( + (BM_d % (row_stride_a * 8)) == 0, + "manual loop unrolling condition not met; BM should be power-of-two"); + +#pragma GCC unroll 2 + for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + local_row_offset += row_stride_a * 8) { + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE that SMEM writes are transposed + // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = + // A[global_a_offset]; + + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + + // stride along columns + asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += row_stride_a * 8; + } +#endif } constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; @@ -585,6 +643,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { + // NOTE: this *should* be signed integer to trigger arithmetic right-shift int32_t k_index = 0; #pragma GCC unroll 1 for (uint32_t k = 0; k < dim_k; k += BK) { From 7f6f0961912561ab3e755f69550bb8c950c27e43 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 13:43:57 -0700 Subject: [PATCH 45/55] sgemm_tcore: Use if constexpr --- tests/regression/sgemm_tcore/kernel.cpp | 201 ++++++++++++------------ 1 file changed, 100 insertions(+), 101 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index fd2e73da..d26bae36 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -378,115 +378,114 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, local_a_tmp += BK * row_stride_a; } } else { -#if !GMEM_COALESCED_A - constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; - const float *global_a = A + dim_k * global_a_row + (k + local_as_row); - // FIXME experimenting with global coalescing - // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; - // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); - volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; - - static_assert( - row_stride_as * 8 <= BK, - "manual loop unrolling condition not met; consider increasing BK"); - static_assert( - (BK % (row_stride_as * 8)) == 0, - "manual loop unrolling condition not met; BK should be power-of-two"); - -#pragma GCC unroll 2 - for (uint32_t local_row_offset = 0; local_row_offset < BK; - local_row_offset += row_stride_as * 8) { - // @perf: bank conflicts here - // const uint32_t global_a_offset = - // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + if constexpr (!GMEM_COALESCED_A) { + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; + const float *global_a = A + dim_k * global_a_row + (k + local_as_row); // FIXME experimenting with global coalescing - // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = - // A[global_a_offset]; + // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); + volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; - // *local_a_tmp = *global_a; - asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += row_stride_as; - - asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += BM * row_stride_as * 8; - } -#else - constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; - const float *global_a = A + dim_k * global_a_row + (k + local_a_col); - // NOTE that SMEM writes are transposed - volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; - - static_assert( - row_stride_a * 8 <= BM_d, - "manual loop unrolling condition not met; consider increasing BM"); - static_assert( - (BM_d % (row_stride_a * 8)) == 0, - "manual loop unrolling condition not met; BM should be power-of-two"); + static_assert( + row_stride_as * 8 <= BK, + "manual loop unrolling condition not met; consider increasing BK"); + static_assert( + (BK % (row_stride_as * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 2 - for (uint32_t local_row_offset = 0; local_row_offset < BM_d; - local_row_offset += row_stride_a * 8) { - // const uint32_t global_a_offset = - // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + for (uint32_t local_row_offset = 0; local_row_offset < BK; + local_row_offset += row_stride_as * 8) { + // @perf: bank conflicts here + // const uint32_t global_a_offset = + // dim_k * (global_a_row) + (k + local_as_row + local_row_offset); + // FIXME experimenting with global coalescing + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); + // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = + // A[global_a_offset]; + + // *local_a_tmp = *global_a; + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += row_stride_as; + + asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += BM * row_stride_as * 8; + } + } else { + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + const float *global_a = A + dim_k * global_a_row + (k + local_a_col); // NOTE that SMEM writes are transposed - // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = - // A[global_a_offset]; + volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; - // *local_a_tmp = *global_a; - asm volatile ("flw ft0, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft1, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft2, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft3, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft4, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft5, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft6, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; - asm volatile ("flw ft7, (%0)" :: "r"(global_a)); - global_a += dim_k * row_stride_a; + static_assert( + row_stride_a * 8 <= BM_d, + "manual loop unrolling condition not met; consider increasing BM"); + static_assert( + (BM_d % (row_stride_a * 8)) == 0, + "manual loop unrolling condition not met; BM should be power-of-two"); - // stride along columns - asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); - local_a_tmp += row_stride_a * 8; +#pragma GCC unroll 4 + for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + local_row_offset += row_stride_a * 8) { + // const uint32_t global_a_offset = + // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); + // NOTE that SMEM writes are transposed + // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = + // A[global_a_offset]; + + asm volatile ("flw ft0, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft1, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft2, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft3, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft4, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft5, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft6, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + asm volatile ("flw ft7, (%0)" :: "r"(global_a)); + global_a += dim_k * row_stride_a; + + // stride along columns + asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp)); + local_a_tmp += row_stride_a * 8; + } } -#endif } constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; From 2c50b0cdce9cc431fbecaffa1fa4ea4216a39dd0 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 14:27:57 -0700 Subject: [PATCH 46/55] sgemm_tcore: Remove BM_d/BN_d --- tests/regression/sgemm_tcore/kernel.cpp | 59 ++++++++++++------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index d26bae36..4838e9d8 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -339,9 +339,6 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, volatile float *local_b, const uint32_t tid_in_threadblock, const uint32_t threadblock_id_x, const uint32_t threadblock_id_y) { - constexpr uint32_t BM_d = BM; - constexpr uint32_t BN_d = BN; - const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_as_row = tid_in_threadblock / BM; @@ -359,14 +356,16 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // // TODO: Sharedmem swizzling is important here if constexpr (!TRANSPOSE_AS) { - const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + // FIXME: !TRANSPOSE_AS code is old + + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; const float *global_a = A + dim_k * global_a_row + (k + local_a_col); volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; #pragma GCC unroll 2 - for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a) { // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); @@ -379,13 +378,13 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, } } else { if constexpr (!GMEM_COALESCED_A) { - constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM; + const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; const float *global_a = A + dim_k * global_a_row + (k + local_as_row); // FIXME experimenting with global coalescing - // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; // const float *global_a = A + dim_k * global_a_row + (k + local_as_col); - volatile float *local_a_tmp = local_a + BM_d * local_as_row + local_as_col; + volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col; static_assert( row_stride_as * 8 <= BK, @@ -403,7 +402,7 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // FIXME experimenting with global coalescing // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - // local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = + // local_a[BM * (local_as_row + local_row_offset) + local_as_col] = // A[global_a_offset]; // *local_a_tmp = *global_a; @@ -436,25 +435,25 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, } } else { constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; - const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; + const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const float *global_a = A + dim_k * global_a_row + (k + local_a_col); // NOTE that SMEM writes are transposed - volatile float *local_a_tmp = local_a + BM_d * local_a_col + local_a_row; + volatile float *local_a_tmp = local_a + BM * local_a_col + local_a_row; static_assert( - row_stride_a * 8 <= BM_d, + row_stride_a * 8 <= BM, "manual loop unrolling condition not met; consider increasing BM"); static_assert( - (BM_d % (row_stride_a * 8)) == 0, + (BM % (row_stride_a * 8)) == 0, "manual loop unrolling condition not met; BM should be power-of-two"); #pragma GCC unroll 4 - for (uint32_t local_row_offset = 0; local_row_offset < BM_d; + for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a * 8) { // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_a_col); // NOTE that SMEM writes are transposed - // local_a[BM_d * (local_a_col) + local_a_row + local_row_offset] = + // local_a[BM * (local_a_col) + local_a_row + local_row_offset] = // A[global_a_offset]; asm volatile ("flw ft0, (%0)" :: "r"(global_a)); @@ -488,10 +487,10 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, } } - constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; - const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col; + constexpr uint32_t row_stride_b = threads_in_warpgroup / BN; + const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; const float *global_b = B + dim_n * (k + local_b_row) + global_b_col; - volatile float *local_b_tmp = local_b + BN_d * local_b_row + local_b_col; + volatile float *local_b_tmp = local_b + BN * local_b_row + local_b_col; static_assert( row_stride_b * 8 <= BK, @@ -505,13 +504,13 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = // dim_n * (k + local_b_row + load_offset) + global_b_col; - // local_b[BN_d * (local_b_row + load_offset) + local_b_col] = + // local_b[BN * (local_b_row + load_offset) + local_b_col] = // B[global_b_offset]; // *local_b_tmp = *global_b; // global_b += dim_n * row_stride_b; - // local_b_tmp += BN_d * row_stride_b; + // local_b_tmp += BN * row_stride_b; asm volatile ("flw ft0, (%0)" :: "r"(global_b)); global_b += dim_n * row_stride_b; @@ -530,15 +529,15 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, asm volatile ("flw ft7, (%0)" :: "r"(global_b)); global_b += dim_n * row_stride_b; - asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_d * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_d * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_d * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_d * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_d * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_d * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_d * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); - asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_d * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); - local_b_tmp += BN_d * row_stride_b * 8; + asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp)); + asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp)); + local_b_tmp += BN * row_stride_b * 8; } } From d5adacda30490907928c53a9359b5f4779e062bc Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 15:19:05 -0700 Subject: [PATCH 47/55] Add args.bin to ELF Change KERNEL_ARG_DEV_MEM_ADDR for sgemm_{wg,gemmini,tcore} --- kernel/linker/vx_link32.ld | 5 +++++ tests/regression/common.mk | 4 ++++ tests/regression/sgemm_gemmini/common.h | 2 +- tests/regression/sgemm_tcore/common.h | 4 ++-- tests/regression/sgemm_wg/common.h | 2 +- 5 files changed, 13 insertions(+), 4 deletions(-) diff --git a/kernel/linker/vx_link32.ld b/kernel/linker/vx_link32.ld index ea5c4e56..40a624f6 100644 --- a/kernel/linker/vx_link32.ld +++ b/kernel/linker/vx_link32.ld @@ -10,6 +10,7 @@ ENTRY(_start) MEMORY { DRAM0 (rwx): ORIGIN = 0x80000000, LENGTH = 512M + DRAMARG (rwx): ORIGIN = 0x9fff0000, LENGTH = 8K DRAM1 (rwx): ORIGIN = 0xa0000000, LENGTH = 32K DRAM2 (rwx): ORIGIN = 0xa1000000, LENGTH = 32K } @@ -259,6 +260,10 @@ SECTIONS .gnu.attributes 0 : { KEEP (*(.gnu.attributes)) } /DISCARD/ : { *(.note.GNU-stack) *(.gnu_debuglink) *(.gnu.lto_*) } + .args : { + *(.args) + . += 8K; + }> DRAMARG .operand.a : { *(.operand.a) . += 32K; diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 33e78828..04ddbb3f 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -107,15 +107,19 @@ kernel.elf: $(VX_SRCS) $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -o $@ $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --update-section .operand.a=input.a.bin $@ $(OBJCOPY) --update-section .operand.b=input.b.bin $@ + $(OBJCOPY) --update-section .args=args.bin $@ kernel.radiance.elf: $(VX_SRCS) $(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -o $@ $(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@ + $(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@ $(OBJCOPY) --update-section .operand.a=input.a.bin $@ $(OBJCOPY) --update-section .operand.b=input.b.bin $@ + $(OBJCOPY) --update-section .args=args.bin $@ ifneq ($(CONFIG),) kernel$(CONFIGEXT).elf: kernel.elf diff --git a/tests/regression/sgemm_gemmini/common.h b/tests/regression/sgemm_gemmini/common.h index 74941562..5c84f3b7 100644 --- a/tests/regression/sgemm_gemmini/common.h +++ b/tests/regression/sgemm_gemmini/common.h @@ -3,7 +3,7 @@ #include -#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { diff --git a/tests/regression/sgemm_tcore/common.h b/tests/regression/sgemm_tcore/common.h index d94a270f..5c84f3b7 100644 --- a/tests/regression/sgemm_tcore/common.h +++ b/tests/regression/sgemm_tcore/common.h @@ -3,7 +3,7 @@ #include -#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { @@ -15,4 +15,4 @@ typedef struct { uint64_t addr_c; } kernel_arg_t; -#endif \ No newline at end of file +#endif diff --git a/tests/regression/sgemm_wg/common.h b/tests/regression/sgemm_wg/common.h index 74941562..5c84f3b7 100644 --- a/tests/regression/sgemm_wg/common.h +++ b/tests/regression/sgemm_wg/common.h @@ -3,7 +3,7 @@ #include -#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000 +#define KERNEL_ARG_DEV_MEM_ADDR 0x9fff0000 #define DEV_SMEM_START_ADDR 0xff000000 typedef struct { From 062403066ef6054679ca96154523fc9a8366dbcd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 15:22:01 -0700 Subject: [PATCH 48/55] sgemm_tcore: Bring M/N-loop inside the kernel Instead of spawning multiple threadblocks which comes with stack access overhead, have 1 threadblock work on the entire M/N-space thru a loop. Grid size is fixed to the hardware parallelism. TODO currently only works with 1 cluster in the system. --- tests/regression/sgemm_tcore/kernel.cpp | 346 ++++++++++++------------ 1 file changed, 175 insertions(+), 171 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 4838e9d8..11187644 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -9,7 +9,6 @@ #define NUM_LANES 8 #define USE_TENSOR_CORE 1 -#define TC_SINGLE_WARP 0 // number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM // scenario #define BK_LOOP 1 @@ -267,7 +266,7 @@ inline void initialize_C(const int dest_reg) { inline void write_results(const int thread_in_warp, const int warp_col, const int warp_row, const int wn_iter, - const int wm_iter, const int dim_m, const int dim_n, + const int wm_iter, const int dim_n, float *C, const int threadblock_id_x, const int threadblock_id_y) { int tid = thread_in_warp; @@ -333,12 +332,12 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) // vx_barrier(0, count); } -inline void -global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, - const float *A, const float *B, volatile float *local_a, - volatile float *local_b, const uint32_t tid_in_threadblock, - const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y) { +inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, + const uint32_t k, const float *A, const float *B, + volatile float *local_a, volatile float *local_b, + const uint32_t tid_in_threadblock, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y) { const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_as_row = tid_in_threadblock / BM; @@ -546,8 +545,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t threads_per_threadblock, const uint32_t threadblock_dim_x, const uint32_t threadblock_dim_y, - const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y, + /*const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y,*/ const uint32_t threadblock_id_in_cluster, float *sharedmem_per_threadblock) { const float *A = (const float *)arg->addr_a; @@ -593,198 +592,198 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, volatile float *local_a_buf = local_b + local_b_elems; volatile float *local_b_buf = local_a_buf + local_a_elems; - // clear out C - initialize_C(0); - initialize_C(1); - - if constexpr (DOUBLE_BUFFER) { - // initiate software pipeline - if (warpgroup_id == 0) { - global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, - tid_in_warpgroup, threadblock_id_x, threadblock_id_y); - } - - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); - } - if (warpgroup_id == 0) { - // TODO: bring initiation pipeline here - // NOTE: this *should* be signed integer to trigger arithmetic right-shift - int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k - BK; k += BK) { - volatile float *local_a_produce; - volatile float *local_b_produce; - if constexpr (DOUBLE_BUFFER) { - const uint32_t mask_odd = (k_index & 1) << 31 >> 31; - const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; - // local_a_produce = (k_index % 2) ? local_a : local_a_buf; - // local_b_produce = (k_index % 2) ? local_b : local_b_buf; - local_a_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a)) | - (mask_even & reinterpret_cast(local_a_buf))); - local_b_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b)) | - (mask_even & reinterpret_cast(local_b_buf))); - } else { - local_a_produce = local_a; - local_b_produce = local_b; + for (uint32_t block_m = 0; (block_m * BM) < dim_m; block_m++) { +#pragma GCC unroll 1 + for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + if constexpr (DOUBLE_BUFFER) { + // initiate software pipeline + global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, + tid_in_warpgroup, block_n, block_m); + + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + } + + // NOTE: this *should* be signed integer to trigger arithmetic + // right-shift + int32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < (dim_k) - BK; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + if constexpr (DOUBLE_BUFFER) { + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + // local_a_produce = (k_index % 2) ? local_a : local_a_buf; + // local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_produce = reinterpret_cast( + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); + local_b_produce = reinterpret_cast( + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); + } else { + local_a_produce = local_a; + local_b_produce = local_b; + } + k_index++; + + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, + local_a_produce, local_b_produce, tid_in_warpgroup, + block_n, block_m); + + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + } + + // sync with final consumer stage in the k-loop + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - k_index++; - - global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, local_a_produce, - local_b_produce, tid_in_warpgroup, threadblock_id_x, - threadblock_id_y); - - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } - - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); } else { - // NOTE: this *should* be signed integer to trigger arithmetic right-shift - int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k; k += BK) { - volatile float *local_a_consume; - volatile float *local_b_consume; - if constexpr (DOUBLE_BUFFER) { - // local_a_consume = (k_index % 2) ? local_a_buf : local_a; - // local_b_consume = (k_index % 2) ? local_b_buf : local_b; - // FIXME: swap multiply with bitshifts - const uint32_t mask_odd = (k_index & 1) << 31 >> 31; - const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; - local_a_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a_buf)) | - (mask_even & reinterpret_cast(local_a))); - local_b_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b_buf)) | - (mask_even & reinterpret_cast(local_b))); - } else { - local_a_consume = local_a; - local_b_consume = local_b; - } - k_index++; + for (uint32_t block_m = 0; (block_m * BM) < dim_m; block_m++) { +#pragma GCC unroll 1 + for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { + // clear out C + initialize_C(0); + initialize_C(1); + + // sync with initial producer stage in the k-loop + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + + // NOTE: this *should* be signed integer to trigger arithmetic + // right-shift + int32_t k_index = 0; +#pragma GCC unroll 1 + for (uint32_t k = 0; k < (dim_k); k += BK) { + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + // local_a_consume = (k_index % 2) ? local_a_buf : local_a; + // local_b_consume = (k_index % 2) ? local_b_buf : local_b; + // FIXME: swap multiply with bitshifts + const uint32_t mask_odd = (k_index & 1) << 31 >> 31; + const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; + local_a_consume = reinterpret_cast( + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); + local_b_consume = reinterpret_cast( + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); + } else { + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; #if USE_TENSOR_CORE - // @perf: this loop spills to stack a lot because of all the flws in - // vx_wmma_load + // @perf: this loop spills to stack a lot because of all the flws in + // vx_wmma_load #pragma GCC unroll 1 - for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { - // perform wmma - // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, - // tid_in_warp); - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS -#pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, - tid_in_warp); - // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { -#if TC_SINGLE_WARP - if (warp_in_warpgroup == 0) { -#endif - // if ((threadblock_id_in_cluster % 2) == 0) { - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // } - // SMEM -> RF - vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 1 + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, + // tid_in_warp); + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS +#pragma GCC unroll 1 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); - // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); - // compute - vx_wmma(wm_iter); -#if TC_SINGLE_WARP + // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); +#pragma GCC unroll 1 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + // if ((threadblock_id_in_cluster % 2) == 0) { + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // } + // SMEM -> RF + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + tid_in_warp); + // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); + // compute + vx_wmma(wm_iter); + } } -#endif } } - } - } - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); #else - // Compute single tile*tile matmul + // Compute single tile*tile matmul #pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF + for (uint32_t local_k = 0; local_k < BK; local_k++) { + // First, pump data from SMEM->RF #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - - // Next, compute multiple result elements (TM*TN) by reusing data in - // RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + reg_a[res_idx_m] = + local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; + } #pragma GCC unroll TN for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + reg_b[res_idx_n] = + local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } + + // Next, compute multiple result elements (TM*TN) by reusing data in + // RF +#pragma GCC unroll TM + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { +#pragma GCC unroll TN + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + // NOTE use of local_b_row + reg_c[TN * res_idx_m + res_idx_n] += + reg_a[res_idx_m] * reg_b[res_idx_n]; + // reg_c[TN * res_idx_m + res_idx_n] += + // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * + // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; + } } } - } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); #endif - } - } + } #if USE_TENSOR_CORE #pragma GCC unroll 1 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #pragma GCC unroll 1 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { -#if TC_SINGLE_WARP - if (warp_in_warpgroup == 0) { -#endif - if (warpgroup_id == 1) { - write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, - dim_m, dim_n, C, threadblock_id_x, threadblock_id_y); - } -#if TC_SINGLE_WARP - } -#endif - } - } - + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + if (warpgroup_id == 1) { + write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + dim_n, C, block_n, block_m); + } #else - - // Store result data from RF to GMEM + // Store result data from RF to GMEM #pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { + for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { #pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + - (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = - reg_c[TN * res_idx_m + res_idx_n]; + for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { + C[dim_n * (BM * threadblock_id_y + TM * local_c_row + res_idx_m) + + (BN * threadblock_id_x + TN * local_c_col + res_idx_n)] = + reg_c[TN * res_idx_m + res_idx_n]; + } + } +#endif + } + } + } } } -#endif - } void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { @@ -819,14 +818,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const int warp_id = vx_warp_id(); thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, - threadblock_dim_x, threadblock_dim_y, threadblock_id_x, - threadblock_id_y, threadblock_id_in_cluster, + threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x, + threadblock_id_y,*/ threadblock_id_in_cluster, sharedmem_per_threadblock); } int main() { kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR; - const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; + + const uint32_t threads_per_cluster = + CORES_PER_CLUSTER * vx_num_threads() * vx_num_warps(); + // const uint32_t grid_size = arg->dim_m * arg->dim_n / ELEM_PER_THREAD; + const uint32_t grid_size = threads_per_cluster; + #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #else From 7c4d850074ef5820c9bec91df5522ba01e989dbd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 18:38:43 -0700 Subject: [PATCH 49/55] sgemm_tcore: Experiment with high K; 48% util --- tests/regression/sgemm_tcore/kernel.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 11187644..6db7ae3d 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -363,7 +363,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const float *global_a = A + dim_k * global_a_row + (k + local_a_col); volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; -#pragma GCC unroll 2 +#pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a) { // const uint32_t global_a_offset = @@ -392,7 +392,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, (BK % (row_stride_as * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); -#pragma GCC unroll 2 +#pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here @@ -446,7 +446,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, (BM % (row_stride_a * 8)) == 0, "manual loop unrolling condition not met; BM should be power-of-two"); -#pragma GCC unroll 4 +#pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a * 8) { // const uint32_t global_a_offset = @@ -498,7 +498,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, (BK % (row_stride_b * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); -#pragma GCC unroll 2 +#pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = @@ -609,7 +609,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // right-shift int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < (dim_k) - BK; k += BK) { + for (uint32_t k = 0; k < (8 * dim_k) - BK; k += BK) { volatile float *local_a_produce; volatile float *local_b_produce; if constexpr (DOUBLE_BUFFER) { @@ -656,7 +656,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // right-shift int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < (dim_k); k += BK) { + for (uint32_t k = 0; k < (8 * dim_k); k += BK) { volatile float *local_a_consume; volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { @@ -682,19 +682,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, // tid_in_warp); // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // if ((threadblock_id_in_cluster % 2) == 0) { // asm volatile("addi a0, a0, 0"); From 985c5fc0dcf868e1f4677ee02fd57d25ac121e40 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 18:50:31 -0700 Subject: [PATCH 50/55] sgemm_tcore: Remove uneffective register asm --- tests/regression/sgemm_tcore/kernel.cpp | 34 ++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 6db7ae3d..42443de7 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -183,7 +183,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, } else { // transposed A // f8-f15 stores a single row of A - register volatile float *smem_addr asm("t5"); + volatile float *smem_addr; smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]; asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -220,7 +220,7 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B - register volatile float *smem_addr asm("t5"); + volatile float *smem_addr; smem_addr = &smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]; asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr)); @@ -286,8 +286,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, // @perf: this likely causes a lot of gmem bank conflicts if (wm_iter == 0) { - register volatile float *gmem_addr asm("t5"); - register volatile float *gmem_addr_tmp asm("t6"); + volatile float *gmem_addr; + volatile float *gmem_addr_tmp; gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; asm volatile ("fsw f16, %0" :: "m"(*(gmem_addr + 0))); asm volatile ("fsw f17, %0" :: "m"(*(gmem_addr + 1))); @@ -309,8 +309,8 @@ inline void write_results(const int thread_in_warp, const int warp_col, // asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)])); // asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)])); } else { - register volatile float *gmem_addr asm("t5"); - register volatile float *gmem_addr_tmp asm("t6"); + volatile float *gmem_addr; + volatile float *gmem_addr_tmp; gmem_addr = &global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]; gmem_addr_tmp = gmem_addr + (2 * dim_n); asm volatile ("fsw f24, %0" :: "m"(*(gmem_addr + 0))); @@ -494,9 +494,9 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, static_assert( row_stride_b * 8 <= BK, "manual loop unrolling condition not met; consider increasing BK"); - static_assert( - (BK % (row_stride_b * 8)) == 0, - "manual loop unrolling condition not met; BK should be power-of-two"); + static_assert( + (BK % (row_stride_b * 8)) == 0, + "manual loop unrolling condition not met; BK should be power-of-two"); #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; @@ -618,11 +618,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // local_a_produce = (k_index % 2) ? local_a : local_a_buf; // local_b_produce = (k_index % 2) ? local_b : local_b_buf; local_a_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a)) | - (mask_even & reinterpret_cast(local_a_buf))); + (mask_odd & reinterpret_cast(local_a)) | + (mask_even & reinterpret_cast(local_a_buf))); local_b_produce = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b)) | - (mask_even & reinterpret_cast(local_b_buf))); + (mask_odd & reinterpret_cast(local_b)) | + (mask_even & reinterpret_cast(local_b_buf))); } else { local_a_produce = local_a; local_b_produce = local_b; @@ -666,11 +666,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t mask_odd = (k_index & 1) << 31 >> 31; const uint32_t mask_even = ((k_index & 1) ^ 1) << 31 >> 31; local_a_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_a_buf)) | - (mask_even & reinterpret_cast(local_a))); + (mask_odd & reinterpret_cast(local_a_buf)) | + (mask_even & reinterpret_cast(local_a))); local_b_consume = reinterpret_cast( - (mask_odd & reinterpret_cast(local_b_buf)) | - (mask_even & reinterpret_cast(local_b))); + (mask_odd & reinterpret_cast(local_b_buf)) | + (mask_even & reinterpret_cast(local_b))); } else { local_a_consume = local_a; local_b_consume = local_b; From 856596cbb33bce17fab2a56117a4f973b10182cc Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 7 Jun 2024 10:39:11 -0700 Subject: [PATCH 51/55] sgemm_tcore: Write reference C before sim --- tests/regression/sgemm_tcore/main.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index e6f18317..84283992 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -108,14 +108,6 @@ int run_test(const kernel_arg_t& kernel_arg, std::cout << "download destination buffer" << std::endl; RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size)); - std::ofstream ref_file("reference.c.bin", std::ios::binary | std::ios::out); - if (!ref_file) { - std::cerr << "error: failed to open reference.c.bin for writing\n"; - exit(EXIT_FAILURE); - } - ref_file.write(reinterpret_cast(ref_data.data()), buf_size); - ref_file.close(); - // verify result std::cout << "verify result" << std::endl; { @@ -155,13 +147,22 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 32; - uint32_t dim_n = 32; - uint32_t dim_k = 32; + uint32_t dim_m = 128; + uint32_t dim_n = 128; + uint32_t dim_k = 128; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k); + std::cout << "write reference output" << std::endl; + std::ofstream ref_file("reference.c.bin", std::ios::binary | std::ios::out); + if (!ref_file) { + std::cerr << "error: failed to open reference.c.bin for writing\n"; + exit(EXIT_FAILURE); + } + ref_file.write(reinterpret_cast(ref_data.data()), ref_data.size() * sizeof(ref_data[0])); + ref_file.close(); + uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]); uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]); uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]); From 3a6427a49143edb194901b50b34397250995102e Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 7 Jun 2024 16:08:40 -0700 Subject: [PATCH 52/55] sgemm_tcore: Hardcode threadblock id 0 this is fine since we're statically dispatching only one threadblock to the whole cluster. --- tests/regression/sgemm_tcore/kernel.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 42443de7..2354a3e0 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -547,7 +547,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_dim_y, /*const uint32_t threadblock_id_x, const uint32_t threadblock_id_y,*/ - const uint32_t threadblock_id_in_cluster, + // const uint32_t threadblock_id_in_cluster, float *sharedmem_per_threadblock) { const float *A = (const float *)arg->addr_a; const float *B = (const float *)arg->addr_b; @@ -602,7 +602,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, tid_in_warpgroup, block_n, block_m); - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); } // NOTE: this *should* be signed integer to trigger arithmetic @@ -633,11 +633,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, local_a_produce, local_b_produce, tid_in_warpgroup, block_n, block_m); - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); } // sync with final consumer stage in the k-loop - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); } } } else { @@ -650,7 +650,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, initialize_C(1); // sync with initial producer stage in the k-loop - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); // NOTE: this *should* be signed integer to trigger arithmetic // right-shift @@ -718,7 +718,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } - threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y); + threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y); #else @@ -819,7 +819,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const int warp_id = vx_warp_id(); thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x, - threadblock_id_y,*/ threadblock_id_in_cluster, + threadblock_id_y,*/ /*threadblock_id_in_cluster, */ sharedmem_per_threadblock); } From 2cac995db9dd589092d049a336ad8332851269ba Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 7 Jun 2024 18:13:57 -0700 Subject: [PATCH 53/55] tensor: generate 8x8 in correctness script --- kernel/src/vx_spawn.c | 1 + tests/kernel/tensor/check_correctness.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index 1d838c1f..759b915c 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -254,6 +254,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg vx_wspawn_wait(); } + // TODO: this is incomplete // TODO: Instead of launching an additional wave just to work on remaining // threads, handle this in the last wave amongst other full warps. if (rem_threads_in_last_warp != 0 && core_id_in_cluster == 0) { diff --git a/tests/kernel/tensor/check_correctness.py b/tests/kernel/tensor/check_correctness.py index de0c976a..13e28891 100644 --- a/tests/kernel/tensor/check_correctness.py +++ b/tests/kernel/tensor/check_correctness.py @@ -82,16 +82,23 @@ with open(file) as f: expected = np.load("abc.npz") -expected_A = expected['A_array'] -expected_B = expected['B_array'] -expected_C = expected['C_array'] +# expected_A = expected['A_array'] +# expected_B = expected['B_array'] +# expected_C = expected['C_array'] +expected_A = expected['A_array'][0:8, 0:8] +expected_B = expected['B_array'][0:8, 0:8] +expected_C = expected['C_array'][0:8, 0:8] expected_C = expected_C + expected_A @ expected_B +print('expected A:') +print(expected_A) +print('expected B:') +print(expected_B) print('expected C:') print(expected_C[0:8, 0:8]) print('got C:') print(C_array[0:8, 0:8]) print('diff C:') -print((expected_C - C_array)[0:8, 0:8]) +print(expected_C[0:8, 0:8] - C_array[0:8, 0:8]) expected_C.astype('float32').tofile("c_expected.bin") From 080923e869497900e48ae421d821949a2e8bfb44 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 7 Jun 2024 18:14:40 -0700 Subject: [PATCH 54/55] common.mk: Add more aggressive inline flag --- tests/regression/common.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/common.mk b/tests/regression/common.mk index 04ddbb3f..3efe9947 100644 --- a/tests/regression/common.mk +++ b/tests/regression/common.mk @@ -49,7 +49,7 @@ VX_CP = $(LLVM_VORTEX)/bin/llvm-objcopy #VX_CP = $(RISCV_TOOLCHAIN_PATH)/bin/$(RISCV_PREFIX)-objcopy VX_CFLAGS += -v -O3 -std=c++17 -VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections +VX_CFLAGS += -mcmodel=medany -fno-rtti -fno-exceptions -nostartfiles -fdata-sections -ffunction-sections -mllvm -inline-threshold=8192 VX_CFLAGS += -I$(VORTEX_KN_PATH)/include -I$(VORTEX_KN_PATH)/../hw -I$(GEMMINI_SW_PATH) VX_CFLAGS += -DNDEBUG -DLLVM_VORTEX From 800d9801b59e039cb8de3425ae65a46a8845f0e9 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 7 Jun 2024 18:19:20 -0700 Subject: [PATCH 55/55] tensor: Test with multiple accumulators --- tests/kernel/tensor/main.cpp | 38 +++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/kernel/tensor/main.cpp b/tests/kernel/tensor/main.cpp index 7fa759a8..d90c38be 100644 --- a/tests/kernel/tensor/main.cpp +++ b/tests/kernel/tensor/main.cpp @@ -11,6 +11,10 @@ inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } +inline void vx_wmma_new() { + asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + #include "test_data.h" inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -122,6 +126,14 @@ void vx_wmma_load() { asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5])); asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4])); asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5])); + asm volatile ("flw f24, %0" :: "m"(C[row+0][col+0])); + asm volatile ("flw f25, %0" :: "m"(C[row+0][col+1])); + asm volatile ("flw f26, %0" :: "m"(C[row+2][col+0])); + asm volatile ("flw f27, %0" :: "m"(C[row+2][col+1])); + asm volatile ("flw f28, %0" :: "m"(C[row+0][col+4])); + asm volatile ("flw f29, %0" :: "m"(C[row+0][col+5])); + asm volatile ("flw f30, %0" :: "m"(C[row+2][col+4])); + asm volatile ("flw f31, %0" :: "m"(C[row+2][col+5])); } // float results[32*8]; @@ -149,14 +161,22 @@ void store_wmma_result() { float *const results_wid = results + (DIM_M * DIM_M * wid); - asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); - asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); - asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); - asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); - asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); - asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); - asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); - asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); + // asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); + // asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); + // asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); + // asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); + // asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); + // asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); + // asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); + // asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); + asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)])); + asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)])); + asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)])); + asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)])); + asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)])); + asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)])); + asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)])); + asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)])); } void print_wmma_result() { @@ -184,7 +204,7 @@ void wmma() { // for (int i = 0; i < 100; i++) { // vx_wmma(); // } - vx_wmma(); + vx_wmma_new(); store_wmma_result(); // print_wmma_result();