tensor: Do dot-product in fp16, only do accum in fp32
This is to better match Gemmini PEs doing MACs in full fp16, and only doing accumulation in fp32.
This commit is contained in:
@@ -147,7 +147,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
||||
val mulSigWidth = m.io.rawOut.sigWidth
|
||||
val roundRawFNToRecFN =
|
||||
Module(new hardfloat.RoundAnyRawFNToRecFN(
|
||||
mulExpWidth, mulSigWidth, outExpWidth, outSigWidth, 0))
|
||||
mulExpWidth, mulSigWidth, expWidth, sigWidth, 0))
|
||||
roundRawFNToRecFN.io.invalidExc := m.io.invalidExc
|
||||
roundRawFNToRecFN.io.infiniteExc := false.B
|
||||
roundRawFNToRecFN.io.in := m.io.rawOut
|
||||
@@ -169,7 +169,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
||||
|
||||
// instantiate wires for input values to each reduction pipeline stage
|
||||
val interim = (log2Dim to 0 by -1).map { i =>
|
||||
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
|
||||
Wire(Valid(Vec(1 << i, Bits(recInFLen.W))))
|
||||
}
|
||||
// instantiate wires for pipe registers for C
|
||||
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
|
||||
@@ -186,7 +186,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
||||
require(inputs.bits.length == 2 * outputs.bits.length)
|
||||
val thisDim = inputs.bits.length
|
||||
val adders = Seq.fill(thisDim / 2)(
|
||||
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||
Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
)
|
||||
val addOuts = adders.zipWithIndex.map { case (a, i) =>
|
||||
a.io.subOp := 0.U // FIXME dont know what this is
|
||||
@@ -212,9 +212,15 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
||||
// add stages end ------------------------------------------------------------
|
||||
|
||||
// add final A and B dot-product result to accumulator C
|
||||
val conv = Module(new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth))
|
||||
conv.io.in := addStageOut.bits(0)
|
||||
conv.io.roundingMode := hardfloat.consts.round_near_even
|
||||
conv.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
// assert(conv.io.exceptionFlags === 0.U)
|
||||
|
||||
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||
acc.io.subOp := 0.U // FIXME
|
||||
acc.io.a := addStageOut.bits(0)
|
||||
acc.io.a := conv.io.out
|
||||
acc.io.b := addStageC.bits
|
||||
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
|
||||
@@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||
behavior of "TensorCoreDecoupled"
|
||||
|
||||
it should "do the right thing" in {
|
||||
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams()))
|
||||
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, half = true))
|
||||
{ c =>
|
||||
c.io.initiate.valid.poke(true.B)
|
||||
c.io.initiate.bits.wid.poke(0.U)
|
||||
|
||||
Reference in New Issue
Block a user