Write four-element dpu without accumulation

This commit is contained in:
Hansung Kim
2024-05-28 18:27:56 -07:00
parent db889c5e22
commit 3b1ab4e10d
2 changed files with 99 additions and 36 deletions

View File

@@ -11,35 +11,98 @@ class DPUPipe extends Module with tile.HasFPUParameters {
val fLen = 32
val minFLen = 32
def xLen = 32
val dotProductDim = 4
val io = IO(new Bundle {
val in = Flipped(Valid(new Bundle {
val a = Bits((fLen).W)
val b = Bits((fLen).W)
val c = Bits((fLen).W)
val a = Vec(dotProductDim, Bits((fLen).W))
val b = Vec(dotProductDim, Bits((fLen).W))
val c = Vec(dotProductDim, Bits((fLen).W))
}))
val out = Valid(new Bundle {
val data = Bits((fLen+1).W)
val data = Bits((fLen).W)
})
})
val t = tile.FType.S
val in1 = recode(io.in.bits.a, S)
val in2 = recode(io.in.bits.b, S)
val in3 = recode(io.in.bits.c, S)
val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
val in3 = io.in.bits.c.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
fma.io.validin := io.in.valid
fma.io.op := 0.U // FIXME
fma.io.roundingMode := 0.U // FIXME
fma.io.detectTininess := hardfloat.consts.tininess_afterRounding
fma.io.a := unbox(in1, S, Some(tile.FType.S))
fma.io.b := unbox(in2, S, Some(tile.FType.S))
fma.io.c := unbox(in3, S, Some(tile.FType.S))
// val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
// fma.io.validin := io.in.valid
// fma.io.op := 0.U // FIXME
// fma.io.roundingMode := hardfloat.consts.round_near_even
// fma.io.detectTininess := hardfloat.consts.tininess_afterRounding
// fma.io.a := unbox(in1, S, Some(tile.FType.S))
// fma.io.b := unbox(in2, S, Some(tile.FType.S))
// fma.io.c := unbox(in3, S, Some(tile.FType.S))
io.out.valid := fma.io.validout
io.out.bits.data := ieee(box(fma.io.out, S))
val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig))
dpu.io.in.valid := io.in.valid
dpu.io.in.bits.a := in1
dpu.io.in.bits.b := in2
// FIXME: in3 unused
io.out.valid := dpu.io.out.valid
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
}
class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
require(dim == 4, "DPU currently only supports dimension 4")
val recFLen = expWidth + sigWidth + 1
val io = IO(new Bundle {
val in = Flipped(Valid(new Bundle {
val a = Vec(4, Bits((recFLen).W))
val b = Vec(4, Bits((recFLen).W))
// val roundingMode = UInt(3.W)
// val detectTininess = UInt(1.W)
}))
val out = Valid(new Bundle {
val data = Bits((recFLen).W)
})
})
val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth)))
mul.zipWithIndex.foreach { case (m, i) =>
m.io.roundingMode := hardfloat.consts.round_near_even // consts.round_near_maxMag
m.io.detectTininess := hardfloat.consts.tininess_afterRounding
m.io.a := io.in.bits.a(i)
m.io.b := io.in.bits.b(i)
}
val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out)))
// mul stage end -------------------------------------------------------
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth)))
add1.zipWithIndex.foreach { case (a, i) =>
a.io.subOp := 0.U // FIXME
a.io.a := mulStageOut.bits(2 * i + 0)
a.io.b := mulStageOut.bits(2 * i + 1)
a.io.roundingMode := hardfloat.consts.round_near_even
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
}
val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out)))
// add stage 1 end -----------------------------------------------------
val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
add2.io.subOp := 0.U // FIXME
assert(add1StageOut.bits.length == 2)
add2.io.a := add1StageOut.bits(0)
add2.io.b := add1StageOut.bits(1)
add2.io.roundingMode := hardfloat.consts.round_near_even
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
io.out.valid := Pipe(add1StageOut.valid, false.B).valid
io.out.bits.data := Pipe(add1StageOut.valid, add2.io.out).bits
// FIXME: exception output ignored
// add stage 2 end -----------------------------------------------------
}
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {

View File

@@ -1,7 +1,6 @@
package radiance.core
import chisel3._
import chisel3.stage.PrintFullStackTraceAnnotation
import chisel3.util._
import chiseltest._
import chiseltest.simulator.VerilatorFlags
@@ -49,25 +48,26 @@ class DPUPipeTest extends AnyFlatSpec with ChiselScalatestTester {
it should "pass" in {
test(new DPUPipe)
// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteFstAnnotation))
// .withAnnotations(Seq(WriteVcdAnnotation))
{ fma =>
fma.io.in.valid.poke(true.B)
fma.io.in.bits.a.poke(0x40000000L.U(64.W))
fma.io.in.bits.b.poke(0x40400000L.U(64.W))
fma.io.in.bits.c.poke(0x3f800000L.U(64.W))
fma.clock.step()
fma.io.in.valid.poke(true.B)
fma.io.in.bits.a.poke(0x40000000L.U(64.W))
fma.io.in.bits.b.poke(0x3f800000L.U(64.W))
fma.io.in.bits.c.poke(0x3f800000L.U(64.W))
fma.clock.step()
fma.io.in.valid.poke(false.B)
fma.io.out.valid.expect(true.B)
fma.io.out.bits.data.expect(0x40e00000L.U)
fma.clock.step()
// pipelined back-to-back response
fma.io.out.valid.expect(true.B)
fma.io.out.bits.data.expect(0x40400000L.U)
{ c =>
c.io.in.valid.poke(true.B)
c.io.in.bits.a(0).poke(0x40000000L.U(64.W))
c.io.in.bits.a(1).poke(0x40000000L.U(64.W))
c.io.in.bits.a(2).poke(0x40000000L.U(64.W))
c.io.in.bits.a(3).poke(0x40000000L.U(64.W))
c.io.in.bits.b(0).poke(0x40000000L.U(64.W))
c.io.in.bits.b(1).poke(0x40000000L.U(64.W))
c.io.in.bits.b(2).poke(0x40000000L.U(64.W))
c.io.in.bits.b(3).poke(0x40000000L.U(64.W))
c.clock.step()
c.io.in.valid.poke(false.B)
c.clock.step()
c.clock.step()
c.io.out.valid.expect(true.B)
c.io.out.bits.data.expect(0x40e00000L.U)
c.clock.step()
c.io.out.valid.expect(false.B)
}
}
}