tensor: Enlarge operand buffer for A for better SMEM reuse

This commit is contained in:
Hansung Kim
2024-10-18 21:51:34 -07:00
parent 93c9bcc32f
commit c0292dd0aa

View File

@@ -146,18 +146,6 @@ class TensorCoreDecoupled(
// Memory traffic generation
// -------------------------
//
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
val substep = UInt(1.W)
}
// use concatenation of set/step as the memory request source. This will get
// translated to the actual TL sourcewidth in sourceGen.
val tag = Wire(new TensorMemTag)
tag.set := setAccess
tag.step := stepAccess
tag.substep := substepAccess
val numTilesM = tilingParams.m / tilingParams.mc
val numTilesN = tilingParams.n / tilingParams.nc
// @cleanup: generalize in terms of M/N/K-majorness?
@@ -198,12 +186,41 @@ class TensorCoreDecoupled(
val (addressA, addressB) =
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
// 'index' is the index of a memory request among the sequence of requests
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
// or [0,n/2), where 2 is the stride can be read in a single request size.
require(tilingParams.m == tilingParams.n,
"currently only supports square SMEM tile")
val numIndices = tilingParams.m / 2
val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
val tagInit = Wire(new TensorMemTag)
tagInit.set := 0.U
tagInit.index := 0.U
val tagA = RegInit(tagInit)
val tagB = RegInit(tagInit)
when (io.reqA.fire) {
when (tagA.index === lastIndex.U) {
tagA.set := tagA.set + 1.U
}
tagA.index := tagA.index + 1.U
}
when (io.reqB.fire) {
when (tagB.index === lastIndex.U) {
tagB.set := tagB.set + 1.U
}
tagB.index := tagB.index + 1.U
}
val genReqA = (state === TensorState.run)
val numTilesMBits = log2Ceil(numTilesM)
// generate B request at every 4 steps. B achieves reuse through outer
// product so it doesn't require access at every step
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
val genReqB = (state === TensorState.run) && shouldFireB
val genReqB = (state === TensorState.run)
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
@@ -212,11 +229,11 @@ class TensorCoreDecoupled(
case ((req, (resp, respTagged)), i) => {
val sourceGen = Module(new SourceGenerator(
log2Ceil(numSourceIds),
metadata = Some(tag)
metadata = Some(new TensorMemTag)
))
sourceGen.io.gen := req.fire
sourceGen.io.meta := tag
sourceGen.io.meta := (if (i == 0) tagA else tagB)
req.valid := (if (i == 0) genReqA else genReqB)
req.bits.address := (if (i == 0) addressA else addressB)
req.bits.source := sourceGen.io.id.bits
@@ -243,7 +260,7 @@ class TensorCoreDecoupled(
val firedBNow = io.reqB.fire
val firedA = firedAReg || firedANow
val firedB = firedBReg || firedBNow
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
val nextSubstepAccess = firedA && firedB
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep
when (nextSubstepAccess) {
@@ -252,17 +269,12 @@ class TensorCoreDecoupled(
substepAccess := substepAccess + 1.U
}
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
dontTouch(shouldFireB)
// Execute stage
// -------------
// Backend of the decoupled access/execute pipeline.
//
// set and step being currently executed in the acc/ex backend
val setExecute = RegInit(0.U(setBits.W))
val stepExecute = RegInit(0.U(stepBits.W))
dontTouch(setExecute)
dontTouch(stepExecute)
val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth)
@@ -283,8 +295,10 @@ class TensorCoreDecoupled(
// ready for compute. Also send the set/step tag along the pipe for
// alignment check.
// @cleanup: dedup A and B below
val fullA = Module(new FillBuffer(
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
chiselTypeOf(respQueueB.bits.data), numIndices
))
fullA.io.enq.valid := respQueueA.valid
fullA.io.enq.bits := respQueueA.bits.data
@@ -337,23 +351,48 @@ class TensorCoreDecoupled(
fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val operandA = fullABuf.io.deq.bits.data
val operandATag = fullABuf.io.deq.bits.tag
val operandB = fullBBuf.io.deq.bits.data
val dpuReady = Wire(Bool())
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val dpuFire = operandsValid && dpuReady
val setCompute = fullABuf.io.deq.bits.tag.set
val stepCompute = fullABuf.io.deq.bits.tag.step
val setCompute = RegInit(0.U(setBits.W))
val stepCompute = RegInit(0.U(stepBits.W))
val substepCompute = RegInit(0.U(1.W))
val nextStepCompute = dpuFire && (substepCompute === 1.U)
dontTouch(setCompute)
dontTouch(stepCompute)
dontTouch(substepCompute)
when (dpuFire) {
substepCompute := substepCompute + 1.U
}
// hold full A until two-cycle compute is done
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension).
// Operand selection
//
// select the correct 4x4 tile from A operand buffer
val numTilesMBits = log2Ceil(numTilesM)
def selectOperandA(buf: Vec[UInt]): UInt = {
require(buf.length == numIndices)
val stepM = stepCompute & ((1 << numTilesMBits) - 1).U
Cat(buf((stepM << 1) + 1.U), buf(stepM << 1))
}
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
val operandATag = fullABuf.io.deq.bits.tag
// select the correct 2x4 tile from B operand buffer
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
val operandBTag = fullBBuf.io.deq.bits.tag
dontTouch(operandATag)
dontTouch(operandBTag)
// Operand buffer dequeue logic
//
// hold A data until the entire set is done
val shouldDequeueAMask = ((1 << stepBits) - 1).U
val shouldDequeueA =
((stepCompute & shouldDequeueAMask) === shouldDequeueAMask) &&
(substepCompute === 1.U)
fullABuf.io.deq.ready := dpuFire && shouldDequeueA
// hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension)
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB =
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
@@ -361,11 +400,9 @@ class TensorCoreDecoupled(
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA)
dontTouch(respQueueB)
dontTouch(shouldDequeueA)
dontTouch(shouldDequeueB)
// FIXME: this should be nextStepCompute
val nextStepExecute = dpuFire && (substepCompute === 1.U)
// Assert that the DPU is computing with operands of the same set/step. Note
// that the B resp will only have step values multiple of 4 due to reuse.
//
@@ -374,11 +411,9 @@ class TensorCoreDecoupled(
def assertAligned = {
val stepMask = (1 << numTilesMBits).U
when (dpuFire) {
assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) &&
((fullABuf.io.deq.bits.tag.step & stepMask) ===
(fullBBuf.io.deq.bits.tag.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
"A and B operands are pointing to different sets. " ++
"This might indicate memory response coming back out-of-order.")
}
}
assertAligned
@@ -386,23 +421,24 @@ class TensorCoreDecoupled(
// Dot-product unit
//
// 4x2 four-element DPUs summing up to 32 MACs in total
//
val ncSubstep = tilingParams.nc / 2
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(half = false))
))
// operandA is 4x4 in K-major
val operandADimensional =
operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
// reshape operands for easier routing to DPU
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
}
val operandADimensional = reshapeByFourWords(operandA)
require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// select 2x4 subtile out of operandB that is 4x4 in K-major
val operandBDimensional =
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
val operandBDimensional = reshapeByFourWords(operandB)
require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
@@ -444,12 +480,17 @@ class TensorCoreDecoupled(
// ----------------
// These queues hold metadata needed for writeback in sync with the DPU.
class TensorComputeTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
val substep = UInt(1.W)
}
val queueDepth = 5 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
tagQueue.io.enq.valid := dpuFire
// A and B should have the same tags
tagQueue.io.enq.bits := operandATag
// @cleanup: awkward
tagQueue.io.enq.bits.set := setCompute
tagQueue.io.enq.bits.step := stepCompute
tagQueue.io.enq.bits.substep := substepCompute
tagQueue.io.deq.ready := io.writeback.fire
assert(tagQueue.io.enq.ready === true.B,
@@ -490,7 +531,7 @@ class TensorCoreDecoupled(
}
}
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
switch(state) {
is(TensorState.idle) {