From 46a57fdf9b944f9fdb6cb1ca5ad1da238aff2933 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 25 Oct 2024 21:44:36 -0700 Subject: [PATCH] tensor: Parameterize dimension in TensorDotProductUnit --- .../radiance/core/TensorCoreDecoupled.scala | 2 +- src/main/scala/radiance/core/TensorDPU.scala | 90 ++++++++++++------- .../radiance/TensorCoreDecoupledTest.scala | 2 +- src/test/scala/radiance/TensorDPUTest.scala | 14 +-- 4 files changed, 65 insertions(+), 43 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 20028d3..2f53269 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -543,7 +543,7 @@ class TensorCoreDecoupled( require(tilingParams.mc * ncSubstep == numLanes, "substep tile size doesn't match writeback throughput") val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)( - Module(new TensorDotProductUnit(half = false)) + Module(new TensorDotProductUnit(dim = 4, half = false)) )) // reshape operands for easier routing to DPU diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index db98b36..a4e6db0 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -9,7 +9,10 @@ import freechips.rocketchip.tile // Implements the four-element dot product (FEDP) unit in Volta Tensor Cores. // `half`: if True, generate fp16 MACs; if False fp32. -class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters { +class TensorDotProductUnit( + val dim: Int = 4, + val half: Boolean +) extends Module with tile.HasFPUParameters { val tIn = if (half) tile.FType.H else tile.FType.S // output datatype fixed to single-precision val tOut = tile.FType.S @@ -19,12 +22,11 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar val fLen = outFLen // needed for HasFPUParameters val minFLen = 16 // fp16 def xLen = 32 - val dotProductDim = 4 val io = IO(new Bundle { val in = Flipped(Valid(new Bundle { - val a = Vec(dotProductDim, Bits((inFLen).W)) - val b = Vec(dotProductDim, Bits((inFLen).W)) + val a = Vec(dim, Bits((inFLen).W)) + val b = Vec(dim, Bits((inFLen).W)) val c = Bits((outFLen).W) // note C has the out length for accumulation })) // 'stall' is effectively out.ready, combinationally coupled to in.ready @@ -43,7 +45,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn))) val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut)) - val dpu = Module(new DotProductPipe(dotProductDim, tIn, tOut)) + val dpu = Module(new DotProductPipe(dim, tIn, tOut)) dpu.io.in.valid := io.in.valid dpu.io.in.bits.a := in1 dpu.io.in.bits.b := in2 @@ -101,7 +103,6 @@ object StallingPipe { // Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c. // Fully pipelined with a fixed latency determined by `dim`. class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) extends Module { - require(dim == 4, "DPU currently only supports dimension 4") val expWidth = inputType.exp val sigWidth = inputType.sig val outExpWidth = outputType.exp @@ -111,8 +112,8 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex val recOutFLen = outExpWidth + outSigWidth + 1 val io = IO(new Bundle { val in = Flipped(Valid(new Bundle { - val a = Vec(4, Bits((recInFLen).W)) - val b = Vec(4, Bits((recInFLen).W)) + val a = Vec(dim, Bits((recInFLen).W)) + val b = Vec(dim, Bits((recInFLen).W)) val c = Bits((recOutFLen).W) // val roundingMode = UInt(3.W) // val detectTininess = UInt(1.W) @@ -141,6 +142,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex // assert(m.io.invalidExc === false.B) // round fp16*fp16 raw result back to fp32 recoded format + // @perf: possibly pipeline here for better timing val mulExpWidth = m.io.rawOut.expWidth val mulSigWidth = m.io.rawOut.sigWidth val roundRawFNToRecFN = @@ -160,45 +162,65 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex // mul stage end ------------------------------------------------------------- - val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))) - val add1Outs = add1.zipWithIndex.map { case (a, i) => - a.io.subOp := 0.U // FIXME dont know what this is - a.io.a := mulStageOut.bits(2 * i + 0) - a.io.b := mulStageOut.bits(2 * i + 1) - a.io.roundingMode := hardfloat.consts.round_near_even - a.io.detectTininess := hardfloat.consts.tininess_afterRounding - // assert(a.io.exceptionFlags === 0.U) - a.io.out + // reduce-add `dim` mul results down to one in a tree reduction + // + val log2Dim = log2Ceil(dim) + require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!") + + // instantiate wires for input values to each reduction pipeline stage + val interim = (log2Dim to 0 by -1).map { i => + Wire(Valid(Vec(1 << i, Bits(recOutFLen.W)))) } + // instantiate wires for pipe registers for C + val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) ) + // connect the first stage inputs + interim(0) := mulStageOut + interimC(0) := mulStageC - val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1Outs)) - val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits) + // now we get fancy + val (addStageOut, addStageC) = (interim zip interimC).reduce { + (inputsAndC, outputsAndC) => { + val (inputs, inC) = inputsAndC + val (outputs, outC) = outputsAndC - // add1 stage end ------------------------------------------------------------ + require(inputs.bits.length == 2 * outputs.bits.length) + val thisDim = inputs.bits.length + val adders = Seq.fill(thisDim / 2)( + Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) + ) + val addOuts = adders.zipWithIndex.map { case (a, i) => + a.io.subOp := 0.U // FIXME dont know what this is + a.io.a := inputs.bits(2 * i + 0) + a.io.b := inputs.bits(2 * i + 1) + a.io.roundingMode := hardfloat.consts.round_near_even + a.io.detectTininess := hardfloat.consts.tininess_afterRounding + // assert(a.io.exceptionFlags === 0.U) + a.io.out + } - val add2 = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) - add2.io.subOp := 0.U // FIXME - add2.io.a := add1StageOut.bits(0) - add2.io.b := add1StageOut.bits(1) - add2.io.roundingMode := hardfloat.consts.round_near_even - add2.io.detectTininess := hardfloat.consts.tininess_afterRounding - // assert(add2.io.exceptionFlags === 0.U) + // pipeline and connect outputs to the next stage + outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts)) + outC := StallingPipe(io.stall, inputs.valid, inC.bits) + assert(inputs.valid === inC.valid, + "adder inputs valid and C pipe valid went out-of-sync") - val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out) - val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits) + (outputs, outC) + } + } + require(addStageOut.bits.length == 1) - // add2 stage end ------------------------------------------------------------ + // add stages end ------------------------------------------------------------ + // add final A and B dot-product result to accumulator C val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) acc.io.subOp := 0.U // FIXME - acc.io.a := add2StageOut.bits - // acc.io.b := add2StageCRec - acc.io.b := add2StageC.bits + acc.io.a := addStageOut.bits(0) + acc.io.b := addStageC.bits acc.io.roundingMode := hardfloat.consts.round_near_even acc.io.detectTininess := hardfloat.consts.tininess_afterRounding // assert(acc.io.exceptionFlags === 0.U) - val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out) + val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out) // acc stage end ------------------------------------------------------------- diff --git a/src/test/scala/radiance/TensorCoreDecoupledTest.scala b/src/test/scala/radiance/TensorCoreDecoupledTest.scala index b1e0e9a..7b31eb7 100644 --- a/src/test/scala/radiance/TensorCoreDecoupledTest.scala +++ b/src/test/scala/radiance/TensorCoreDecoupledTest.scala @@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester { behavior of "TensorCoreDecoupled" it should "do the right thing" in { - test(new TensorCoreDecoupled(8, 8, tilingParams = TensorTilingParams())) + test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams())) { c => c.io.initiate.valid.poke(true.B) c.io.initiate.bits.wid.poke(0.U) diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index 87eb4e7..3eac1ca 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -46,8 +46,8 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { implicit val p: Parameters = Parameters.empty - it should "pass fp16" in { - test(new TensorDotProductUnit(half = true)) + it should "pass 4-dim fp16" in { + test(new TensorDotProductUnit(4, half = true)) // .withAnnotations(Seq(VerilatorBackendAnnotation)) // .withAnnotations(Seq(WriteVcdAnnotation)) { c => @@ -93,9 +93,9 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { } } - it should "pass fp16 2" in { - test(new TensorDotProductUnit(half = true)) - .withAnnotations(Seq(VerilatorBackendAnnotation)) + it should "pass 4-dim fp16 2" in { + test(new TensorDotProductUnit(4, half = true)) + // .withAnnotations(Seq(VerilatorBackendAnnotation)) // .withAnnotations(Seq(WriteVcdAnnotation)) { c => c.io.in.valid.poke(true.B) @@ -129,8 +129,8 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { } } - it should "pass fp32" in { - test(new TensorDotProductUnit(half = false)) + it should "pass 4-dim fp32" in { + test(new TensorDotProductUnit(4, half = false)) // .withAnnotations(Seq(VerilatorBackendAnnotation)) // .withAnnotations(Seq(WriteVcdAnnotation)) { c =>