tensor: Replace staging logic for A with FillBuffer

This commit is contained in:
Hansung Kim
2024-10-18 19:54:20 -07:00
parent 7fab6f89ad
commit c4b5a11fde

View File

@@ -272,46 +272,41 @@ class TensorCoreDecoupled(
io.writeback.bits.data.widthOption.get,
"response data width does not match the writeback data width")
// FIXME: unnecessary
val substepDeqA = RegInit(0.U(1.W))
when (respQueueA.fire) {
substepDeqA := substepDeqA + 1.U
}
dontTouch(substepDeqA)
// Do pipelining for the A operand so that we obtain the full 4x4 A tile
// ready for compute. The pipeline is two-stage:
// - stage one (halfAQueue) for assembling the full A tile from half-tiles
// coming from the resp queue, and
// - stage two (fullAQueue) for holding the full A tile until it gets
// matched with two 4x2 B tiles, and compute is complete.
//
// Note that the half-tile assembly is unnecessary for B since the B tile is
// only 4x2.
// Also send the set/step tag along the pipe for alignment check.
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
// ready for compute. Also send the set/step tag along the pipe for
// alignment check.
// note combinationally coupled ready with `pipe`
val halfAQueue = Module(new Queue(
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
val fullA = Module(new FillBuffer(
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
))
halfAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 0.U)
halfAQueue.io.enq.bits := respQueueA.bits
fullA.io.enq.valid := respQueueA.valid
fullA.io.enq.bits := respQueueA.bits.data
respQueueA.ready := fullA.io.enq.ready
// `pipe` combinationally couples enq-deq ready
val fullATag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullATag.io.enq.valid := respQueueA.valid
fullATag.io.enq.bits := respQueueA.bits.tag
// substep == 0 data goes to the LSB
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data)
require(fullAEnqData.widthOption.get == dataWidth * 2,
"assumes 2-cycle read for a full compute tile of A")
// only use the lower halfA's tag. substep will be incorrect.
val fullAEnqTag = halfAQueue.io.deq.bits.tag
val fullAQueue = Module(new Queue(
// stage the full A tile once more so that FillBuffer can be filled up in the
// background while the tile is being used for compute. This does come with
// capacity overhead.
val fullABuf = Module(new Queue(
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
))
// hold first half A data for the first substep
halfAQueue.io.deq.ready := respQueueA.valid && (substepDeqA === 1.U) &&
fullAQueue.io.enq.ready
fullAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 1.U) &&
halfAQueue.io.deq.valid
fullAQueue.io.enq.bits.data := fullAEnqData
fullAQueue.io.enq.bits.tag := fullAEnqTag
fullABuf.io.enq.valid := fullA.io.deq.valid
fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
fullA.io.deq.ready := fullABuf.io.enq.ready
fullATag.io.deq.ready := fullABuf.io.enq.ready
// serialize every two B responses into one full 4x4 B tile
// FIXME: do the same for A
@@ -327,29 +322,24 @@ class TensorCoreDecoupled(
fullBTag.io.enq.valid := respQueueB.valid
fullBTag.io.enq.bits := respQueueB.bits.tag
val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid
val operandA = fullAQueue.io.deq.bits.data
val operandATag = fullAQueue.io.deq.bits.tag
val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid
val operandA = fullABuf.io.deq.bits.data
val operandATag = fullABuf.io.deq.bits.tag
val operandB = fullB.io.deq.bits
val dpuReady = Wire(Bool())
val dpuFire = operandsValid && dpuReady
val setCompute = fullAQueue.io.deq.bits.tag.set
val stepCompute = fullAQueue.io.deq.bits.tag.step
val setCompute = fullABuf.io.deq.bits.tag.set
val stepCompute = fullABuf.io.deq.bits.tag.step
val substepCompute = RegInit(0.U(1.W))
when (dpuFire) {
substepCompute := substepCompute + 1.U
}
// respQueueA output arbitrates to either halfAQueue or fullAQueue depending
// on the substep
respQueueA.ready := MuxCase(false.B,
Seq((substepDeqA === 0.U) -> halfAQueue.io.enq.ready,
(substepDeqA === 1.U) -> fullAQueue.io.enq.ready))
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension).
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB =
((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) &&
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U)
fullB.io.deq.ready := dpuFire && shouldDequeueB
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
@@ -358,7 +348,8 @@ class TensorCoreDecoupled(
dontTouch(shouldDequeueB)
// hold full A until two-cycle compute is done
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
// FIXME: this should be nextStepCompute
val nextStepExecute = dpuFire && (substepCompute === 1.U)
// Assert that the DPU is computing with operands of the same set/step. Note
@@ -369,8 +360,8 @@ class TensorCoreDecoupled(
def assertAligned = {
val stepMask = (1 << numTilesMBits).U
when (dpuFire) {
assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
((fullABuf.io.deq.bits.tag.step & stepMask) ===
(fullBTag.io.deq.bits.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")