tensor: Do dot-product in fp16, only do accum in fp32
This is to better match Gemmini PEs doing MACs in full fp16, and only doing accumulation in fp32.
This commit is contained in:
@@ -147,7 +147,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|||||||
val mulSigWidth = m.io.rawOut.sigWidth
|
val mulSigWidth = m.io.rawOut.sigWidth
|
||||||
val roundRawFNToRecFN =
|
val roundRawFNToRecFN =
|
||||||
Module(new hardfloat.RoundAnyRawFNToRecFN(
|
Module(new hardfloat.RoundAnyRawFNToRecFN(
|
||||||
mulExpWidth, mulSigWidth, outExpWidth, outSigWidth, 0))
|
mulExpWidth, mulSigWidth, expWidth, sigWidth, 0))
|
||||||
roundRawFNToRecFN.io.invalidExc := m.io.invalidExc
|
roundRawFNToRecFN.io.invalidExc := m.io.invalidExc
|
||||||
roundRawFNToRecFN.io.infiniteExc := false.B
|
roundRawFNToRecFN.io.infiniteExc := false.B
|
||||||
roundRawFNToRecFN.io.in := m.io.rawOut
|
roundRawFNToRecFN.io.in := m.io.rawOut
|
||||||
@@ -169,7 +169,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|||||||
|
|
||||||
// instantiate wires for input values to each reduction pipeline stage
|
// instantiate wires for input values to each reduction pipeline stage
|
||||||
val interim = (log2Dim to 0 by -1).map { i =>
|
val interim = (log2Dim to 0 by -1).map { i =>
|
||||||
Wire(Valid(Vec(1 << i, Bits(recOutFLen.W))))
|
Wire(Valid(Vec(1 << i, Bits(recInFLen.W))))
|
||||||
}
|
}
|
||||||
// instantiate wires for pipe registers for C
|
// instantiate wires for pipe registers for C
|
||||||
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
|
val interimC = (log2Dim to 0 by -1).map( _ => Wire(Valid(Bits(recOutFLen.W))) )
|
||||||
@@ -186,7 +186,7 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|||||||
require(inputs.bits.length == 2 * outputs.bits.length)
|
require(inputs.bits.length == 2 * outputs.bits.length)
|
||||||
val thisDim = inputs.bits.length
|
val thisDim = inputs.bits.length
|
||||||
val adders = Seq.fill(thisDim / 2)(
|
val adders = Seq.fill(thisDim / 2)(
|
||||||
Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||||
)
|
)
|
||||||
val addOuts = adders.zipWithIndex.map { case (a, i) =>
|
val addOuts = adders.zipWithIndex.map { case (a, i) =>
|
||||||
a.io.subOp := 0.U // FIXME dont know what this is
|
a.io.subOp := 0.U // FIXME dont know what this is
|
||||||
@@ -212,9 +212,15 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|||||||
// add stages end ------------------------------------------------------------
|
// add stages end ------------------------------------------------------------
|
||||||
|
|
||||||
// add final A and B dot-product result to accumulator C
|
// add final A and B dot-product result to accumulator C
|
||||||
|
val conv = Module(new hardfloat.RecFNToRecFN(expWidth, sigWidth, outExpWidth, outSigWidth))
|
||||||
|
conv.io.in := addStageOut.bits(0)
|
||||||
|
conv.io.roundingMode := hardfloat.consts.round_near_even
|
||||||
|
conv.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||||
|
// assert(conv.io.exceptionFlags === 0.U)
|
||||||
|
|
||||||
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
val acc = Module(new hardfloat.AddRecFN(outExpWidth, outSigWidth))
|
||||||
acc.io.subOp := 0.U // FIXME
|
acc.io.subOp := 0.U // FIXME
|
||||||
acc.io.a := addStageOut.bits(0)
|
acc.io.a := conv.io.out
|
||||||
acc.io.b := addStageC.bits
|
acc.io.b := addStageC.bits
|
||||||
acc.io.roundingMode := hardfloat.consts.round_near_even
|
acc.io.roundingMode := hardfloat.consts.round_near_even
|
||||||
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
acc.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class TensorCoreDecoupledTest extends AnyFlatSpec with ChiselScalatestTester {
|
|||||||
behavior of "TensorCoreDecoupled"
|
behavior of "TensorCoreDecoupled"
|
||||||
|
|
||||||
it should "do the right thing" in {
|
it should "do the right thing" in {
|
||||||
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, tilingParams = TensorTilingParams()))
|
test(new TensorCoreDecoupled(8, 8, numSourceIds = 4, half = true))
|
||||||
{ c =>
|
{ c =>
|
||||||
c.io.initiate.valid.poke(true.B)
|
c.io.initiate.valid.poke(true.B)
|
||||||
c.io.initiate.bits.wid.poke(0.U)
|
c.io.initiate.bits.wid.poke(0.U)
|
||||||
|
|||||||
Reference in New Issue
Block a user