tensor: Enlarge operand buffer for A for better SMEM reuse
This commit is contained in:
@@ -146,18 +146,6 @@ class TensorCoreDecoupled(
|
|||||||
// Memory traffic generation
|
// Memory traffic generation
|
||||||
// -------------------------
|
// -------------------------
|
||||||
//
|
//
|
||||||
class TensorMemTag extends Bundle {
|
|
||||||
val set = UInt(setBits.W)
|
|
||||||
val step = UInt(stepBits.W)
|
|
||||||
val substep = UInt(1.W)
|
|
||||||
}
|
|
||||||
// use concatenation of set/step as the memory request source. This will get
|
|
||||||
// translated to the actual TL sourcewidth in sourceGen.
|
|
||||||
val tag = Wire(new TensorMemTag)
|
|
||||||
tag.set := setAccess
|
|
||||||
tag.step := stepAccess
|
|
||||||
tag.substep := substepAccess
|
|
||||||
|
|
||||||
val numTilesM = tilingParams.m / tilingParams.mc
|
val numTilesM = tilingParams.m / tilingParams.mc
|
||||||
val numTilesN = tilingParams.n / tilingParams.nc
|
val numTilesN = tilingParams.n / tilingParams.nc
|
||||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
// @cleanup: generalize in terms of M/N/K-majorness?
|
||||||
@@ -198,12 +186,41 @@ class TensorCoreDecoupled(
|
|||||||
val (addressA, addressB) =
|
val (addressA, addressB) =
|
||||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
||||||
|
|
||||||
|
// 'index' is the index of a memory request among the sequence of requests
|
||||||
|
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||||
|
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||||
|
require(tilingParams.m == tilingParams.n,
|
||||||
|
"currently only supports square SMEM tile")
|
||||||
|
val numIndices = tilingParams.m / 2
|
||||||
|
val indexBits = log2Ceil(numIndices)
|
||||||
|
val lastIndex = (1 << indexBits) - 1
|
||||||
|
|
||||||
|
class TensorMemTag extends Bundle {
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val index = UInt(indexBits.W)
|
||||||
|
}
|
||||||
|
|
||||||
|
val tagInit = Wire(new TensorMemTag)
|
||||||
|
tagInit.set := 0.U
|
||||||
|
tagInit.index := 0.U
|
||||||
|
val tagA = RegInit(tagInit)
|
||||||
|
val tagB = RegInit(tagInit)
|
||||||
|
|
||||||
|
when (io.reqA.fire) {
|
||||||
|
when (tagA.index === lastIndex.U) {
|
||||||
|
tagA.set := tagA.set + 1.U
|
||||||
|
}
|
||||||
|
tagA.index := tagA.index + 1.U
|
||||||
|
}
|
||||||
|
when (io.reqB.fire) {
|
||||||
|
when (tagB.index === lastIndex.U) {
|
||||||
|
tagB.set := tagB.set + 1.U
|
||||||
|
}
|
||||||
|
tagB.index := tagB.index + 1.U
|
||||||
|
}
|
||||||
|
|
||||||
val genReqA = (state === TensorState.run)
|
val genReqA = (state === TensorState.run)
|
||||||
val numTilesMBits = log2Ceil(numTilesM)
|
val genReqB = (state === TensorState.run)
|
||||||
// generate B request at every 4 steps. B achieves reuse through outer
|
|
||||||
// product so it doesn't require access at every step
|
|
||||||
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
|
|
||||||
val genReqB = (state === TensorState.run) && shouldFireB
|
|
||||||
|
|
||||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
@@ -212,11 +229,11 @@ class TensorCoreDecoupled(
|
|||||||
case ((req, (resp, respTagged)), i) => {
|
case ((req, (resp, respTagged)), i) => {
|
||||||
val sourceGen = Module(new SourceGenerator(
|
val sourceGen = Module(new SourceGenerator(
|
||||||
log2Ceil(numSourceIds),
|
log2Ceil(numSourceIds),
|
||||||
metadata = Some(tag)
|
metadata = Some(new TensorMemTag)
|
||||||
))
|
))
|
||||||
|
|
||||||
sourceGen.io.gen := req.fire
|
sourceGen.io.gen := req.fire
|
||||||
sourceGen.io.meta := tag
|
sourceGen.io.meta := (if (i == 0) tagA else tagB)
|
||||||
req.valid := (if (i == 0) genReqA else genReqB)
|
req.valid := (if (i == 0) genReqA else genReqB)
|
||||||
req.bits.address := (if (i == 0) addressA else addressB)
|
req.bits.address := (if (i == 0) addressA else addressB)
|
||||||
req.bits.source := sourceGen.io.id.bits
|
req.bits.source := sourceGen.io.id.bits
|
||||||
@@ -243,7 +260,7 @@ class TensorCoreDecoupled(
|
|||||||
val firedBNow = io.reqB.fire
|
val firedBNow = io.reqB.fire
|
||||||
val firedA = firedAReg || firedANow
|
val firedA = firedAReg || firedANow
|
||||||
val firedB = firedBReg || firedBNow
|
val firedB = firedBReg || firedBNow
|
||||||
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
|
val nextSubstepAccess = firedA && firedB
|
||||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||||
// clear out firedABReg every substep
|
// clear out firedABReg every substep
|
||||||
when (nextSubstepAccess) {
|
when (nextSubstepAccess) {
|
||||||
@@ -252,17 +269,12 @@ class TensorCoreDecoupled(
|
|||||||
substepAccess := substepAccess + 1.U
|
substepAccess := substepAccess + 1.U
|
||||||
}
|
}
|
||||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||||
dontTouch(shouldFireB)
|
|
||||||
|
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// -------------
|
// -------------
|
||||||
// Backend of the decoupled access/execute pipeline.
|
// Backend of the decoupled access/execute pipeline.
|
||||||
//
|
//
|
||||||
// set and step being currently executed in the acc/ex backend
|
// set and step being currently executed in the acc/ex backend
|
||||||
val setExecute = RegInit(0.U(setBits.W))
|
|
||||||
val stepExecute = RegInit(0.U(stepBits.W))
|
|
||||||
dontTouch(setExecute)
|
|
||||||
dontTouch(stepExecute)
|
|
||||||
|
|
||||||
val respQueueDepth = 4 // FIXME: parameterize
|
val respQueueDepth = 4 // FIXME: parameterize
|
||||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||||
@@ -283,8 +295,10 @@ class TensorCoreDecoupled(
|
|||||||
// ready for compute. Also send the set/step tag along the pipe for
|
// ready for compute. Also send the set/step tag along the pipe for
|
||||||
// alignment check.
|
// alignment check.
|
||||||
|
|
||||||
|
// @cleanup: dedup A and B below
|
||||||
|
|
||||||
val fullA = Module(new FillBuffer(
|
val fullA = Module(new FillBuffer(
|
||||||
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
|
chiselTypeOf(respQueueB.bits.data), numIndices
|
||||||
))
|
))
|
||||||
fullA.io.enq.valid := respQueueA.valid
|
fullA.io.enq.valid := respQueueA.valid
|
||||||
fullA.io.enq.bits := respQueueA.bits.data
|
fullA.io.enq.bits := respQueueA.bits.data
|
||||||
@@ -337,23 +351,48 @@ class TensorCoreDecoupled(
|
|||||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
|
|
||||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
|
||||||
val operandA = fullABuf.io.deq.bits.data
|
|
||||||
val operandATag = fullABuf.io.deq.bits.tag
|
|
||||||
val operandB = fullBBuf.io.deq.bits.data
|
|
||||||
val dpuReady = Wire(Bool())
|
val dpuReady = Wire(Bool())
|
||||||
|
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||||
val dpuFire = operandsValid && dpuReady
|
val dpuFire = operandsValid && dpuReady
|
||||||
val setCompute = fullABuf.io.deq.bits.tag.set
|
|
||||||
val stepCompute = fullABuf.io.deq.bits.tag.step
|
val setCompute = RegInit(0.U(setBits.W))
|
||||||
|
val stepCompute = RegInit(0.U(stepBits.W))
|
||||||
val substepCompute = RegInit(0.U(1.W))
|
val substepCompute = RegInit(0.U(1.W))
|
||||||
|
val nextStepCompute = dpuFire && (substepCompute === 1.U)
|
||||||
|
dontTouch(setCompute)
|
||||||
|
dontTouch(stepCompute)
|
||||||
|
dontTouch(substepCompute)
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
substepCompute := substepCompute + 1.U
|
substepCompute := substepCompute + 1.U
|
||||||
}
|
}
|
||||||
|
|
||||||
// hold full A until two-cycle compute is done
|
// Operand selection
|
||||||
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
//
|
||||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
// select the correct 4x4 tile from A operand buffer
|
||||||
// we fully iterated a column (M-dimension).
|
val numTilesMBits = log2Ceil(numTilesM)
|
||||||
|
def selectOperandA(buf: Vec[UInt]): UInt = {
|
||||||
|
require(buf.length == numIndices)
|
||||||
|
val stepM = stepCompute & ((1 << numTilesMBits) - 1).U
|
||||||
|
Cat(buf((stepM << 1) + 1.U), buf(stepM << 1))
|
||||||
|
}
|
||||||
|
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
|
||||||
|
val operandATag = fullABuf.io.deq.bits.tag
|
||||||
|
// select the correct 2x4 tile from B operand buffer
|
||||||
|
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
|
||||||
|
val operandBTag = fullBBuf.io.deq.bits.tag
|
||||||
|
dontTouch(operandATag)
|
||||||
|
dontTouch(operandBTag)
|
||||||
|
|
||||||
|
// Operand buffer dequeue logic
|
||||||
|
//
|
||||||
|
// hold A data until the entire set is done
|
||||||
|
val shouldDequeueAMask = ((1 << stepBits) - 1).U
|
||||||
|
val shouldDequeueA =
|
||||||
|
((stepCompute & shouldDequeueAMask) === shouldDequeueAMask) &&
|
||||||
|
(substepCompute === 1.U)
|
||||||
|
fullABuf.io.deq.ready := dpuFire && shouldDequeueA
|
||||||
|
// hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||||
|
// we fully iterated a column (M-dimension)
|
||||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||||
val shouldDequeueB =
|
val shouldDequeueB =
|
||||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||||
@@ -361,11 +400,9 @@ class TensorCoreDecoupled(
|
|||||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
dontTouch(respQueueB)
|
dontTouch(respQueueB)
|
||||||
|
dontTouch(shouldDequeueA)
|
||||||
dontTouch(shouldDequeueB)
|
dontTouch(shouldDequeueB)
|
||||||
|
|
||||||
// FIXME: this should be nextStepCompute
|
|
||||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
|
||||||
|
|
||||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||||
// that the B resp will only have step values multiple of 4 due to reuse.
|
// that the B resp will only have step values multiple of 4 due to reuse.
|
||||||
//
|
//
|
||||||
@@ -374,11 +411,9 @@ class TensorCoreDecoupled(
|
|||||||
def assertAligned = {
|
def assertAligned = {
|
||||||
val stepMask = (1 << numTilesMBits).U
|
val stepMask = (1 << numTilesMBits).U
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) &&
|
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
|
||||||
((fullABuf.io.deq.bits.tag.step & stepMask) ===
|
"A and B operands are pointing to different sets. " ++
|
||||||
(fullBBuf.io.deq.bits.tag.step & stepMask)),
|
"This might indicate memory response coming back out-of-order.")
|
||||||
"A and B operands are pointing to different set/steps. " ++
|
|
||||||
"This might indicate memory response coming back out-of-order.")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assertAligned
|
assertAligned
|
||||||
@@ -386,23 +421,24 @@ class TensorCoreDecoupled(
|
|||||||
// Dot-product unit
|
// Dot-product unit
|
||||||
//
|
//
|
||||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||||
|
//
|
||||||
val ncSubstep = tilingParams.nc / 2
|
val ncSubstep = tilingParams.nc / 2
|
||||||
|
require(tilingParams.mc * ncSubstep == numLanes,
|
||||||
|
"substep tile size doesn't match writeback throughput")
|
||||||
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||||
Module(new TensorDotProductUnit(half = false))
|
Module(new TensorDotProductUnit(half = false))
|
||||||
))
|
))
|
||||||
// operandA is 4x4 in K-major
|
|
||||||
val operandADimensional =
|
// reshape operands for easier routing to DPU
|
||||||
operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
|
||||||
.grouped(4/*k-dim*/).toSeq
|
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
|
.grouped(4/*k-dim*/).toSeq
|
||||||
|
}
|
||||||
|
val operandADimensional = reshapeByFourWords(operandA)
|
||||||
require(operandADimensional.length == tilingParams.mc &&
|
require(operandADimensional.length == tilingParams.mc &&
|
||||||
operandADimensional(0).length == tilingParams.kc,
|
operandADimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
// select 2x4 subtile out of operandB that is 4x4 in K-major
|
val operandBDimensional = reshapeByFourWords(operandB)
|
||||||
val operandBDimensional =
|
|
||||||
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
|
||||||
.grouped(4/*k-dim*/).toSeq
|
|
||||||
require(tilingParams.mc * ncSubstep == numLanes,
|
|
||||||
"substep tile size doesn't match writeback throughput")
|
|
||||||
require(operandBDimensional.length == ncSubstep &&
|
require(operandBDimensional.length == ncSubstep &&
|
||||||
operandBDimensional(0).length == tilingParams.kc,
|
operandBDimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
@@ -444,12 +480,17 @@ class TensorCoreDecoupled(
|
|||||||
// ----------------
|
// ----------------
|
||||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||||
|
|
||||||
|
class TensorComputeTag extends Bundle {
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val step = UInt(stepBits.W)
|
||||||
|
val substep = UInt(1.W)
|
||||||
|
}
|
||||||
|
|
||||||
val queueDepth = 5 // needs to be at least the DPU latency
|
val queueDepth = 5 // needs to be at least the DPU latency
|
||||||
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
|
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||||
tagQueue.io.enq.valid := dpuFire
|
tagQueue.io.enq.valid := dpuFire
|
||||||
// A and B should have the same tags
|
tagQueue.io.enq.bits.set := setCompute
|
||||||
tagQueue.io.enq.bits := operandATag
|
tagQueue.io.enq.bits.step := stepCompute
|
||||||
// @cleanup: awkward
|
|
||||||
tagQueue.io.enq.bits.substep := substepCompute
|
tagQueue.io.enq.bits.substep := substepCompute
|
||||||
tagQueue.io.deq.ready := io.writeback.fire
|
tagQueue.io.deq.ready := io.writeback.fire
|
||||||
assert(tagQueue.io.enq.ready === true.B,
|
assert(tagQueue.io.enq.ready === true.B,
|
||||||
@@ -490,7 +531,7 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
||||||
sequenceSetStep(setExecute, stepExecute, nextStepExecute)
|
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||||
|
|
||||||
switch(state) {
|
switch(state) {
|
||||||
is(TensorState.idle) {
|
is(TensorState.idle) {
|
||||||
|
|||||||
Reference in New Issue
Block a user