tensor: Add FP16 parameter and expose to VX_core

This commit is contained in:
Hansung Kim
2024-09-10 15:25:48 -07:00
parent a968bdd69b
commit da54162241
5 changed files with 24 additions and 11 deletions

View File

@@ -4,6 +4,7 @@
module Vortex import VX_gpu_pkg::*; #( module Vortex import VX_gpu_pkg::*; #(
parameter CORE_ID = 0, parameter CORE_ID = 0,
parameter TENSOR_FP16 = 0,
parameter BOOTROM_HANG100 = 32'h10100, parameter BOOTROM_HANG100 = 32'h10100,
parameter NUM_THREADS = 0 parameter NUM_THREADS = 0
) ( ) (
@@ -394,7 +395,8 @@ module Vortex import VX_gpu_pkg::*; #(
// TODO: SCOPE_IO_BIND should be socket id // TODO: SCOPE_IO_BIND should be socket id
VX_core #( VX_core #(
.CORE_ID (CORE_ID) .CORE_ID (CORE_ID),
.TENSOR_FP16 (TENSOR_FP16)
) core ( ) core (
`SCOPE_IO_BIND (0) `SCOPE_IO_BIND (0)

View File

@@ -18,7 +18,8 @@
`endif `endif
module VX_core import VX_gpu_pkg::*; #( module VX_core import VX_gpu_pkg::*; #(
parameter CORE_ID = 0 parameter CORE_ID = 0,
parameter TENSOR_FP16 = 0
) ( ) (
`SCOPE_IO_DECL `SCOPE_IO_DECL
@@ -191,7 +192,8 @@ module VX_core import VX_gpu_pkg::*; #(
); );
VX_execute #( VX_execute #(
.CORE_ID (CORE_ID) .CORE_ID (CORE_ID),
.TENSOR_FP16 (TENSOR_FP16)
) execute ( ) execute (
`SCOPE_IO_BIND (2) `SCOPE_IO_BIND (2)

View File

@@ -14,7 +14,8 @@
`include "VX_define.vh" `include "VX_define.vh"
module VX_execute import VX_gpu_pkg::*; #( module VX_execute import VX_gpu_pkg::*; #(
parameter CORE_ID = 0 parameter CORE_ID = 0,
parameter TENSOR_FP16 = 0
) ( ) (
`SCOPE_IO_DECL `SCOPE_IO_DECL
@@ -144,7 +145,7 @@ module VX_execute import VX_gpu_pkg::*; #(
`ifdef EXT_T_ENABLE `ifdef EXT_T_ENABLE
VX_tensor_core #( VX_tensor_core #(
.FP16 (TENSOR_FP16)
) tensor_core ( ) tensor_core (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),

View File

@@ -2,7 +2,7 @@
`include "VX_fpu_define.vh" `include "VX_fpu_define.vh"
module VX_tensor_core import VX_gpu_pkg::*; #( module VX_tensor_core import VX_gpu_pkg::*; #(
parameter FP16
) ( ) (
input clk, input clk,
input reset, input reset,
@@ -52,7 +52,8 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
VX_tensor_core_block #( VX_tensor_core_block #(
.ISW(1) // FIXME: not block_idx .ISW(1), // FIXME: not block_idx
.FP16(FP16)
) tensor_core ( ) tensor_core (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
@@ -65,7 +66,8 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
endmodule endmodule
module VX_tensor_core_block import VX_gpu_pkg::*; #( module VX_tensor_core_block import VX_gpu_pkg::*; #(
parameter ISW parameter ISW,
parameter FP16
) ( ) (
input clk, input clk,
input reset, input reset,
@@ -121,7 +123,8 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #(
VX_tensor_octet #( VX_tensor_octet #(
.ISW(ISW), .ISW(ISW),
.OCTET(i) .OCTET(i),
.FP16(FP16)
) octet ( ) octet (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
@@ -329,6 +332,7 @@ endmodule
module VX_tensor_octet #( module VX_tensor_octet #(
parameter ISW, parameter ISW,
parameter OCTET, parameter OCTET,
parameter FP16,
parameter RESULT_BUFFER_DEPTH = 2 parameter RESULT_BUFFER_DEPTH = 2
) ( ) (
input clk, input clk,
@@ -519,6 +523,7 @@ module VX_tensor_octet #(
VX_tensor_threadgroups #( VX_tensor_threadgroups #(
.ISW(ISW), .ISW(ISW),
.OCTET(OCTET), .OCTET(OCTET),
.FP16(FP16),
.OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/) .OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/)
) dpu ( ) dpu (
.clk(clk), .clk(clk),

View File

@@ -5,6 +5,7 @@
module VX_tensor_threadgroups #( module VX_tensor_threadgroups #(
parameter ISW, parameter ISW,
parameter OCTET, parameter OCTET,
parameter FP16,
// @perf: has big impact on throughput. A rule of thumb is to set it to // @perf: has big impact on throughput. A rule of thumb is to set it to
// the pipeline length of FEDPs in order to make sure there are enough // the pipeline length of FEDPs in order to make sure there are enough
// entries to fully saturate the pipeline, but this is still rough // entries to fully saturate the pipeline, but this is still rough
@@ -102,6 +103,7 @@ module VX_tensor_threadgroups #(
// threadgroup DPUs; B_tile is shared across the two threadgroups. See // threadgroup DPUs; B_tile is shared across the two threadgroups. See
// Figure 13 in paper // Figure 13 in paper
VX_tensor_threadgroup #( VX_tensor_threadgroup #(
.FP16(FP16)
) threadgroup_0 ( ) threadgroup_0 (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
@@ -115,6 +117,7 @@ module VX_tensor_threadgroups #(
.D_frag (D_tile[1:0]) .D_frag (D_tile[1:0])
); );
VX_tensor_threadgroup #( VX_tensor_threadgroup #(
.FP16(FP16)
) threadgroup_1 ( ) threadgroup_1 (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
@@ -165,7 +168,7 @@ endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
// see Figure 10(b) of the paper. // see Figure 10(b) of the paper.
module VX_tensor_threadgroup #( module VX_tensor_threadgroup #(
parameter HALF_PRECISION = 1 parameter FP16
) ( ) (
input clk, input clk,
input reset, input reset,
@@ -297,7 +300,7 @@ module VX_tensor_threadgroup #(
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1); wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
// Dot product (FEDP) unit generated from Chisel // Dot product (FEDP) unit generated from Chisel
if (HALF_PRECISION != 0) begin if (FP16 != 0) begin
TensorDotProductUnit fedp ( TensorDotProductUnit fedp (
.clock (clk), .clock (clk),
.reset (reset), .reset (reset),