tensor: Enlarge operand buffer for A for better SMEM reuse

This commit is contained in:
Hansung Kim
2024-10-18 21:51:34 -07:00
parent 93c9bcc32f
commit c0292dd0aa

View File

@@ -146,18 +146,6 @@ class TensorCoreDecoupled(
// Memory traffic generation // Memory traffic generation
// ------------------------- // -------------------------
// //
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
val substep = UInt(1.W)
}
// use concatenation of set/step as the memory request source. This will get
// translated to the actual TL sourcewidth in sourceGen.
val tag = Wire(new TensorMemTag)
tag.set := setAccess
tag.step := stepAccess
tag.substep := substepAccess
val numTilesM = tilingParams.m / tilingParams.mc val numTilesM = tilingParams.m / tilingParams.mc
val numTilesN = tilingParams.n / tilingParams.nc val numTilesN = tilingParams.n / tilingParams.nc
// @cleanup: generalize in terms of M/N/K-majorness? // @cleanup: generalize in terms of M/N/K-majorness?
@@ -198,12 +186,41 @@ class TensorCoreDecoupled(
val (addressA, addressB) = val (addressA, addressB) =
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
// 'index' is the index of a memory request among the sequence of requests
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
// or [0,n/2), where 2 is the stride can be read in a single request size.
require(tilingParams.m == tilingParams.n,
"currently only supports square SMEM tile")
val numIndices = tilingParams.m / 2
val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
val tagInit = Wire(new TensorMemTag)
tagInit.set := 0.U
tagInit.index := 0.U
val tagA = RegInit(tagInit)
val tagB = RegInit(tagInit)
when (io.reqA.fire) {
when (tagA.index === lastIndex.U) {
tagA.set := tagA.set + 1.U
}
tagA.index := tagA.index + 1.U
}
when (io.reqB.fire) {
when (tagB.index === lastIndex.U) {
tagB.set := tagB.set + 1.U
}
tagB.index := tagB.index + 1.U
}
val genReqA = (state === TensorState.run) val genReqA = (state === TensorState.run)
val numTilesMBits = log2Ceil(numTilesM) val genReqB = (state === TensorState.run)
// generate B request at every 4 steps. B achieves reuse through outer
// product so it doesn't require access at every step
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
val genReqB = (state === TensorState.run) && shouldFireB
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
@@ -212,11 +229,11 @@ class TensorCoreDecoupled(
case ((req, (resp, respTagged)), i) => { case ((req, (resp, respTagged)), i) => {
val sourceGen = Module(new SourceGenerator( val sourceGen = Module(new SourceGenerator(
log2Ceil(numSourceIds), log2Ceil(numSourceIds),
metadata = Some(tag) metadata = Some(new TensorMemTag)
)) ))
sourceGen.io.gen := req.fire sourceGen.io.gen := req.fire
sourceGen.io.meta := tag sourceGen.io.meta := (if (i == 0) tagA else tagB)
req.valid := (if (i == 0) genReqA else genReqB) req.valid := (if (i == 0) genReqA else genReqB)
req.bits.address := (if (i == 0) addressA else addressB) req.bits.address := (if (i == 0) addressA else addressB)
req.bits.source := sourceGen.io.id.bits req.bits.source := sourceGen.io.id.bits
@@ -243,7 +260,7 @@ class TensorCoreDecoupled(
val firedBNow = io.reqB.fire val firedBNow = io.reqB.fire
val firedA = firedAReg || firedANow val firedA = firedAReg || firedANow
val firedB = firedBReg || firedBNow val firedB = firedBReg || firedBNow
val nextSubstepAccess = firedA && (!shouldFireB || firedB) val nextSubstepAccess = firedA && firedB
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U) val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep // clear out firedABReg every substep
when (nextSubstepAccess) { when (nextSubstepAccess) {
@@ -252,17 +269,12 @@ class TensorCoreDecoupled(
substepAccess := substepAccess + 1.U substepAccess := substepAccess + 1.U
} }
require(substepAccess.widthOption.get == 1, "there should be only two substeps") require(substepAccess.widthOption.get == 1, "there should be only two substeps")
dontTouch(shouldFireB)
// Execute stage // Execute stage
// ------------- // -------------
// Backend of the decoupled access/execute pipeline. // Backend of the decoupled access/execute pipeline.
// //
// set and step being currently executed in the acc/ex backend // set and step being currently executed in the acc/ex backend
val setExecute = RegInit(0.U(setBits.W))
val stepExecute = RegInit(0.U(stepBits.W))
dontTouch(setExecute)
dontTouch(stepExecute)
val respQueueDepth = 4 // FIXME: parameterize val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth) val respQueueA = Queue(respATagged, respQueueDepth)
@@ -283,8 +295,10 @@ class TensorCoreDecoupled(
// ready for compute. Also send the set/step tag along the pipe for // ready for compute. Also send the set/step tag along the pipe for
// alignment check. // alignment check.
// @cleanup: dedup A and B below
val fullA = Module(new FillBuffer( val fullA = Module(new FillBuffer(
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/ chiselTypeOf(respQueueB.bits.data), numIndices
)) ))
fullA.io.enq.valid := respQueueA.valid fullA.io.enq.valid := respQueueA.valid
fullA.io.enq.bits := respQueueA.bits.data fullA.io.enq.bits := respQueueA.bits.data
@@ -337,23 +351,48 @@ class TensorCoreDecoupled(
fullB.io.deq.ready := fullBBuf.io.enq.ready fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val operandA = fullABuf.io.deq.bits.data
val operandATag = fullABuf.io.deq.bits.tag
val operandB = fullBBuf.io.deq.bits.data
val dpuReady = Wire(Bool()) val dpuReady = Wire(Bool())
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val dpuFire = operandsValid && dpuReady val dpuFire = operandsValid && dpuReady
val setCompute = fullABuf.io.deq.bits.tag.set
val stepCompute = fullABuf.io.deq.bits.tag.step val setCompute = RegInit(0.U(setBits.W))
val stepCompute = RegInit(0.U(stepBits.W))
val substepCompute = RegInit(0.U(1.W)) val substepCompute = RegInit(0.U(1.W))
val nextStepCompute = dpuFire && (substepCompute === 1.U)
dontTouch(setCompute)
dontTouch(stepCompute)
dontTouch(substepCompute)
when (dpuFire) { when (dpuFire) {
substepCompute := substepCompute + 1.U substepCompute := substepCompute + 1.U
} }
// hold full A until two-cycle compute is done // Operand selection
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U) //
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when // select the correct 4x4 tile from A operand buffer
// we fully iterated a column (M-dimension). val numTilesMBits = log2Ceil(numTilesM)
def selectOperandA(buf: Vec[UInt]): UInt = {
require(buf.length == numIndices)
val stepM = stepCompute & ((1 << numTilesMBits) - 1).U
Cat(buf((stepM << 1) + 1.U), buf(stepM << 1))
}
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
val operandATag = fullABuf.io.deq.bits.tag
// select the correct 2x4 tile from B operand buffer
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
val operandBTag = fullBBuf.io.deq.bits.tag
dontTouch(operandATag)
dontTouch(operandBTag)
// Operand buffer dequeue logic
//
// hold A data until the entire set is done
val shouldDequeueAMask = ((1 << stepBits) - 1).U
val shouldDequeueA =
((stepCompute & shouldDequeueAMask) === shouldDequeueAMask) &&
(substepCompute === 1.U)
fullABuf.io.deq.ready := dpuFire && shouldDequeueA
// hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension)
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB = val shouldDequeueB =
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
@@ -361,11 +400,9 @@ class TensorCoreDecoupled(
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA) dontTouch(respQueueA)
dontTouch(respQueueB) dontTouch(respQueueB)
dontTouch(shouldDequeueA)
dontTouch(shouldDequeueB) dontTouch(shouldDequeueB)
// FIXME: this should be nextStepCompute
val nextStepExecute = dpuFire && (substepCompute === 1.U)
// Assert that the DPU is computing with operands of the same set/step. Note // Assert that the DPU is computing with operands of the same set/step. Note
// that the B resp will only have step values multiple of 4 due to reuse. // that the B resp will only have step values multiple of 4 due to reuse.
// //
@@ -374,11 +411,9 @@ class TensorCoreDecoupled(
def assertAligned = { def assertAligned = {
val stepMask = (1 << numTilesMBits).U val stepMask = (1 << numTilesMBits).U
when (dpuFire) { when (dpuFire) {
assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) && assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
((fullABuf.io.deq.bits.tag.step & stepMask) === "A and B operands are pointing to different sets. " ++
(fullBBuf.io.deq.bits.tag.step & stepMask)), "This might indicate memory response coming back out-of-order.")
"A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
} }
} }
assertAligned assertAligned
@@ -386,23 +421,24 @@ class TensorCoreDecoupled(
// Dot-product unit // Dot-product unit
// //
// 4x2 four-element DPUs summing up to 32 MACs in total // 4x2 four-element DPUs summing up to 32 MACs in total
//
val ncSubstep = tilingParams.nc / 2 val ncSubstep = tilingParams.nc / 2
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)( val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(half = false)) Module(new TensorDotProductUnit(half = false))
)) ))
// operandA is 4x4 in K-major
val operandADimensional = // reshape operands for easier routing to DPU
operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
.grouped(4/*k-dim*/).toSeq x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
}
val operandADimensional = reshapeByFourWords(operandA)
require(operandADimensional.length == tilingParams.mc && require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc, operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
// select 2x4 subtile out of operandB that is 4x4 in K-major val operandBDimensional = reshapeByFourWords(operandB)
val operandBDimensional =
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
require(operandBDimensional.length == ncSubstep && require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc, operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
@@ -444,12 +480,17 @@ class TensorCoreDecoupled(
// ---------------- // ----------------
// These queues hold metadata needed for writeback in sync with the DPU. // These queues hold metadata needed for writeback in sync with the DPU.
class TensorComputeTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
val substep = UInt(1.W)
}
val queueDepth = 5 // needs to be at least the DPU latency val queueDepth = 5 // needs to be at least the DPU latency
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth)) val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
tagQueue.io.enq.valid := dpuFire tagQueue.io.enq.valid := dpuFire
// A and B should have the same tags tagQueue.io.enq.bits.set := setCompute
tagQueue.io.enq.bits := operandATag tagQueue.io.enq.bits.step := stepCompute
// @cleanup: awkward
tagQueue.io.enq.bits.substep := substepCompute tagQueue.io.enq.bits.substep := substepCompute
tagQueue.io.deq.ready := io.writeback.fire tagQueue.io.deq.ready := io.writeback.fire
assert(tagQueue.io.enq.ready === true.B, assert(tagQueue.io.enq.ready === true.B,
@@ -490,7 +531,7 @@ class TensorCoreDecoupled(
} }
} }
sequenceSetStep(setAccess, stepAccess, nextStepAccess) sequenceSetStep(setAccess, stepAccess, nextStepAccess)
sequenceSetStep(setExecute, stepExecute, nextStepExecute) sequenceSetStep(setCompute, stepCompute, nextStepCompute)
switch(state) { switch(state) {
is(TensorState.idle) { is(TensorState.idle) {