Make dpu 2-stage

For debugging, need to revert.
This commit is contained in:
Hansung Kim
2024-05-29 13:31:38 -07:00
parent 8dd3994012
commit 4a43d0126d
2 changed files with 28 additions and 22 deletions

View File

@@ -72,8 +72,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
m.io.b := io.in.bits.b(i)
}
val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out)))
val mulStageC = Pipe(io.in.valid, io.in.bits.c)
val mulStageOut = Pipe(!io.stall && io.in.valid, VecInit(mul.map(_.io.out)))
val mulStageC = Pipe(!io.stall && io.in.valid, io.in.bits.c)
// mul stage end -------------------------------------------------------------
@@ -86,8 +86,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
}
val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out)))
val add1StageC = Pipe(mulStageC)
val add1StageOut = Pipe(!io.stall && mulStageOut.valid, VecInit(add1.map(_.io.out)), latency = 0)
val add1StageC = Pipe(!io.stall && mulStageOut.valid, mulStageC.bits, latency = 0)
// add1 stage end ------------------------------------------------------------
@@ -99,8 +99,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
add2.io.roundingMode := hardfloat.consts.round_near_even
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
val add2StageOut = Pipe(add1StageOut.valid, add2.io.out)
val add2StageC = Pipe(add1StageC)
val add2StageOut = Pipe(!io.stall && add1StageOut.valid, add2.io.out, latency = 0)
val add2StageC = Pipe(!io.stall && add1StageOut.valid, add1StageC.bits, latency = 0)
// add2 stage end ------------------------------------------------------------
@@ -111,11 +111,13 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
io.out.valid := Pipe(add2StageOut.valid, false.B).valid
io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits
val accStageOut = Pipe(!io.stall && add2StageOut.valid, acc.io.out)
// FIXME: exception output ignored
// acc stage end -------------------------------------------------------------
io.out.valid := accStageOut.valid
io.out.bits.data := accStageOut.bits
}
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {

View File

@@ -49,33 +49,37 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
it should "pass" in {
test(new TensorDotProductUnit)
.withAnnotations(Seq(VerilatorBackendAnnotation))
.withAnnotations(Seq(WriteVcdAnnotation))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ c =>
c.io.in.valid.poke(true.B)
c.io.stall.poke(false.B)
// (2,2,2,2)*(2,2,2,2) + 3 = 19
c.io.in.bits.a(0).poke(0x40000000L.U(64.W))
c.io.in.bits.a(1).poke(0x40000000L.U(64.W))
c.io.in.bits.a(2).poke(0x40000000L.U(64.W))
c.io.in.bits.a(3).poke(0x40000000L.U(64.W))
// (1,3,5,7)*(2,4,6,8) + 9 = 109
c.io.in.bits.a(0).poke(0x3f800000L.U(64.W))
c.io.in.bits.a(1).poke(0x40400000L.U(64.W))
c.io.in.bits.a(2).poke(0x40a00000L.U(64.W))
c.io.in.bits.a(3).poke(0x40e00000L.U(64.W))
c.io.in.bits.b(0).poke(0x40000000L.U(64.W))
c.io.in.bits.b(1).poke(0x40000000L.U(64.W))
c.io.in.bits.b(2).poke(0x40000000L.U(64.W))
c.io.in.bits.b(3).poke(0x40000000L.U(64.W))
c.io.in.bits.c .poke(0x40400000L.U(64.W))
c.io.in.bits.b(1).poke(0x40800000L.U(64.W))
c.io.in.bits.b(2).poke(0x40c00000L.U(64.W))
c.io.in.bits.b(3).poke(0x41000000L.U(64.W))
c.io.in.bits.c .poke(0x41100000L.U(64.W))
c.io.out.valid.expect(false.B)
c.clock.step()
c.io.in.valid.poke(false.B)
c.io.out.valid.expect(false.B)
// stall the pipeline
// c.io.stall.poke(true.B)
c.clock.step()
c.io.stall.poke(false.B)
c.clock.step()
c.clock.step()
// c.io.stall.poke(false.B)
// c.io.out.valid.expect(false.B)
// c.clock.step()
// c.clock.step()
// 4-cycle latency
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x41980000L.U)
c.io.out.bits.data.expect(0x42da0000L.U)
c.clock.step()