diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 90dd2ec..9978a96 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -17,7 +17,7 @@ class DPUPipe extends Module with tile.HasFPUParameters { val in = Flipped(Valid(new Bundle { val a = Vec(dotProductDim, Bits((fLen).W)) val b = Vec(dotProductDim, Bits((fLen).W)) - val c = Vec(dotProductDim, Bits((fLen).W)) + val c = Bits((fLen).W) })) val out = Valid(new Bundle { val data = Bits((fLen).W) @@ -28,7 +28,7 @@ class DPUPipe extends Module with tile.HasFPUParameters { val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) - val in3 = io.in.bits.c.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) + val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S)) // val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig)) // fma.io.validin := io.in.valid @@ -43,12 +43,14 @@ class DPUPipe extends Module with tile.HasFPUParameters { dpu.io.in.valid := io.in.valid dpu.io.in.bits.a := in1 dpu.io.in.bits.b := in2 - // FIXME: in3 unused + dpu.io.in.bits.c := in3 io.out.valid := dpu.io.out.valid io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) } +// Computes d = a(0)*b(0) + ... + a(3)*b(3) + c. +// Fully pipelined with a fixed latency of 4 cycles. class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { require(dim == 4, "DPU currently only supports dimension 4") @@ -57,6 +59,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { val in = Flipped(Valid(new Bundle { val a = Vec(4, Bits((recFLen).W)) val b = Vec(4, Bits((recFLen).W)) + val c = Bits((recFLen).W) // val roundingMode = UInt(3.W) // val detectTininess = UInt(1.W) })) @@ -74,6 +77,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { } val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out))) + val mulStageC = Pipe(io.in.valid, io.in.bits.c) // mul stage end ------------------------------------------------------- @@ -87,8 +91,9 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { } val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out))) + val add1StageC = Pipe(mulStageC) - // add stage 1 end ----------------------------------------------------- + // add1 stage end ----------------------------------------------------- val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) add2.io.subOp := 0.U // FIXME @@ -98,11 +103,23 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { add2.io.roundingMode := hardfloat.consts.round_near_even add2.io.detectTininess := hardfloat.consts.tininess_afterRounding - io.out.valid := Pipe(add1StageOut.valid, false.B).valid - io.out.bits.data := Pipe(add1StageOut.valid, add2.io.out).bits + val add2StageOut = Pipe(add1StageOut.valid, add2.io.out) + val add2StageC = Pipe(add1StageC) + + // add2 stage end ----------------------------------------------------- + + val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) + acc.io.subOp := 0.U // FIXME + acc.io.a := add2StageOut.bits + acc.io.b := add2StageC.bits + acc.io.roundingMode := hardfloat.consts.round_near_even + acc.io.detectTininess := hardfloat.consts.tininess_afterRounding + + io.out.valid := Pipe(add2StageOut.valid, false.B).valid + io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits // FIXME: exception output ignored - // add stage 2 end ----------------------------------------------------- + // acc stage end ----------------------------------------------------- } class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index 5993edf..55ef3a7 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -60,13 +60,20 @@ class DPUPipeTest extends AnyFlatSpec with ChiselScalatestTester { c.io.in.bits.b(1).poke(0x40000000L.U(64.W)) c.io.in.bits.b(2).poke(0x40000000L.U(64.W)) c.io.in.bits.b(3).poke(0x40000000L.U(64.W)) + c.io.in.bits.c .poke(0x40400000L.U(64.W)) + c.clock.step() c.io.in.valid.poke(false.B) c.clock.step() c.clock.step() + c.clock.step() + // 4-cycle latency + c.io.out.valid.expect(true.B) c.io.out.bits.data.expect(0x40e00000L.U) + c.clock.step() + c.io.out.valid.expect(false.B) } }