tensor: Consider data reuse for B memory request

B is reused every 4 steps because of the k->i->j iteration order.
This commit is contained in:
Hansung Kim
2024-10-18 13:46:04 -07:00
parent a2519da58f
commit 64ea48ace3

View File

@@ -145,8 +145,6 @@ class TensorCoreDecoupled(
// Memory traffic generation // Memory traffic generation
// ------------------------- // -------------------------
// //
val genReq = (state === TensorState.run)
class TensorMemTag extends Bundle { class TensorMemTag extends Bundle {
val set = UInt(setBits.W) val set = UInt(setBits.W)
val step = UInt(stepBits.W) val step = UInt(stepBits.W)
@@ -159,16 +157,14 @@ class TensorCoreDecoupled(
tag.step := stepAccess tag.step := stepAccess
tag.substep := substepAccess tag.substep := substepAccess
val numTilesM = tilingParams.m / tilingParams.mc
val numTilesN = tilingParams.n / tilingParams.nc
// @cleanup: generalize in terms of M/N/K-majorness? // @cleanup: generalize in terms of M/N/K-majorness?
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt) def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
: (UInt/*A*/, UInt/*B*/) = { : (UInt/*A*/, UInt/*B*/) = {
// note that step iterates along N first, then M // note that step iterates along N first, then M
val numComputeTilesM = tilingParams.m / tilingParams.mc val tileM = step % numTilesM.U
val numComputeTilesN = tilingParams.n / tilingParams.nc val tileN = step / numTilesM.U
val tileM = step % numComputeTilesM.U
val tileN = step / numComputeTilesM.U
val mcSubstep = tilingParams.mc / 2
val ncSubstep = tilingParams.nc / 2
// note that both A and B are K-major to facilitate bank conflict-free SMEM // note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses // accesses
@@ -180,11 +176,11 @@ class TensorCoreDecoupled(
val tileColB = set // K val tileColB = set // K
// (row,col) coordinate of the starting element of the compute tile // (row,col) coordinate of the starting element of the compute tile
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) + val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
(substep << log2Ceil(mcSubstep)) (substep << log2Ceil(tilingParams.mc / 2))
val elemColA = tileColA << log2Ceil(tilingParams.kc) val elemColA = tileColA << log2Ceil(tilingParams.kc)
val elemRowB = tileRowB << log2Ceil(tilingParams.nc) val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
(substep << log2Ceil(ncSubstep)) (substep << log2Ceil(tilingParams.nc / 2))
val elemColB = tileColB << log2Ceil(tilingParams.kc) val elemColB = tileColB << log2Ceil(tilingParams.kc)
val rowStrideA = wordSize * tilingParams.k val rowStrideA = wordSize * tilingParams.k
val rowStrideABits = log2Ceil(rowStrideA) val rowStrideABits = log2Ceil(rowStrideA)
val rowStrideB = wordSize * tilingParams.k val rowStrideB = wordSize * tilingParams.k
@@ -201,6 +197,13 @@ class TensorCoreDecoupled(
val (addressA, addressB) = val (addressA, addressB) =
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
val genReqA = (state === TensorState.run)
val numTilesMBits = log2Ceil(numTilesM)
// generate B request at every 4 steps. B achieves reuse through outer
// product so it doesn't require access at every step
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
val genReqB = (state === TensorState.run) && shouldFireB
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)), Seq((io.reqA, (io.respA, respATagged)),
@@ -213,7 +216,7 @@ class TensorCoreDecoupled(
sourceGen.io.gen := req.fire sourceGen.io.gen := req.fire
sourceGen.io.meta := tag sourceGen.io.meta := tag
req.valid := genReq req.valid := (if (i == 0) genReqA else genReqB)
req.bits.address := (if (i == 0) addressA else addressB) req.bits.address := (if (i == 0) addressA else addressB)
req.bits.source := sourceGen.io.id.bits req.bits.source := sourceGen.io.id.bits
@@ -228,23 +231,27 @@ class TensorCoreDecoupled(
} }
} }
// only advance to the next step if we fired mem requests for both A and B // only advance to the next step if we fired mem requests for both A and B.
// TODO: @perf: too strict? should be able to have A and B progress // also consider that B doesn't have to be fired every time due to reuse.
// separately // @perf: too strict? should be able to have A and B progress separately
val firedABReg = RegInit(VecInit(false.B, false.B)) val firedAReg = RegInit(false.B)
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map { val firedBReg = RegInit(false.B)
case (req, fired) => { when (req.fire) { fired := true.B } } when (io.reqA.fire) { firedAReg := true.B }
req.fire when (io.reqB.fire) { firedBReg := true.B }
}) val firedANow = io.reqA.fire
val firedAB = (firedABNow.asUInt | firedABReg.asUInt) val firedBNow = io.reqB.fire
val nextSubstepAccess = firedAB.andR val firedA = firedAReg || firedANow
val firedB = firedBReg || firedBNow
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U) val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep // clear out firedABReg every substep
when (nextSubstepAccess) { when (nextSubstepAccess) {
firedABReg := Seq(false.B, false.B) firedAReg := false.B
firedBReg := false.B
substepAccess := substepAccess + 1.U substepAccess := substepAccess + 1.U
} }
require(substepAccess.widthOption.get == 1, "there should be only two substeps") require(substepAccess.widthOption.get == 1, "there should be only two substeps")
dontTouch(shouldFireB)
// Execute stage // Execute stage
// ------------- // -------------
@@ -327,18 +334,26 @@ class TensorCoreDecoupled(
respQueueA.ready := MuxCase(false.B, respQueueA.ready := MuxCase(false.B,
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready, Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
(substepExecute === 1.U) -> fullAQueue.io.enq.ready)) (substepExecute === 1.U) -> fullAQueue.io.enq.ready))
respQueueB.ready := dpuFire // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension).
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask
respQueueB.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA) dontTouch(respQueueA)
dontTouch(respQueueB) dontTouch(respQueueB)
dontTouch(shouldDequeueB)
// assert that the DPU is computing with operands of the same set/step // Assert that the DPU is computing with operands of the same set/step. Note
// that the B resp will only have step values multiple of 4 due to reuse.
// //
// this assumes that memory responses come back in-order. this might be too // This check assumes that memory responses come back in-order. Might be too
// strong an assumption depending on the backing memory // strong of an assumption depending on the backing memory.
def assertAligned = { def assertAligned = {
val stepMask = (1 << numTilesMBits).U
when (dpuFire) { when (dpuFire) {
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) && assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step), ((fullAQueue.io.deq.bits.tag.step & stepMask) ===
(respQueueB.bits.tag.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++ "A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.") "This might indicate memory response coming back out-of-order.")
} }
@@ -348,26 +363,26 @@ class TensorCoreDecoupled(
// Dot-product unit // Dot-product unit
// //
// 4x2 four-element DPUs summing up to 32 MACs in total // 4x2 four-element DPUs summing up to 32 MACs in total
val dpus = Seq.fill(4)(Seq.fill(2)( val ncSubstep = tilingParams.nc / 2
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(half = false)) Module(new TensorDotProductUnit(half = false))
)) ))
// operandA is 4x4 in K-major // operandA is 4x4 in K-major
val operandADimensional = val operandADimensional =
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq .grouped(4/*k-dim*/).toSeq
assert(operandADimensional.length == tilingParams.mc && require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc, operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
// operandB is 2x4, i.e. 4x2 in N-major // operandB is 2x4 in K-major
val operandBDimensional = val operandBDimensional =
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq .grouped(4/*k-dim*/).toSeq
val ncSubstep = tilingParams.nc / 2 require(tilingParams.mc * ncSubstep == numLanes,
assert(tilingParams.mc * ncSubstep == numLanes, "substep tile size doesn't match writeback throughput")
"substep tile size doesn't match writeback throughput") require(operandBDimensional.length == ncSubstep &&
assert(operandBDimensional.length == ncSubstep && operandBDimensional(0).length == tilingParams.kc,
operandBDimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter")
"operand width doesn't agree with tiling parameter")
for (m <- 0 until tilingParams.mc) { for (m <- 0 until tilingParams.mc) {
for (n <- 0 until ncSubstep) { for (n <- 0 until ncSubstep) {
@@ -406,10 +421,8 @@ class TensorCoreDecoupled(
// ---------------- // ----------------
// These queues hold metadata needed for writeback in sync with the DPU. // These queues hold metadata needed for writeback in sync with the DPU.
val queueDepth = 6 // needs to be at least the DPU latency val queueDepth = 5 // needs to be at least the DPU latency
val tagQueue = Module(new Queue( val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
chiselTypeOf(operandATag), queueDepth
))
tagQueue.io.enq.valid := dpuFire tagQueue.io.enq.valid := dpuFire
// A and B should have the same tags // A and B should have the same tags
tagQueue.io.enq.bits := operandATag tagQueue.io.enq.bits := operandATag
@@ -573,11 +586,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
val tensor = LazyModule(new TensorCoreDecoupledTL) val tensor = LazyModule(new TensorCoreDecoupledTL)
val xbar = LazyModule(new TLXbar) val xbar = LazyModule(new TLXbar)
val ramA = LazyModule(new TLRAM( val ramA = LazyModule(new TLRAM(
address = AddressSet(0x000, 0xfffeff), address = AddressSet(0x000, 0xfffbff),
beatBytes = 32 // @cleanup: hardcoded beatBytes = 32 // @cleanup: hardcoded
)) ))
val ramB = LazyModule(new TLRAM( val ramB = LazyModule(new TLRAM(
address = AddressSet(0x100, 0xfffeff), address = AddressSet(0x400, 0xfffbff),
beatBytes = 32 // @cleanup: hardcoded beatBytes = 32 // @cleanup: hardcoded
)) ))