tensor: Do dot-product in fp16, only do accum in fp32

This is to better match Gemmini PEs doing MACs in full fp16, and only
doing accumulation in fp32.
This commit is contained in:
Hansung Kim
2024-11-13 16:01:11 -08:00
parent 8c473f52e3
commit 3b71276c4a
2 changed files with 11 additions and 5 deletions

View File

@@ -147,7 +147,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
val mulSigWidth = m.io.rawOut.sigWidth val mulSigWidth = m.io.rawOut.sigWidth
val roundRawFNToRecFN = val roundRawFNToRecFN =
Module(new hardfloat.RoundAnyRawFNToRecFN( Module(new hardfloat.RoundAnyRawFNToRecFN(
mulExpWidth, mulSigWidth, outExpWidth, outSigWidth, 0)) mulExpWidth, mulSigWidth, expWidth, sigWidth, 0))
roundRawFNToRecFN.io.invalidExc := m.io.invalidExc roundRawFNToRecFN.io.invalidExc := m.io.invalidExc
roundRawFNToRecFN.io.infiniteExc := false.B roundRawFNToRecFN.io.infiniteExc := false.B
roundRawFNToRecFN.io.in := m.io.rawOut roundRawFNToRecFN.io.in := m.io.rawOut
@@ -169,7 +169,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// instantiate wires for input values to each reduction pipeline stage // instantiate wires for input values to each reduction pipeline stage
val interim = (log2Dim to 0 by -1).map { i => val interim = (log2Dim to 0 by -1).map { i =>
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W)))) Wire(Valid(Vec(1 << i, Bits(recInFLen.W))))
} }
// instantiate wires for pipe registers for C // instantiate wires for pipe registers for C
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) ) val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
@@ -186,7 +186,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
require(inputs.bits.length == 2 * outputs.bits.length) require(inputs.bits.length == 2 * outputs.bits.length)
val thisDim = inputs.bits.length val thisDim = inputs.bits.length
val adders = Seq.fill(thisDim / 2)( val adders = Seq.fill(thisDim / 2)(
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) Module(new hardfloat.AddRecFN(expWidth, sigWidth))
) )
val addOuts = adders.zipWithIndex.map { case (a, i) => val addOuts = adders.zipWithIndex.map { case (a, i) =>
a.io.subOp := 0.U // FIXME dont know what this is a.io.subOp := 0.U // FIXME dont know what this is
@@ -212,9 +212,15 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
// add stages end ------------------------------------------------------------ // add stages end ------------------------------------------------------------
// add final A and B dot-product result to accumulator C // add final A and B dot-product result to accumulator C
val conv = Module(new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth))
conv.io.in := addStageOut.bits(0)
conv.io.roundingMode := hardfloat.consts.round_near_even
conv.io.detectTininess := hardfloat.consts.tininess_afterRounding
// assert(conv.io.exceptionFlags === 0.U)
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth)) val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
acc.io.subOp := 0.U // FIXME acc.io.subOp := 0.U // FIXME
acc.io.a := addStageOut.bits(0) acc.io.a := conv.io.out
acc.io.b := addStageC.bits acc.io.b := addStageC.bits
acc.io.roundingMode := hardfloat.consts.round_near_even acc.io.roundingMode := hardfloat.consts.round_near_even
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding acc.io.detectTininess := hardfloat.consts.tininess_afterRounding

View File

@@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreDecoupled" behavior of "TensorCoreDecoupled"
it should "do the right thing" in { it should "do the right thing" in {
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams())) test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, half = true))
{ c => { c =>
c.io.initiate.valid.poke(true.B) c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.wid.poke(0.U) c.io.initiate.bits.wid.poke(0.U)