|
|
|
|
@@ -9,7 +9,10 @@ import freechips.rocketchip.tile
|
|
|
|
|
|
|
|
|
|
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
|
|
|
|
// `half`: if True, generate fp16 MACs; if False fp32.
|
|
|
|
|
class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUParameters {
|
|
|
|
|
class TensorDotProductUnit(
|
|
|
|
|
val dim: Int = 4,
|
|
|
|
|
val half: Boolean
|
|
|
|
|
) extends Module with tile.HasFPUParameters {
|
|
|
|
|
val tIn = if (half) tile.FType.H else tile.FType.S
|
|
|
|
|
// output datatype fixed to single-precision
|
|
|
|
|
val tOut = tile.FType.S
|
|
|
|
|
@@ -19,12 +22,11 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
|
|
|
|
val fLen = outFLen // needed for HasFPUParameters
|
|
|
|
|
val minFLen = 16 // fp16
|
|
|
|
|
def xLen = 32
|
|
|
|
|
val dotProductDim = 4
|
|
|
|
|
|
|
|
|
|
val io = IO(new Bundle {
|
|
|
|
|
val in = Flipped(Valid(new Bundle {
|
|
|
|
|
val a = Vec(dotProductDim, Bits((inFLen).W))
|
|
|
|
|
val b = Vec(dotProductDim, Bits((inFLen).W))
|
|
|
|
|
val a = Vec(dim, Bits((inFLen).W))
|
|
|
|
|
val b = Vec(dim, Bits((inFLen).W))
|
|
|
|
|
val c = Bits((outFLen).W) // note C has the out length for accumulation
|
|
|
|
|
}))
|
|
|
|
|
// 'stall' is effectively out.ready, combinationally coupled to in.ready
|
|
|
|
|
@@ -43,7 +45,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
|
|
|
|
val in2 = io.in.bits.b.map(x => unbox(recode(x, tag), tag, Some(tIn)))
|
|
|
|
|
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tOut))
|
|
|
|
|
|
|
|
|
|
val dpu = Module(new DotProductPipe(dotProductDim, tIn, tOut))
|
|
|
|
|
val dpu = Module(new DotProductPipe(dim, tIn, tOut))
|
|
|
|
|
dpu.io.in.valid := io.in.valid
|
|
|
|
|
dpu.io.in.bits.a := in1
|
|
|
|
|
dpu.io.in.bits.b := in2
|
|
|
|
|
@@ -101,7 +103,6 @@ object StallingPipe {
|
|
|
|
|
// Computes d = a(0)*b(0) + ... + a(`dim`-1)*b(`dim`-1) + c.
|
|
|
|
|
// Fully pipelined with a fixed latency determined by `dim`.
|
|
|
|
|
class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) extends Module {
|
|
|
|
|
require(dim == 4, "DPU currently only supports dimension 4")
|
|
|
|
|
val expWidth = inputType.exp
|
|
|
|
|
val sigWidth = inputType.sig
|
|
|
|
|
val outExpWidth = outputType.exp
|
|
|
|
|
@@ -111,8 +112,8 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|
|
|
|
val recOutFLen = outExpWidth + outSigWidth + 1
|
|
|
|
|
val io = IO(new Bundle {
|
|
|
|
|
val in = Flipped(Valid(new Bundle {
|
|
|
|
|
val a = Vec(4, Bits((recInFLen).W))
|
|
|
|
|
val b = Vec(4, Bits((recInFLen).W))
|
|
|
|
|
val a = Vec(dim, Bits((recInFLen).W))
|
|
|
|
|
val b = Vec(dim, Bits((recInFLen).W))
|
|
|
|
|
val c = Bits((recOutFLen).W)
|
|
|
|
|
// val roundingMode = UInt(3.W)
|
|
|
|
|
// val detectTininess = UInt(1.W)
|
|
|
|
|
@@ -141,6 +142,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|
|
|
|
// assert(m.io.invalidExc === false.B)
|
|
|
|
|
|
|
|
|
|
// round fp16*fp16 raw result back to fp32 recoded format
|
|
|
|
|
// @perf: possibly pipeline here for better timing
|
|
|
|
|
val mulExpWidth = m.io.rawOut.expWidth
|
|
|
|
|
val mulSigWidth = m.io.rawOut.sigWidth
|
|
|
|
|
val roundRawFNToRecFN =
|
|
|
|
|
@@ -160,45 +162,65 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|
|
|
|
|
|
|
|
|
// mul stage end -------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)))
|
|
|
|
|
val add1Outs = add1.zipWithIndex.map { case (a, i) =>
|
|
|
|
|
a.io.subOp := 0.U // FIXME dont know what this is
|
|
|
|
|
a.io.a := mulStageOut.bits(2 * i + 0)
|
|
|
|
|
a.io.b := mulStageOut.bits(2 * i + 1)
|
|
|
|
|
a.io.roundingMode := hardfloat.consts.round_near_even
|
|
|
|
|
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
|
|
|
|
// assert(a.io.exceptionFlags === 0.U)
|
|
|
|
|
a.io.out
|
|
|
|
|
// reduce-add `dim` mul results down to one in a tree reduction
|
|
|
|
|
//
|
|
|
|
|
val log2Dim = log2Ceil(dim)
|
|
|
|
|
require(dim == (1 << log2Dim), s"dim (${dim}) is not power of two!")
|
|
|
|
|
|
|
|
|
|
// instantiate wires for input values to each reduction pipeline stage
|
|
|
|
|
val interim = (log2Dim to 0 by -1).map { i =>
|
|
|
|
|
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
|
|
|
|
|
}
|
|
|
|
|
// instantiate wires for pipe registers for C
|
|
|
|
|
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
|
|
|
|
|
// connect the first stage inputs
|
|
|
|
|
interim(0) := mulStageOut
|
|
|
|
|
interimC(0) := mulStageC
|
|
|
|
|
|
|
|
|
|
val add1StageOut = StallingPipe(io.stall, mulStageOut.valid, VecInit(add1Outs))
|
|
|
|
|
val add1StageC = StallingPipe(io.stall, mulStageOut.valid, mulStageC.bits)
|
|
|
|
|
// now we get fancy
|
|
|
|
|
val (addStageOut, addStageC) = (interim zip interimC).reduce {
|
|
|
|
|
(inputsAndC, outputsAndC) => {
|
|
|
|
|
val (inputs, inC) = inputsAndC
|
|
|
|
|
val (outputs, outC) = outputsAndC
|
|
|
|
|
|
|
|
|
|
// add1 stage end ------------------------------------------------------------
|
|
|
|
|
require(inputs.bits.length == 2 * outputs.bits.length)
|
|
|
|
|
val thisDim = inputs.bits.length
|
|
|
|
|
val adders = Seq.fill(thisDim / 2)(
|
|
|
|
|
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
|
|
|
|
)
|
|
|
|
|
val addOuts = adders.zipWithIndex.map { case (a, i) =>
|
|
|
|
|
a.io.subOp := 0.U // FIXME dont know what this is
|
|
|
|
|
a.io.a := inputs.bits(2 * i + 0)
|
|
|
|
|
a.io.b := inputs.bits(2 * i + 1)
|
|
|
|
|
a.io.roundingMode := hardfloat.consts.round_near_even
|
|
|
|
|
a.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
|
|
|
|
// assert(a.io.exceptionFlags === 0.U)
|
|
|
|
|
a.io.out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
val add2 = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
|
|
|
|
add2.io.subOp := 0.U // FIXME
|
|
|
|
|
add2.io.a := add1StageOut.bits(0)
|
|
|
|
|
add2.io.b := add1StageOut.bits(1)
|
|
|
|
|
add2.io.roundingMode := hardfloat.consts.round_near_even
|
|
|
|
|
add2.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
|
|
|
|
// assert(add2.io.exceptionFlags === 0.U)
|
|
|
|
|
// pipeline and connect outputs to the next stage
|
|
|
|
|
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
|
|
|
|
|
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
|
|
|
|
|
assert(inputs.valid === inC.valid,
|
|
|
|
|
"adder inputs valid and C pipe valid went out-of-sync")
|
|
|
|
|
|
|
|
|
|
val add2StageOut = StallingPipe(io.stall, add1StageOut.valid, add2.io.out)
|
|
|
|
|
val add2StageC = StallingPipe(io.stall, add1StageOut.valid, add1StageC.bits)
|
|
|
|
|
(outputs, outC)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
require(addStageOut.bits.length == 1)
|
|
|
|
|
|
|
|
|
|
// add2 stage end ------------------------------------------------------------
|
|
|
|
|
// add stages end ------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
// add final A and B dot-product result to accumulator C
|
|
|
|
|
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
|
|
|
|
acc.io.subOp := 0.U // FIXME
|
|
|
|
|
acc.io.a := add2StageOut.bits
|
|
|
|
|
// acc.io.b := add2StageCRec
|
|
|
|
|
acc.io.b := add2StageC.bits
|
|
|
|
|
acc.io.a := addStageOut.bits(0)
|
|
|
|
|
acc.io.b := addStageC.bits
|
|
|
|
|
acc.io.roundingMode := hardfloat.consts.round_near_even
|
|
|
|
|
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
|
|
|
|
// assert(acc.io.exceptionFlags === 0.U)
|
|
|
|
|
|
|
|
|
|
val accStageOut = StallingPipe(io.stall, add2StageOut.valid, acc.io.out)
|
|
|
|
|
val accStageOut = StallingPipe(io.stall, addStageOut.valid, acc.io.out)
|
|
|
|
|
|
|
|
|
|
// acc stage end -------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|