tensor: Consider data reuse for B memory request
B is reused every 4 steps because of the k->i->j iteration order.
This commit is contained in:
@@ -145,8 +145,6 @@ class TensorCoreDecoupled(
|
||||
// Memory traffic generation
|
||||
// -------------------------
|
||||
//
|
||||
val genReq = (state === TensorState.run)
|
||||
|
||||
class TensorMemTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
@@ -159,16 +157,14 @@ class TensorCoreDecoupled(
|
||||
tag.step := stepAccess
|
||||
tag.substep := substepAccess
|
||||
|
||||
val numTilesM = tilingParams.m / tilingParams.mc
|
||||
val numTilesN = tilingParams.n / tilingParams.nc
|
||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
||||
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
|
||||
: (UInt/*A*/, UInt/*B*/) = {
|
||||
// note that step iterates along N first, then M
|
||||
val numComputeTilesM = tilingParams.m / tilingParams.mc
|
||||
val numComputeTilesN = tilingParams.n / tilingParams.nc
|
||||
val tileM = step % numComputeTilesM.U
|
||||
val tileN = step / numComputeTilesM.U
|
||||
val mcSubstep = tilingParams.mc / 2
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
val tileM = step % numTilesM.U
|
||||
val tileN = step / numTilesM.U
|
||||
|
||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||
// accesses
|
||||
@@ -180,11 +176,11 @@ class TensorCoreDecoupled(
|
||||
val tileColB = set // K
|
||||
// (row,col) coordinate of the starting element of the compute tile
|
||||
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
|
||||
(substep << log2Ceil(mcSubstep))
|
||||
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
||||
val elemRowB = tileRowB << log2Ceil(tilingParams.nc)
|
||||
(substep << log2Ceil(ncSubstep))
|
||||
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
||||
(substep << log2Ceil(tilingParams.mc / 2))
|
||||
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
||||
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
|
||||
(substep << log2Ceil(tilingParams.nc / 2))
|
||||
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
||||
val rowStrideA = wordSize * tilingParams.k
|
||||
val rowStrideABits = log2Ceil(rowStrideA)
|
||||
val rowStrideB = wordSize * tilingParams.k
|
||||
@@ -201,6 +197,13 @@ class TensorCoreDecoupled(
|
||||
val (addressA, addressB) =
|
||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
||||
|
||||
val genReqA = (state === TensorState.run)
|
||||
val numTilesMBits = log2Ceil(numTilesM)
|
||||
// generate B request at every 4 steps. B achieves reuse through outer
|
||||
// product so it doesn't require access at every step
|
||||
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
|
||||
val genReqB = (state === TensorState.run) && shouldFireB
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
@@ -213,7 +216,7 @@ class TensorCoreDecoupled(
|
||||
|
||||
sourceGen.io.gen := req.fire
|
||||
sourceGen.io.meta := tag
|
||||
req.valid := genReq
|
||||
req.valid := (if (i == 0) genReqA else genReqB)
|
||||
req.bits.address := (if (i == 0) addressA else addressB)
|
||||
req.bits.source := sourceGen.io.id.bits
|
||||
|
||||
@@ -228,23 +231,27 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
|
||||
// only advance to the next step if we fired mem requests for both A and B
|
||||
// TODO: @perf: too strict? should be able to have A and B progress
|
||||
// separately
|
||||
val firedABReg = RegInit(VecInit(false.B, false.B))
|
||||
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
||||
case (req, fired) => { when (req.fire) { fired := true.B } }
|
||||
req.fire
|
||||
})
|
||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||
val nextSubstepAccess = firedAB.andR
|
||||
// only advance to the next step if we fired mem requests for both A and B.
|
||||
// also consider that B doesn't have to be fired every time due to reuse.
|
||||
// @perf: too strict? should be able to have A and B progress separately
|
||||
val firedAReg = RegInit(false.B)
|
||||
val firedBReg = RegInit(false.B)
|
||||
when (io.reqA.fire) { firedAReg := true.B }
|
||||
when (io.reqB.fire) { firedBReg := true.B }
|
||||
val firedANow = io.reqA.fire
|
||||
val firedBNow = io.reqB.fire
|
||||
val firedA = firedAReg || firedANow
|
||||
val firedB = firedBReg || firedBNow
|
||||
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
|
||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||
// clear out firedABReg every substep
|
||||
when (nextSubstepAccess) {
|
||||
firedABReg := Seq(false.B, false.B)
|
||||
firedAReg := false.B
|
||||
firedBReg := false.B
|
||||
substepAccess := substepAccess + 1.U
|
||||
}
|
||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||
dontTouch(shouldFireB)
|
||||
|
||||
// Execute stage
|
||||
// -------------
|
||||
@@ -327,18 +334,26 @@ class TensorCoreDecoupled(
|
||||
respQueueA.ready := MuxCase(false.B,
|
||||
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
|
||||
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
|
||||
respQueueB.ready := dpuFire
|
||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||
// we fully iterated a column (M-dimension).
|
||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||
val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask
|
||||
respQueueB.ready := dpuFire && shouldDequeueB
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
dontTouch(shouldDequeueB)
|
||||
|
||||
// assert that the DPU is computing with operands of the same set/step
|
||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||
// that the B resp will only have step values multiple of 4 due to reuse.
|
||||
//
|
||||
// this assumes that memory responses come back in-order. this might be too
|
||||
// strong an assumption depending on the backing memory
|
||||
// This check assumes that memory responses come back in-order. Might be too
|
||||
// strong of an assumption depending on the backing memory.
|
||||
def assertAligned = {
|
||||
val stepMask = (1 << numTilesMBits).U
|
||||
when (dpuFire) {
|
||||
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step),
|
||||
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
|
||||
(respQueueB.bits.tag.step & stepMask)),
|
||||
"A and B operands are pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
@@ -348,26 +363,26 @@ class TensorCoreDecoupled(
|
||||
// Dot-product unit
|
||||
//
|
||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||
val dpus = Seq.fill(4)(Seq.fill(2)(
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||
Module(new TensorDotProductUnit(half = false))
|
||||
))
|
||||
// operandA is 4x4 in K-major
|
||||
val operandADimensional =
|
||||
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4).toSeq
|
||||
assert(operandADimensional.length == tilingParams.mc &&
|
||||
operandADimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
// operandB is 2x4, i.e. 4x2 in N-major
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
require(operandADimensional.length == tilingParams.mc &&
|
||||
operandADimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
// operandB is 2x4 in K-major
|
||||
val operandBDimensional =
|
||||
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4).toSeq
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
assert(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
assert(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
require(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
|
||||
for (m <- 0 until tilingParams.mc) {
|
||||
for (n <- 0 until ncSubstep) {
|
||||
@@ -406,10 +421,8 @@ class TensorCoreDecoupled(
|
||||
// ----------------
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
val queueDepth = 6 // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(
|
||||
chiselTypeOf(operandATag), queueDepth
|
||||
))
|
||||
val queueDepth = 5 // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
|
||||
tagQueue.io.enq.valid := dpuFire
|
||||
// A and B should have the same tags
|
||||
tagQueue.io.enq.bits := operandATag
|
||||
@@ -573,11 +586,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
val tensor = LazyModule(new TensorCoreDecoupledTL)
|
||||
val xbar = LazyModule(new TLXbar)
|
||||
val ramA = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x000, 0xfffeff),
|
||||
address = AddressSet(0x000, 0xfffbff),
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
val ramB = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x100, 0xfffeff),
|
||||
address = AddressSet(0x400, 0xfffbff),
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user