diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index dd848d4..5732fe2 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -43,6 +43,49 @@ class TensorDotProductUnit extends Module with tile.HasFPUParameters { io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) } +// Copied from chisel3.util.Pipe. +class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module { + /** A non-ambiguous name of this `StallingPipe` for use in generated Verilog + * names. Includes the latency cycle count in the name as well as the + * parameterized generator's `typeName`, e.g. `Pipe4_UInt4` + */ + override def desiredName = s"${simpleClassName(this.getClass)}${latency}_${gen.typeName}" + + class StallingPipeIO extends Bundle { + val stall = Input(Bool()) + val enq = Input(Valid(gen)) + val deq = Output(Valid(gen)) + } + + val io = IO(new StallingPipeIO) + + io.deq <> StallingPipe(io.stall, io.enq, latency) +} + +object StallingPipe { + import chisel3.experimental.prefix + + def apply[T <: Data](stall: Bool, enqValid: Bool, enqBits: T, latency: Int): Valid[T] = { + require(latency == 1, "StallingPipe only supports latency equals one!") + prefix("stalling_pipe") { + val out = Wire(Valid(chiselTypeOf(enqBits))) + val v = RegEnable(enqValid, false.B, !stall) + val b = RegEnable(enqBits, !stall && enqValid) + out.valid := v + out.bits := b + out + } + } + + def apply[T <: Data](stall: Bool, enqValid: Bool, enqBits: T): Valid[T] = { + apply(stall, enqValid, enqBits, 1) + } + + def apply[T <: Data](stall: Bool, enq: Valid[T], latency: Int = 1): Valid[T] = { + apply(stall, enq.valid, enq.bits, latency) + } +} + // Computes d = a(0)*b(0) + ... + a(3)*b(3) + c. // Fully pipelined with a fixed latency of 4 cycles. class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { @@ -72,8 +115,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { m.io.b := io.in.bits.b(i) } - val mulStageOut = Pipe(!io.stall && io.in.valid, VecInit(mul.map(_.io.out))) - val mulStageC = Pipe(!io.stall && io.in.valid, io.in.bits.c) + val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.out))) + val mulStageC = StallingPipe(io.stall, io.in.valid, io.in.bits.c) // mul stage end ------------------------------------------------------------- @@ -86,8 +129,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { a.io.detectTininess := hardfloat.consts.tininess_afterRounding } - val add1StageOut = Pipe(!io.stall && mulStageOut.valid, VecInit(add1.map(_.io.out)), latency = 0) - val add1StageC = Pipe(!io.stall && mulStageOut.valid, mulStageC.bits, latency = 0) + val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out))) + val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits) // add1 stage end ------------------------------------------------------------ @@ -99,8 +142,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { add2.io.roundingMode := hardfloat.consts.round_near_even add2.io.detectTininess := hardfloat.consts.tininess_afterRounding - val add2StageOut = Pipe(!io.stall && add1StageOut.valid, add2.io.out, latency = 0) - val add2StageC = Pipe(!io.stall && add1StageOut.valid, add1StageC.bits, latency = 0) + val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out) + val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits) // add2 stage end ------------------------------------------------------------ @@ -111,7 +154,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { acc.io.roundingMode := hardfloat.consts.round_near_even acc.io.detectTininess := hardfloat.consts.tininess_afterRounding - val accStageOut = Pipe(!io.stall && add2StageOut.valid, acc.io.out) + val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out) // FIXME: exception output ignored // acc stage end ------------------------------------------------------------- diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index 1fbfd4d..b31e3fe 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -48,7 +48,7 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { it should "pass" in { test(new TensorDotProductUnit) - .withAnnotations(Seq(VerilatorBackendAnnotation)) + // .withAnnotations(Seq(VerilatorBackendAnnotation)) // .withAnnotations(Seq(WriteVcdAnnotation)) { c => c.io.in.valid.poke(true.B) @@ -69,14 +69,20 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester { c.clock.step() c.io.in.valid.poke(false.B) c.io.out.valid.expect(false.B) + // stall the pipeline - // c.io.stall.poke(true.B) + c.io.stall.poke(true.B) c.clock.step() - // c.io.stall.poke(false.B) - // c.io.out.valid.expect(false.B) - // c.clock.step() - // c.clock.step() - // 4-cycle latency + c.io.stall.poke(true.B) + c.clock.step() + c.io.stall.poke(true.B) + c.clock.step() + c.io.stall.poke(false.B) + + c.clock.step() + c.clock.step() + c.clock.step() + // 4-cycle latency + stalls c.io.out.valid.expect(true.B) c.io.out.bits.data.expect(0x42da0000L.U)