tensor: Do dot-product in fp16, only do accum in fp32

This is to better match Gemmini PEs doing MACs in full fp16, and only
doing accumulation in fp32.
This commit is contained in:
Hansung Kim
2024-11-13 16:01:11 -08:00
parent 8c473f52e3
commit 3b71276c4a
2 changed files with 11 additions and 5 deletions

View File

@@ -147,7 +147,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
val mulSigWidth = m.io.rawOut.sigWidth
val roundRawFNToRecFN =
Module(new hardfloat.RoundAnyRawFNToRecFN(
mulExpWidth, mulSigWidth, outExpWidth, outSigWidth, 0))
mulExpWidth, mulSigWidth, expWidth, sigWidth, 0))
roundRawFNToRecFN.io.invalidExc := m.io.invalidExc
roundRawFNToRecFN.io.infiniteExc := false.B
roundRawFNToRecFN.io.in := m.io.rawOut
@@ -169,7 +169,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// instantiate wires for input values to each reduction pipeline stage
val interim = (log2Dim to 0 by -1).map { i =>
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
Wire(Valid(Vec(1 << i, Bits(recInFLen.W))))
}
// instantiate wires for pipe registers for C
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
@@ -186,7 +186,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
require(inputs.bits.length == 2 * outputs.bits.length)
val thisDim = inputs.bits.length
val adders = Seq.fill(thisDim / 2)(
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
Module(new hardfloat.AddRecFN(expWidth, sigWidth))
)
val addOuts = adders.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U // FIXME dont know what this is
@@ -212,9 +212,15 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// add stages end ------------------------------------------------------------
// add final A and B dot-product result to accumulator C
val conv = Module(new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth))
conv.io.in := addStageOut.bits(0)
conv.io.roundingMode := hardfloat.consts.round_near_even
conv.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(conv.io.exceptionFlags === 0.U)
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
acc.io.subOp := 0.U // FIXME
acc.io.a := addStageOut.bits(0)
acc.io.a := conv.io.out
acc.io.b := addStageC.bits
acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding

View File

@@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreDecoupled"
it should "do the right thing" in {
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams()))
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, half = true))
{ c =>
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.wid.poke(0.U)