tensor: Consider data reuse for B memory request

B is reused every 4 steps because of the k->i->j iteration order.
This commit is contained in:
Hansung Kim
2024-10-18 13:46:04 -07:00
parent a2519da58f
commit 64ea48ace3

View File

@@ -145,8 +145,6 @@ class TensorCoreDecoupled(
// Memory traffic generation
// -------------------------
//
val genReq = (state === TensorState.run)
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
@@ -159,16 +157,14 @@ class TensorCoreDecoupled(
tag.step := stepAccess
tag.substep := substepAccess
val numTilesM = tilingParams.m / tilingParams.mc
val numTilesN = tilingParams.n / tilingParams.nc
// @cleanup: generalize in terms of M/N/K-majorness?
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
: (UInt/*A*/, UInt/*B*/) = {
// note that step iterates along N first, then M
val numComputeTilesM = tilingParams.m / tilingParams.mc
val numComputeTilesN = tilingParams.n / tilingParams.nc
val tileM = step % numComputeTilesM.U
val tileN = step / numComputeTilesM.U
val mcSubstep = tilingParams.mc / 2
val ncSubstep = tilingParams.nc / 2
val tileM = step % numTilesM.U
val tileN = step / numTilesM.U
// note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses
@@ -180,11 +176,11 @@ class TensorCoreDecoupled(
val tileColB = set // K
// (row,col) coordinate of the starting element of the compute tile
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
(substep << log2Ceil(mcSubstep))
val elemColA = tileColA << log2Ceil(tilingParams.kc)
val elemRowB = tileRowB << log2Ceil(tilingParams.nc)
(substep << log2Ceil(ncSubstep))
val elemColB = tileColB << log2Ceil(tilingParams.kc)
(substep << log2Ceil(tilingParams.mc / 2))
val elemColA = tileColA << log2Ceil(tilingParams.kc)
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
(substep << log2Ceil(tilingParams.nc / 2))
val elemColB = tileColB << log2Ceil(tilingParams.kc)
val rowStrideA = wordSize * tilingParams.k
val rowStrideABits = log2Ceil(rowStrideA)
val rowStrideB = wordSize * tilingParams.k
@@ -201,6 +197,13 @@ class TensorCoreDecoupled(
val (addressA, addressB) =
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
val genReqA = (state === TensorState.run)
val numTilesMBits = log2Ceil(numTilesM)
// generate B request at every 4 steps. B achieves reuse through outer
// product so it doesn't require access at every step
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
val genReqB = (state === TensorState.run) && shouldFireB
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)),
@@ -213,7 +216,7 @@ class TensorCoreDecoupled(
sourceGen.io.gen := req.fire
sourceGen.io.meta := tag
req.valid := genReq
req.valid := (if (i == 0) genReqA else genReqB)
req.bits.address := (if (i == 0) addressA else addressB)
req.bits.source := sourceGen.io.id.bits
@@ -228,23 +231,27 @@ class TensorCoreDecoupled(
}
}
// only advance to the next step if we fired mem requests for both A and B
// TODO: @perf: too strict? should be able to have A and B progress
// separately
val firedABReg = RegInit(VecInit(false.B, false.B))
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
case (req, fired) => { when (req.fire) { fired := true.B } }
req.fire
})
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
val nextSubstepAccess = firedAB.andR
// only advance to the next step if we fired mem requests for both A and B.
// also consider that B doesn't have to be fired every time due to reuse.
// @perf: too strict? should be able to have A and B progress separately
val firedAReg = RegInit(false.B)
val firedBReg = RegInit(false.B)
when (io.reqA.fire) { firedAReg := true.B }
when (io.reqB.fire) { firedBReg := true.B }
val firedANow = io.reqA.fire
val firedBNow = io.reqB.fire
val firedA = firedAReg || firedANow
val firedB = firedBReg || firedBNow
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep
when (nextSubstepAccess) {
firedABReg := Seq(false.B, false.B)
firedAReg := false.B
firedBReg := false.B
substepAccess := substepAccess + 1.U
}
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
dontTouch(shouldFireB)
// Execute stage
// -------------
@@ -327,18 +334,26 @@ class TensorCoreDecoupled(
respQueueA.ready := MuxCase(false.B,
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
respQueueB.ready := dpuFire
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension).
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask
respQueueB.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA)
dontTouch(respQueueB)
dontTouch(shouldDequeueB)
// assert that the DPU is computing with operands of the same set/step
// Assert that the DPU is computing with operands of the same set/step. Note
// that the B resp will only have step values multiple of 4 due to reuse.
//
// this assumes that memory responses come back in-order. this might be too
// strong an assumption depending on the backing memory
// This check assumes that memory responses come back in-order. Might be too
// strong of an assumption depending on the backing memory.
def assertAligned = {
val stepMask = (1 << numTilesMBits).U
when (dpuFire) {
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step),
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
(respQueueB.bits.tag.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
@@ -348,26 +363,26 @@ class TensorCoreDecoupled(
// Dot-product unit
//
// 4x2 four-element DPUs summing up to 32 MACs in total
val dpus = Seq.fill(4)(Seq.fill(2)(
val ncSubstep = tilingParams.nc / 2
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(half = false))
))
// operandA is 4x4 in K-major
val operandADimensional =
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq
assert(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// operandB is 2x4, i.e. 4x2 in N-major
.grouped(4/*k-dim*/).toSeq
require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// operandB is 2x4 in K-major
val operandBDimensional =
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq
val ncSubstep = tilingParams.nc / 2
assert(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
assert(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
.grouped(4/*k-dim*/).toSeq
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
for (m <- 0 until tilingParams.mc) {
for (n <- 0 until ncSubstep) {
@@ -406,10 +421,8 @@ class TensorCoreDecoupled(
// ----------------
// These queues hold metadata needed for writeback in sync with the DPU.
val queueDepth = 6 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(
chiselTypeOf(operandATag), queueDepth
))
val queueDepth = 5 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
tagQueue.io.enq.valid := dpuFire
// A and B should have the same tags
tagQueue.io.enq.bits := operandATag
@@ -573,11 +586,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
val tensor = LazyModule(new TensorCoreDecoupledTL)
val xbar = LazyModule(new TLXbar)
val ramA = LazyModule(new TLRAM(
address = AddressSet(0x000, 0xfffeff),
address = AddressSet(0x000, 0xfffbff),
beatBytes = 32 // @cleanup: hardcoded
))
val ramB = LazyModule(new TLRAM(
address = AddressSet(0x100, 0xfffeff),
address = AddressSet(0x400, 0xfffbff),
beatBytes = 32 // @cleanup: hardcoded
))