tensor: Add FP16 parameter and expose to VX_core
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
module Vortex import VX_gpu_pkg::*; #(
|
module Vortex import VX_gpu_pkg::*; #(
|
||||||
parameter CORE_ID = 0,
|
parameter CORE_ID = 0,
|
||||||
|
parameter TENSOR_FP16 = 0,
|
||||||
parameter BOOTROM_HANG100 = 32'h10100,
|
parameter BOOTROM_HANG100 = 32'h10100,
|
||||||
parameter NUM_THREADS = 0
|
parameter NUM_THREADS = 0
|
||||||
) (
|
) (
|
||||||
@@ -394,7 +395,8 @@ module Vortex import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
// TODO: SCOPE_IO_BIND should be socket id
|
// TODO: SCOPE_IO_BIND should be socket id
|
||||||
VX_core #(
|
VX_core #(
|
||||||
.CORE_ID (CORE_ID)
|
.CORE_ID (CORE_ID),
|
||||||
|
.TENSOR_FP16 (TENSOR_FP16)
|
||||||
) core (
|
) core (
|
||||||
`SCOPE_IO_BIND (0)
|
`SCOPE_IO_BIND (0)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,8 @@
|
|||||||
`endif
|
`endif
|
||||||
|
|
||||||
module VX_core import VX_gpu_pkg::*; #(
|
module VX_core import VX_gpu_pkg::*; #(
|
||||||
parameter CORE_ID = 0
|
parameter CORE_ID = 0,
|
||||||
|
parameter TENSOR_FP16 = 0
|
||||||
) (
|
) (
|
||||||
`SCOPE_IO_DECL
|
`SCOPE_IO_DECL
|
||||||
|
|
||||||
@@ -191,7 +192,8 @@ module VX_core import VX_gpu_pkg::*; #(
|
|||||||
);
|
);
|
||||||
|
|
||||||
VX_execute #(
|
VX_execute #(
|
||||||
.CORE_ID (CORE_ID)
|
.CORE_ID (CORE_ID),
|
||||||
|
.TENSOR_FP16 (TENSOR_FP16)
|
||||||
) execute (
|
) execute (
|
||||||
`SCOPE_IO_BIND (2)
|
`SCOPE_IO_BIND (2)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,8 @@
|
|||||||
`include "VX_define.vh"
|
`include "VX_define.vh"
|
||||||
|
|
||||||
module VX_execute import VX_gpu_pkg::*; #(
|
module VX_execute import VX_gpu_pkg::*; #(
|
||||||
parameter CORE_ID = 0
|
parameter CORE_ID = 0,
|
||||||
|
parameter TENSOR_FP16 = 0
|
||||||
) (
|
) (
|
||||||
`SCOPE_IO_DECL
|
`SCOPE_IO_DECL
|
||||||
|
|
||||||
@@ -144,7 +145,7 @@ module VX_execute import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
`ifdef EXT_T_ENABLE
|
`ifdef EXT_T_ENABLE
|
||||||
VX_tensor_core #(
|
VX_tensor_core #(
|
||||||
|
.FP16 (TENSOR_FP16)
|
||||||
) tensor_core (
|
) tensor_core (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
`include "VX_fpu_define.vh"
|
`include "VX_fpu_define.vh"
|
||||||
|
|
||||||
module VX_tensor_core import VX_gpu_pkg::*; #(
|
module VX_tensor_core import VX_gpu_pkg::*; #(
|
||||||
|
parameter FP16
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
@@ -52,7 +52,8 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||||
VX_tensor_core_block #(
|
VX_tensor_core_block #(
|
||||||
.ISW(1) // FIXME: not block_idx
|
.ISW(1), // FIXME: not block_idx
|
||||||
|
.FP16(FP16)
|
||||||
) tensor_core (
|
) tensor_core (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -65,7 +66,8 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
|||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
||||||
parameter ISW
|
parameter ISW,
|
||||||
|
parameter FP16
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
@@ -121,7 +123,8 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
VX_tensor_octet #(
|
VX_tensor_octet #(
|
||||||
.ISW(ISW),
|
.ISW(ISW),
|
||||||
.OCTET(i)
|
.OCTET(i),
|
||||||
|
.FP16(FP16)
|
||||||
) octet (
|
) octet (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -329,6 +332,7 @@ endmodule
|
|||||||
module VX_tensor_octet #(
|
module VX_tensor_octet #(
|
||||||
parameter ISW,
|
parameter ISW,
|
||||||
parameter OCTET,
|
parameter OCTET,
|
||||||
|
parameter FP16,
|
||||||
parameter RESULT_BUFFER_DEPTH = 2
|
parameter RESULT_BUFFER_DEPTH = 2
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
@@ -519,6 +523,7 @@ module VX_tensor_octet #(
|
|||||||
VX_tensor_threadgroups #(
|
VX_tensor_threadgroups #(
|
||||||
.ISW(ISW),
|
.ISW(ISW),
|
||||||
.OCTET(OCTET),
|
.OCTET(OCTET),
|
||||||
|
.FP16(FP16),
|
||||||
.OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/)
|
.OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/)
|
||||||
) dpu (
|
) dpu (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
module VX_tensor_threadgroups #(
|
module VX_tensor_threadgroups #(
|
||||||
parameter ISW,
|
parameter ISW,
|
||||||
parameter OCTET,
|
parameter OCTET,
|
||||||
|
parameter FP16,
|
||||||
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
||||||
// the pipeline length of FEDPs in order to make sure there are enough
|
// the pipeline length of FEDPs in order to make sure there are enough
|
||||||
// entries to fully saturate the pipeline, but this is still rough
|
// entries to fully saturate the pipeline, but this is still rough
|
||||||
@@ -102,6 +103,7 @@ module VX_tensor_threadgroups #(
|
|||||||
// threadgroup DPUs; B_tile is shared across the two threadgroups. See
|
// threadgroup DPUs; B_tile is shared across the two threadgroups. See
|
||||||
// Figure 13 in paper
|
// Figure 13 in paper
|
||||||
VX_tensor_threadgroup #(
|
VX_tensor_threadgroup #(
|
||||||
|
.FP16(FP16)
|
||||||
) threadgroup_0 (
|
) threadgroup_0 (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
@@ -115,6 +117,7 @@ module VX_tensor_threadgroups #(
|
|||||||
.D_frag (D_tile[1:0])
|
.D_frag (D_tile[1:0])
|
||||||
);
|
);
|
||||||
VX_tensor_threadgroup #(
|
VX_tensor_threadgroup #(
|
||||||
|
.FP16(FP16)
|
||||||
) threadgroup_1 (
|
) threadgroup_1 (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
@@ -165,7 +168,7 @@ endmodule
|
|||||||
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
||||||
// see Figure 10(b) of the paper.
|
// see Figure 10(b) of the paper.
|
||||||
module VX_tensor_threadgroup #(
|
module VX_tensor_threadgroup #(
|
||||||
parameter HALF_PRECISION = 1
|
parameter FP16
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
@@ -297,7 +300,7 @@ module VX_tensor_threadgroup #(
|
|||||||
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
|
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
|
||||||
|
|
||||||
// Dot product (FEDP) unit generated from Chisel
|
// Dot product (FEDP) unit generated from Chisel
|
||||||
if (HALF_PRECISION != 0) begin
|
if (FP16 != 0) begin
|
||||||
TensorDotProductUnit fedp (
|
TensorDotProductUnit fedp (
|
||||||
.clock (clk),
|
.clock (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
|
|||||||
Reference in New Issue
Block a user