Support fp16 input, fp32 output in TensorDPU
TODO could see improvement towards handling raw format as much as possible.
This commit is contained in:
@@ -10,22 +10,26 @@ import freechips.rocketchip.tile
|
||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||
// `half`: if True, generate fp16 MACs; if False fp32.
|
||||
class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters {
|
||||
val t = if (half) tile.FType.H else tile.FType.S
|
||||
val tIn = if (half) tile.FType.H else tile.FType.S
|
||||
// output datatype fixed to single-precision
|
||||
val tOut = tile.FType.S
|
||||
|
||||
val fLen = t.ieeeWidth
|
||||
val inFLen = tIn.ieeeWidth
|
||||
val outFLen = tOut.ieeeWidth
|
||||
val fLen = outFLen // needed for HasFPUParameters
|
||||
val minFLen = 16 // fp16
|
||||
def xLen = 32
|
||||
val dotProductDim = 4
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Vec(dotProductDim, Bits((fLen).W))
|
||||
val b = Vec(dotProductDim, Bits((fLen).W))
|
||||
val c = Bits((fLen).W)
|
||||
val a = Vec(dotProductDim, Bits((inFLen).W))
|
||||
val b = Vec(dotProductDim, Bits((inFLen).W))
|
||||
val c = Bits((inFLen).W)
|
||||
}))
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((fLen).W)
|
||||
val data = Bits((outFLen).W)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -33,11 +37,11 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
||||
// make sure recoding/uncoding happens only at the edge, not at every
|
||||
// pipeline stage inside the dpu
|
||||
val tag = if (half) H else S
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(t)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(t)))
|
||||
val in3 = unbox(recode(io.in.bits.c, tag), tag, Some(t))
|
||||
val in1 = io.in.bits.a.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
||||
val in3 = unbox(recode(io.in.bits.c, tag), tag, Some(tIn))
|
||||
|
||||
val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig))
|
||||
val dpu = Module(new DotProductPipe(dotProductDim, tIn, tOut))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := in1
|
||||
dpu.io.in.bits.b := in2
|
||||
@@ -45,7 +49,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
||||
dpu.io.stall := io.stall
|
||||
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, tag))
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
|
||||
// Copied from chisel3.util.Pipe.
|
||||
@@ -94,74 +98,116 @@ object StallingPipe {
|
||||
|
||||
// Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c.
|
||||
// Fully pipelined with a fixed latency determined by `dim`.
|
||||
class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) extends Module {
|
||||
require(dim == 4, "DPU currently only supports dimension 4")
|
||||
val expWidth = inputType.exp
|
||||
val sigWidth = inputType.sig
|
||||
val outExpWidth = outputType.exp
|
||||
val outSigWidth = outputType.sig
|
||||
|
||||
val recFLen = expWidth + sigWidth + 1
|
||||
val recInFLen = expWidth + sigWidth + 1
|
||||
val recOutFLen = outExpWidth + outSigWidth + 1
|
||||
val io = IO(new Bundle {
|
||||
val in = Flipped(Valid(new Bundle {
|
||||
val a = Vec(4, Bits((recFLen).W))
|
||||
val b = Vec(4, Bits((recFLen).W))
|
||||
val c = Bits((recFLen).W)
|
||||
val a = Vec(4, Bits((recInFLen).W))
|
||||
val b = Vec(4, Bits((recInFLen).W))
|
||||
val c = Bits((recInFLen).W)
|
||||
// val roundingMode = UInt(3.W)
|
||||
// val detectTininess = UInt(1.W)
|
||||
}))
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((recFLen).W)
|
||||
val data = Bits((recOutFLen).W)
|
||||
})
|
||||
})
|
||||
|
||||
val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth)))
|
||||
mul.zipWithIndex.foreach { case (m, i) =>
|
||||
val mul = Seq.fill(dim)(Module(new hardfloat.MulFullRawFN(expWidth, sigWidth)))
|
||||
val mulOuts = mul.zipWithIndex.map { case (m, i) =>
|
||||
// FIXME: these settings are arbitrary
|
||||
m.io.roundingMode := hardfloat.consts.round_near_even
|
||||
m.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
m.io.a := io.in.bits.a(i)
|
||||
m.io.b := io.in.bits.b(i)
|
||||
// m.io.roundingMode := hardfloat.consts.round_near_even
|
||||
// m.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
// m.io.a := io.in.bits.a(i)
|
||||
// m.io.b := io.in.bits.b(i)
|
||||
val rawInA = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, io.in.bits.a(i))
|
||||
val rawInB = hardfloat.rawFloatFromRecFN(expWidth, sigWidth, io.in.bits.b(i))
|
||||
m.io.a := rawInA
|
||||
m.io.b := rawInB
|
||||
// m.io.invalidExc output ignored
|
||||
// assert(m.io.invalidExc === false.B)
|
||||
}
|
||||
|
||||
val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.out)))
|
||||
val mulStageOut = StallingPipe(io.stall, io.in.valid, VecInit(mul.map(_.io.rawOut)))
|
||||
val mulStageC = StallingPipe(io.stall, io.in.valid, io.in.bits.c)
|
||||
|
||||
val mulExpWidth = mulStageOut.bits.head.expWidth
|
||||
val mulSigWidth = mulStageOut.bits.head.sigWidth
|
||||
|
||||
// mul stage end -------------------------------------------------------------
|
||||
|
||||
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth)))
|
||||
add1.zipWithIndex.foreach { case (a, i) =>
|
||||
a.io.subOp := 0.U // FIXME
|
||||
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRawFN(mulExpWidth, mulSigWidth)))
|
||||
val add1Outs = add1.zipWithIndex.map { case (a, i) =>
|
||||
a.io.subOp := 0.U // FIXME dont know what this is
|
||||
a.io.a := mulStageOut.bits(2 * i + 0)
|
||||
a.io.b := mulStageOut.bits(2 * i + 1)
|
||||
a.io.roundingMode := hardfloat.consts.round_near_even
|
||||
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
// a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
// a.io.invalidExc output ignored
|
||||
// assert(a.io.invalidExc === false.B)
|
||||
|
||||
// round back to fp32 recoded format
|
||||
// FIXME: awkward to do this in the middle; do right after mul?
|
||||
val addExpWidth = a.io.rawOut.expWidth
|
||||
val addSigWidth = a.io.rawOut.sigWidth
|
||||
val roundRawFNToRecFN =
|
||||
Module(new hardfloat.RoundAnyRawFNToRecFN(addExpWidth, addSigWidth, outExpWidth, outSigWidth, 0))
|
||||
roundRawFNToRecFN.io.invalidExc := a.io.invalidExc
|
||||
roundRawFNToRecFN.io.infiniteExc := false.B
|
||||
roundRawFNToRecFN.io.in := a.io.rawOut
|
||||
roundRawFNToRecFN.io.roundingMode := hardfloat.consts.round_near_even
|
||||
roundRawFNToRecFN.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
roundRawFNToRecFN.io.out
|
||||
// roundRawFNToRecFN.io.exceptionFlags ignored
|
||||
}
|
||||
|
||||
val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out)))
|
||||
// val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1.map(_.io.out)))
|
||||
val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1Outs))
|
||||
val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits)
|
||||
|
||||
// add1 stage end ------------------------------------------------------------
|
||||
|
||||
val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
val add2 = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||
add2.io.subOp := 0.U // FIXME
|
||||
assert(add1StageOut.bits.length == 2)
|
||||
add2.io.a := add1StageOut.bits(0)
|
||||
add2.io.b := add1StageOut.bits(1)
|
||||
add2.io.roundingMode := hardfloat.consts.round_near_even
|
||||
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
assert(add2.io.exceptionFlags === 0.U)
|
||||
|
||||
val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out)
|
||||
val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits)
|
||||
|
||||
// add2 stage end ------------------------------------------------------------
|
||||
|
||||
val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
// convert to recoded format for addition to C
|
||||
// TODO: raw+raw addition might be cheaper?
|
||||
val recToRec = Module(
|
||||
new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth))
|
||||
recToRec.io.in := add2StageC.bits
|
||||
recToRec.io.roundingMode := hardfloat.consts.round_near_even
|
||||
recToRec.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
assert(recToRec.io.exceptionFlags === 0.U)
|
||||
val add2StageCRec = recToRec.io.out
|
||||
|
||||
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||
acc.io.subOp := 0.U // FIXME
|
||||
acc.io.a := add2StageOut.bits
|
||||
acc.io.b := add2StageC.bits
|
||||
acc.io.b := add2StageCRec
|
||||
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
assert(acc.io.exceptionFlags === 0.U)
|
||||
|
||||
val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out)
|
||||
// FIXME: exception output ignored
|
||||
|
||||
// acc stage end -------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class TensorDotProductUnitTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
// 4-cycle latency + stalls
|
||||
|
||||
c.io.out.valid.expect(true.B)
|
||||
c.io.out.bits.data.expect(0x56d0.U)
|
||||
c.io.out.bits.data.expect(0x42da0000L.U)
|
||||
|
||||
c.clock.step()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user