tensor: Enlarge operand buffer for A for better SMEM reuse
This commit is contained in:
@@ -146,18 +146,6 @@ class TensorCoreDecoupled(
|
||||
// Memory traffic generation
|
||||
// -------------------------
|
||||
//
|
||||
class TensorMemTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
val substep = UInt(1.W)
|
||||
}
|
||||
// use concatenation of set/step as the memory request source. This will get
|
||||
// translated to the actual TL sourcewidth in sourceGen.
|
||||
val tag = Wire(new TensorMemTag)
|
||||
tag.set := setAccess
|
||||
tag.step := stepAccess
|
||||
tag.substep := substepAccess
|
||||
|
||||
val numTilesM = tilingParams.m / tilingParams.mc
|
||||
val numTilesN = tilingParams.n / tilingParams.nc
|
||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
||||
@@ -198,12 +186,41 @@ class TensorCoreDecoupled(
|
||||
val (addressA, addressB) =
|
||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
||||
|
||||
// 'index' is the index of a memory request among the sequence of requests
|
||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||
require(tilingParams.m == tilingParams.n,
|
||||
"currently only supports square SMEM tile")
|
||||
val numIndices = tilingParams.m / 2
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
class TensorMemTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
|
||||
val tagInit = Wire(new TensorMemTag)
|
||||
tagInit.set := 0.U
|
||||
tagInit.index := 0.U
|
||||
val tagA = RegInit(tagInit)
|
||||
val tagB = RegInit(tagInit)
|
||||
|
||||
when (io.reqA.fire) {
|
||||
when (tagA.index === lastIndex.U) {
|
||||
tagA.set := tagA.set + 1.U
|
||||
}
|
||||
tagA.index := tagA.index + 1.U
|
||||
}
|
||||
when (io.reqB.fire) {
|
||||
when (tagB.index === lastIndex.U) {
|
||||
tagB.set := tagB.set + 1.U
|
||||
}
|
||||
tagB.index := tagB.index + 1.U
|
||||
}
|
||||
|
||||
val genReqA = (state === TensorState.run)
|
||||
val numTilesMBits = log2Ceil(numTilesM)
|
||||
// generate B request at every 4 steps. B achieves reuse through outer
|
||||
// product so it doesn't require access at every step
|
||||
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
|
||||
val genReqB = (state === TensorState.run) && shouldFireB
|
||||
val genReqB = (state === TensorState.run)
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
@@ -212,11 +229,11 @@ class TensorCoreDecoupled(
|
||||
case ((req, (resp, respTagged)), i) => {
|
||||
val sourceGen = Module(new SourceGenerator(
|
||||
log2Ceil(numSourceIds),
|
||||
metadata = Some(tag)
|
||||
metadata = Some(new TensorMemTag)
|
||||
))
|
||||
|
||||
sourceGen.io.gen := req.fire
|
||||
sourceGen.io.meta := tag
|
||||
sourceGen.io.meta := (if (i == 0) tagA else tagB)
|
||||
req.valid := (if (i == 0) genReqA else genReqB)
|
||||
req.bits.address := (if (i == 0) addressA else addressB)
|
||||
req.bits.source := sourceGen.io.id.bits
|
||||
@@ -243,7 +260,7 @@ class TensorCoreDecoupled(
|
||||
val firedBNow = io.reqB.fire
|
||||
val firedA = firedAReg || firedANow
|
||||
val firedB = firedBReg || firedBNow
|
||||
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
|
||||
val nextSubstepAccess = firedA && firedB
|
||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||
// clear out firedABReg every substep
|
||||
when (nextSubstepAccess) {
|
||||
@@ -252,17 +269,12 @@ class TensorCoreDecoupled(
|
||||
substepAccess := substepAccess + 1.U
|
||||
}
|
||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||
dontTouch(shouldFireB)
|
||||
|
||||
// Execute stage
|
||||
// -------------
|
||||
// Backend of the decoupled access/execute pipeline.
|
||||
//
|
||||
// set and step being currently executed in the acc/ex backend
|
||||
val setExecute = RegInit(0.U(setBits.W))
|
||||
val stepExecute = RegInit(0.U(stepBits.W))
|
||||
dontTouch(setExecute)
|
||||
dontTouch(stepExecute)
|
||||
|
||||
val respQueueDepth = 4 // FIXME: parameterize
|
||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||
@@ -283,8 +295,10 @@ class TensorCoreDecoupled(
|
||||
// ready for compute. Also send the set/step tag along the pipe for
|
||||
// alignment check.
|
||||
|
||||
// @cleanup: dedup A and B below
|
||||
|
||||
val fullA = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
|
||||
chiselTypeOf(respQueueB.bits.data), numIndices
|
||||
))
|
||||
fullA.io.enq.valid := respQueueA.valid
|
||||
fullA.io.enq.bits := respQueueA.bits.data
|
||||
@@ -337,23 +351,48 @@ class TensorCoreDecoupled(
|
||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||
|
||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||
val operandA = fullABuf.io.deq.bits.data
|
||||
val operandATag = fullABuf.io.deq.bits.tag
|
||||
val operandB = fullBBuf.io.deq.bits.data
|
||||
val dpuReady = Wire(Bool())
|
||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
val setCompute = fullABuf.io.deq.bits.tag.set
|
||||
val stepCompute = fullABuf.io.deq.bits.tag.step
|
||||
|
||||
val setCompute = RegInit(0.U(setBits.W))
|
||||
val stepCompute = RegInit(0.U(stepBits.W))
|
||||
val substepCompute = RegInit(0.U(1.W))
|
||||
val nextStepCompute = dpuFire && (substepCompute === 1.U)
|
||||
dontTouch(setCompute)
|
||||
dontTouch(stepCompute)
|
||||
dontTouch(substepCompute)
|
||||
when (dpuFire) {
|
||||
substepCompute := substepCompute + 1.U
|
||||
}
|
||||
|
||||
// hold full A until two-cycle compute is done
|
||||
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||
// we fully iterated a column (M-dimension).
|
||||
// Operand selection
|
||||
//
|
||||
// select the correct 4x4 tile from A operand buffer
|
||||
val numTilesMBits = log2Ceil(numTilesM)
|
||||
def selectOperandA(buf: Vec[UInt]): UInt = {
|
||||
require(buf.length == numIndices)
|
||||
val stepM = stepCompute & ((1 << numTilesMBits) - 1).U
|
||||
Cat(buf((stepM << 1) + 1.U), buf(stepM << 1))
|
||||
}
|
||||
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
|
||||
val operandATag = fullABuf.io.deq.bits.tag
|
||||
// select the correct 2x4 tile from B operand buffer
|
||||
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
|
||||
val operandBTag = fullBBuf.io.deq.bits.tag
|
||||
dontTouch(operandATag)
|
||||
dontTouch(operandBTag)
|
||||
|
||||
// Operand buffer dequeue logic
|
||||
//
|
||||
// hold A data until the entire set is done
|
||||
val shouldDequeueAMask = ((1 << stepBits) - 1).U
|
||||
val shouldDequeueA =
|
||||
((stepCompute & shouldDequeueAMask) === shouldDequeueAMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullABuf.io.deq.ready := dpuFire && shouldDequeueA
|
||||
// hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||
// we fully iterated a column (M-dimension)
|
||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||
val shouldDequeueB =
|
||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
@@ -361,11 +400,9 @@ class TensorCoreDecoupled(
|
||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
dontTouch(shouldDequeueA)
|
||||
dontTouch(shouldDequeueB)
|
||||
|
||||
// FIXME: this should be nextStepCompute
|
||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||
|
||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||
// that the B resp will only have step values multiple of 4 due to reuse.
|
||||
//
|
||||
@@ -374,10 +411,8 @@ class TensorCoreDecoupled(
|
||||
def assertAligned = {
|
||||
val stepMask = (1 << numTilesMBits).U
|
||||
when (dpuFire) {
|
||||
assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) &&
|
||||
((fullABuf.io.deq.bits.tag.step & stepMask) ===
|
||||
(fullBBuf.io.deq.bits.tag.step & stepMask)),
|
||||
"A and B operands are pointing to different set/steps. " ++
|
||||
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
|
||||
"A and B operands are pointing to different sets. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
}
|
||||
@@ -386,23 +421,24 @@ class TensorCoreDecoupled(
|
||||
// Dot-product unit
|
||||
//
|
||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||
//
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||
Module(new TensorDotProductUnit(half = false))
|
||||
))
|
||||
// operandA is 4x4 in K-major
|
||||
val operandADimensional =
|
||||
operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
|
||||
// reshape operands for easier routing to DPU
|
||||
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
|
||||
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
}
|
||||
val operandADimensional = reshapeByFourWords(operandA)
|
||||
require(operandADimensional.length == tilingParams.mc &&
|
||||
operandADimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
// select 2x4 subtile out of operandB that is 4x4 in K-major
|
||||
val operandBDimensional =
|
||||
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
val operandBDimensional = reshapeByFourWords(operandB)
|
||||
require(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
@@ -444,12 +480,17 @@ class TensorCoreDecoupled(
|
||||
// ----------------
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
class TensorComputeTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
val substep = UInt(1.W)
|
||||
}
|
||||
|
||||
val queueDepth = 5 // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
|
||||
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||
tagQueue.io.enq.valid := dpuFire
|
||||
// A and B should have the same tags
|
||||
tagQueue.io.enq.bits := operandATag
|
||||
// @cleanup: awkward
|
||||
tagQueue.io.enq.bits.set := setCompute
|
||||
tagQueue.io.enq.bits.step := stepCompute
|
||||
tagQueue.io.enq.bits.substep := substepCompute
|
||||
tagQueue.io.deq.ready := io.writeback.fire
|
||||
assert(tagQueue.io.enq.ready === true.B,
|
||||
@@ -490,7 +531,7 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
||||
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
|
||||
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||
|
||||
switch(state) {
|
||||
is(TensorState.idle) {
|
||||
|
||||
Reference in New Issue
Block a user