Write four-element dpu without accumulation
This commit is contained in:
@@ -11,35 +11,98 @@ class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
val fLen = 32
|
||||
val minFLen = 32
|
||||
def xLen = 32
|
||||
val dotProductDim = 4
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Bits((fLen).W)
|
||||
val b = Bits((fLen).W)
|
||||
val c = Bits((fLen).W)
|
||||
val a = Vec(dotProductDim, Bits((fLen).W))
|
||||
val b = Vec(dotProductDim, Bits((fLen).W))
|
||||
val c = Vec(dotProductDim, Bits((fLen).W))
|
||||
}))
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((fLen+1).W)
|
||||
val data = Bits((fLen).W)
|
||||
})
|
||||
})
|
||||
|
||||
val t = tile.FType.S
|
||||
|
||||
val in1 = recode(io.in.bits.a, S)
|
||||
val in2 = recode(io.in.bits.b, S)
|
||||
val in3 = recode(io.in.bits.c, S)
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in3 = io.in.bits.c.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
|
||||
val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
|
||||
fma.io.validin := io.in.valid
|
||||
fma.io.op := 0.U // FIXME
|
||||
fma.io.roundingMode := 0.U // FIXME
|
||||
fma.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
fma.io.a := unbox(in1, S, Some(tile.FType.S))
|
||||
fma.io.b := unbox(in2, S, Some(tile.FType.S))
|
||||
fma.io.c := unbox(in3, S, Some(tile.FType.S))
|
||||
// val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
|
||||
// fma.io.validin := io.in.valid
|
||||
// fma.io.op := 0.U // FIXME
|
||||
// fma.io.roundingMode := hardfloat.consts.round_near_even
|
||||
// fma.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
// fma.io.a := unbox(in1, S, Some(tile.FType.S))
|
||||
// fma.io.b := unbox(in2, S, Some(tile.FType.S))
|
||||
// fma.io.c := unbox(in3, S, Some(tile.FType.S))
|
||||
|
||||
io.out.valid := fma.io.validout
|
||||
io.out.bits.data := ieee(box(fma.io.out, S))
|
||||
val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := in1
|
||||
dpu.io.in.bits.b := in2
|
||||
// FIXME: in3 unused
|
||||
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
|
||||
class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
require(dim == 4, "DPU currently only supports dimension 4")
|
||||
|
||||
val recFLen = expWidth + sigWidth + 1
|
||||
val io = IO(new Bundle {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Vec(4, Bits((recFLen).W))
|
||||
val b = Vec(4, Bits((recFLen).W))
|
||||
// val roundingMode = UInt(3.W)
|
||||
// val detectTininess = UInt(1.W)
|
||||
}))
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((recFLen).W)
|
||||
})
|
||||
})
|
||||
|
||||
val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth)))
|
||||
mul.zipWithIndex.foreach { case (m, i) =>
|
||||
m.io.roundingMode := hardfloat.consts.round_near_even // consts.round_near_maxMag
|
||||
m.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
m.io.a := io.in.bits.a(i)
|
||||
m.io.b := io.in.bits.b(i)
|
||||
}
|
||||
|
||||
val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out)))
|
||||
|
||||
// mul stage end -------------------------------------------------------
|
||||
|
||||
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth)))
|
||||
add1.zipWithIndex.foreach { case (a, i) =>
|
||||
a.io.subOp := 0.U // FIXME
|
||||
a.io.a := mulStageOut.bits(2 * i + 0)
|
||||
a.io.b := mulStageOut.bits(2 * i + 1)
|
||||
a.io.roundingMode := hardfloat.consts.round_near_even
|
||||
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
}
|
||||
|
||||
val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out)))
|
||||
|
||||
// add stage 1 end -----------------------------------------------------
|
||||
|
||||
val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
add2.io.subOp := 0.U // FIXME
|
||||
assert(add1StageOut.bits.length == 2)
|
||||
add2.io.a := add1StageOut.bits(0)
|
||||
add2.io.b := add1StageOut.bits(1)
|
||||
add2.io.roundingMode := hardfloat.consts.round_near_even
|
||||
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
io.out.valid := Pipe(add1StageOut.valid, false.B).valid
|
||||
io.out.bits.data := Pipe(add1StageOut.valid, add2.io.out).bits
|
||||
// FIXME: exception output ignored
|
||||
|
||||
// add stage 2 end -----------------------------------------------------
|
||||
}
|
||||
|
||||
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package radiance.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.stage.PrintFullStackTraceAnnotation
|
||||
import chisel3.util._
|
||||
import chiseltest._
|
||||
import chiseltest.simulator.VerilatorFlags
|
||||
@@ -49,25 +48,26 @@ class DPUPipeTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
|
||||
it should "pass" in {
|
||||
test(new DPUPipe)
|
||||
// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteFstAnnotation))
|
||||
// .withAnnotations(Seq(WriteVcdAnnotation))
|
||||
{ fma =>
|
||||
fma.io.in.valid.poke(true.B)
|
||||
fma.io.in.bits.a.poke(0x40000000L.U(64.W))
|
||||
fma.io.in.bits.b.poke(0x40400000L.U(64.W))
|
||||
fma.io.in.bits.c.poke(0x3f800000L.U(64.W))
|
||||
fma.clock.step()
|
||||
fma.io.in.valid.poke(true.B)
|
||||
fma.io.in.bits.a.poke(0x40000000L.U(64.W))
|
||||
fma.io.in.bits.b.poke(0x3f800000L.U(64.W))
|
||||
fma.io.in.bits.c.poke(0x3f800000L.U(64.W))
|
||||
fma.clock.step()
|
||||
fma.io.in.valid.poke(false.B)
|
||||
fma.io.out.valid.expect(true.B)
|
||||
fma.io.out.bits.data.expect(0x40e00000L.U)
|
||||
fma.clock.step()
|
||||
// pipelined back-to-back response
|
||||
fma.io.out.valid.expect(true.B)
|
||||
fma.io.out.bits.data.expect(0x40400000L.U)
|
||||
{ c =>
|
||||
c.io.in.valid.poke(true.B)
|
||||
c.io.in.bits.a(0).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.a(1).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.a(2).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.a(3).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(0).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(1).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(2).poke(0x40000000L.U(64.W))
|
||||
c.io.in.bits.b(3).poke(0x40000000L.U(64.W))
|
||||
c.clock.step()
|
||||
c.io.in.valid.poke(false.B)
|
||||
c.clock.step()
|
||||
c.clock.step()
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x40e00000L.U)
|
||||
c.clock.step()
|
||||
c.io.out.valid.expect(false.B)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user