Implement proper stalls for dpu
This commit is contained in:
@@ -43,6 +43,49 @@ class TensorDotProductUnit extends Module with tile.HasFPUParameters {
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
|
||||
// Copied from chisel3.util.Pipe.
|
||||
class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module {
|
||||
/** A non-ambiguous name of this `StallingPipe` for use in generated Verilog
|
||||
* names. Includes the latency cycle count in the name as well as the
|
||||
* parameterized generator's `typeName`, e.g. `Pipe4_UInt4`
|
||||
*/
|
||||
override def desiredName = s"${simpleClassName(this.getClass)}${latency}_${gen.typeName}"
|
||||
|
||||
class StallingPipeIO extends Bundle {
|
||||
val stall = Input(Bool())
|
||||
val enq = Input(Valid(gen))
|
||||
val deq = Output(Valid(gen))
|
||||
}
|
||||
|
||||
val io = IO(new StallingPipeIO)
|
||||
|
||||
io.deq <> StallingPipe(io.stall, io.enq, latency)
|
||||
}
|
||||
|
||||
object StallingPipe {
|
||||
import chisel3.experimental.prefix
|
||||
|
||||
def apply[T <: Data](stall: Bool, enqValid: Bool, enqBits: T, latency: Int): Valid[T] = {
|
||||
require(latency == 1, "StallingPipe only supports latency equals one!")
|
||||
prefix("stalling_pipe") {
|
||||
val out = Wire(Valid(chiselTypeOf(enqBits)))
|
||||
val v = RegEnable(enqValid, false.B, !stall)
|
||||
val b = RegEnable(enqBits, !stall && enqValid)
|
||||
out.valid := v
|
||||
out.bits := b
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
def apply[T <: Data](stall: Bool, enqValid: Bool, enqBits: T): Valid[T] = {
|
||||
apply(stall, enqValid, enqBits, 1)
|
||||
}
|
||||
|
||||
def apply[T <: Data](stall: Bool, enq: Valid[T], latency: Int = 1): Valid[T] = {
|
||||
apply(stall, enq.valid, enq.bits, latency)
|
||||
}
|
||||
}
|
||||
|
||||
// Computes d = a(0)*b(0) + ... + a(3)*b(3) + c.
|
||||
// Fully pipelined with a fixed latency of 4 cycles.
|
||||
class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
@@ -72,8 +115,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
m.io.b := io.in.bits.b(i)
|
||||
}
|
||||
|
||||
val mulStageOut = Pipe(!io.stall && io.in.valid, VecInit(mul.map(_.io.out)))
|
||||
val mulStageC = Pipe(!io.stall && io.in.valid, io.in.bits.c)
|
||||
val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.out)))
|
||||
val mulStageC = StallingPipe(io.stall, io.in.valid, io.in.bits.c)
|
||||
|
||||
// mul stage end -------------------------------------------------------------
|
||||
|
||||
@@ -86,8 +129,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
}
|
||||
|
||||
val add1StageOut = Pipe(!io.stall && mulStageOut.valid, VecInit(add1.map(_.io.out)), latency = 0)
|
||||
val add1StageC = Pipe(!io.stall && mulStageOut.valid, mulStageC.bits, latency = 0)
|
||||
val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out)))
|
||||
val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits)
|
||||
|
||||
// add1 stage end ------------------------------------------------------------
|
||||
|
||||
@@ -99,8 +142,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
add2.io.roundingMode := hardfloat.consts.round_near_even
|
||||
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
val add2StageOut = Pipe(!io.stall && add1StageOut.valid, add2.io.out, latency = 0)
|
||||
val add2StageC = Pipe(!io.stall && add1StageOut.valid, add1StageC.bits, latency = 0)
|
||||
val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out)
|
||||
val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits)
|
||||
|
||||
// add2 stage end ------------------------------------------------------------
|
||||
|
||||
@@ -111,7 +154,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
val accStageOut = Pipe(!io.stall && add2StageOut.valid, acc.io.out)
|
||||
val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out)
|
||||
// FIXME: exception output ignored
|
||||
|
||||
// acc stage end -------------------------------------------------------------
|
||||
|
||||
@@ -48,7 +48,7 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
|
||||
it should "pass" in {
|
||||
test(new TensorDotProductUnit)
|
||||
.withAnnotations(Seq(VerilatorBackendAnnotation))
|
||||
// .withAnnotations(Seq(VerilatorBackendAnnotation))
|
||||
// .withAnnotations(Seq(WriteVcdAnnotation))
|
||||
{ c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
@@ -69,14 +69,20 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.io.out.valid.expect(false.B)
|
||||
|
||||
// stall the pipeline
|
||||
// c.io.stall.poke(true.B)
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
// c.io.stall.poke(false.B)
|
||||
// c.io.out.valid.expect(false.B)
|
||||
// c.clock.step()
|
||||
// c.clock.step()
|
||||
// 4-cycle latency
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(true.B)
|
||||
c.clock.step()
|
||||
c.io.stall.poke(false.B)
|
||||
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
// 4-cycle latency + stalls
|
||||
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x42da0000L.U)
|
||||
|
||||
Reference in New Issue
Block a user