Add accumulation to dpu
This commit is contained in:
@@ -17,7 +17,7 @@ class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Vec(dotProductDim, Bits((fLen).W))
|
||||
val b = Vec(dotProductDim, Bits((fLen).W))
|
||||
val c = Vec(dotProductDim, Bits((fLen).W))
|
||||
val c = Bits((fLen).W)
|
||||
}))
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((fLen).W)
|
||||
@@ -28,7 +28,7 @@ class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in3 = io.in.bits.c.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S))
|
||||
|
||||
// val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
|
||||
// fma.io.validin := io.in.valid
|
||||
@@ -43,12 +43,14 @@ class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := in1
|
||||
dpu.io.in.bits.b := in2
|
||||
// FIXME: in3 unused
|
||||
dpu.io.in.bits.c := in3
|
||||
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
|
||||
// Computes d = a(0)*b(0) + ... + a(3)*b(3) + c.
|
||||
// Fully pipelined with a fixed latency of 4 cycles.
|
||||
class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
require(dim == 4, "DPU currently only supports dimension 4")
|
||||
|
||||
@@ -57,6 +59,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Vec(4, Bits((recFLen).W))
|
||||
val b = Vec(4, Bits((recFLen).W))
|
||||
val c = Bits((recFLen).W)
|
||||
// val roundingMode = UInt(3.W)
|
||||
// val detectTininess = UInt(1.W)
|
||||
}))
|
||||
@@ -74,6 +77,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
}
|
||||
|
||||
val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out)))
|
||||
val mulStageC = Pipe(io.in.valid, io.in.bits.c)
|
||||
|
||||
// mul stage end -------------------------------------------------------
|
||||
|
||||
@@ -87,8 +91,9 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
}
|
||||
|
||||
val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out)))
|
||||
val add1StageC = Pipe(mulStageC)
|
||||
|
||||
// add stage 1 end -----------------------------------------------------
|
||||
// add1 stage end -----------------------------------------------------
|
||||
|
||||
val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
add2.io.subOp := 0.U // FIXME
|
||||
@@ -98,11 +103,23 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
add2.io.roundingMode := hardfloat.consts.round_near_even
|
||||
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
io.out.valid := Pipe(add1StageOut.valid, false.B).valid
|
||||
io.out.bits.data := Pipe(add1StageOut.valid, add2.io.out).bits
|
||||
val add2StageOut = Pipe(add1StageOut.valid, add2.io.out)
|
||||
val add2StageC = Pipe(add1StageC)
|
||||
|
||||
// add2 stage end -----------------------------------------------------
|
||||
|
||||
val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
acc.io.subOp := 0.U // FIXME
|
||||
acc.io.a := add2StageOut.bits
|
||||
acc.io.b := add2StageC.bits
|
||||
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
io.out.valid := Pipe(add2StageOut.valid, false.B).valid
|
||||
io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits
|
||||
// FIXME: exception output ignored
|
||||
|
||||
// add stage 2 end -----------------------------------------------------
|
||||
// acc stage end -----------------------------------------------------
|
||||
}
|
||||
|
||||
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
|
||||
@@ -60,13 +60,20 @@ class DPUPipeTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
c.io.in.bits.b(1).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(2).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(3).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.c .poke(0x40400000L.U(64.W))
|
||||
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
// 4-cycle latency
|
||||
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x40e00000L.U)
|
||||
|
||||
c.clock.step()
|
||||
|
||||
c.io.out.valid.expect(false.B)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user