tensor: Replace staging logic for A with FillBuffer
This commit is contained in:
@@ -272,46 +272,41 @@ class TensorCoreDecoupled(
|
||||
io.writeback.bits.data.widthOption.get,
|
||||
"response data width does not match the writeback data width")
|
||||
|
||||
// FIXME: unnecessary
|
||||
val substepDeqA = RegInit(0.U(1.W))
|
||||
when (respQueueA.fire) {
|
||||
substepDeqA := substepDeqA + 1.U
|
||||
}
|
||||
dontTouch(substepDeqA)
|
||||
|
||||
// Do pipelining for the A operand so that we obtain the full 4x4 A tile
|
||||
// ready for compute. The pipeline is two-stage:
|
||||
// - stage one (halfAQueue) for assembling the full A tile from half-tiles
|
||||
// coming from the resp queue, and
|
||||
// - stage two (fullAQueue) for holding the full A tile until it gets
|
||||
// matched with two 4x2 B tiles, and compute is complete.
|
||||
//
|
||||
// Note that the half-tile assembly is unnecessary for B since the B tile is
|
||||
// only 4x2.
|
||||
// Also send the set/step tag along the pipe for alignment check.
|
||||
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
|
||||
// ready for compute. Also send the set/step tag along the pipe for
|
||||
// alignment check.
|
||||
|
||||
// note combinationally coupled ready with `pipe`
|
||||
val halfAQueue = Module(new Queue(
|
||||
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
|
||||
val fullA = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
|
||||
))
|
||||
halfAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 0.U)
|
||||
halfAQueue.io.enq.bits := respQueueA.bits
|
||||
fullA.io.enq.valid := respQueueA.valid
|
||||
fullA.io.enq.bits := respQueueA.bits.data
|
||||
respQueueA.ready := fullA.io.enq.ready
|
||||
// `pipe` combinationally couples enq-deq ready
|
||||
val fullATag = Module(new Queue(
|
||||
new TensorMemTag, entries = 1, pipe = true
|
||||
))
|
||||
fullATag.io.enq.valid := respQueueA.valid
|
||||
fullATag.io.enq.bits := respQueueA.bits.tag
|
||||
|
||||
// substep == 0 data goes to the LSB
|
||||
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data)
|
||||
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
||||
"assumes 2-cycle read for a full compute tile of A")
|
||||
// only use the lower halfA's tag. substep will be incorrect.
|
||||
val fullAEnqTag = halfAQueue.io.deq.bits.tag
|
||||
val fullAQueue = Module(new Queue(
|
||||
// stage the full A tile once more so that FillBuffer can be filled up in the
|
||||
// background while the tile is being used for compute. This does come with
|
||||
// capacity overhead.
|
||||
val fullABuf = Module(new Queue(
|
||||
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
|
||||
))
|
||||
// hold first half A data for the first substep
|
||||
halfAQueue.io.deq.ready := respQueueA.valid && (substepDeqA === 1.U) &&
|
||||
fullAQueue.io.enq.ready
|
||||
fullAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 1.U) &&
|
||||
halfAQueue.io.deq.valid
|
||||
fullAQueue.io.enq.bits.data := fullAEnqData
|
||||
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
||||
fullABuf.io.enq.valid := fullA.io.deq.valid
|
||||
fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt
|
||||
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
|
||||
fullA.io.deq.ready := fullABuf.io.enq.ready
|
||||
fullATag.io.deq.ready := fullABuf.io.enq.ready
|
||||
|
||||
// serialize every two B responses into one full 4x4 B tile
|
||||
// FIXME: do the same for A
|
||||
@@ -327,29 +322,24 @@ class TensorCoreDecoupled(
|
||||
fullBTag.io.enq.valid := respQueueB.valid
|
||||
fullBTag.io.enq.bits := respQueueB.bits.tag
|
||||
|
||||
val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid
|
||||
val operandA = fullAQueue.io.deq.bits.data
|
||||
val operandATag = fullAQueue.io.deq.bits.tag
|
||||
val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid
|
||||
val operandA = fullABuf.io.deq.bits.data
|
||||
val operandATag = fullABuf.io.deq.bits.tag
|
||||
val operandB = fullB.io.deq.bits
|
||||
val dpuReady = Wire(Bool())
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
val setCompute = fullAQueue.io.deq.bits.tag.set
|
||||
val stepCompute = fullAQueue.io.deq.bits.tag.step
|
||||
val setCompute = fullABuf.io.deq.bits.tag.set
|
||||
val stepCompute = fullABuf.io.deq.bits.tag.step
|
||||
val substepCompute = RegInit(0.U(1.W))
|
||||
when (dpuFire) {
|
||||
substepCompute := substepCompute + 1.U
|
||||
}
|
||||
|
||||
// respQueueA output arbitrates to either halfAQueue or fullAQueue depending
|
||||
// on the substep
|
||||
respQueueA.ready := MuxCase(false.B,
|
||||
Seq((substepDeqA === 0.U) -> halfAQueue.io.enq.ready,
|
||||
(substepDeqA === 1.U) -> fullAQueue.io.enq.ready))
|
||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||
// we fully iterated a column (M-dimension).
|
||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||
val shouldDequeueB =
|
||||
((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullB.io.deq.ready := dpuFire && shouldDequeueB
|
||||
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
|
||||
@@ -358,7 +348,8 @@ class TensorCoreDecoupled(
|
||||
dontTouch(shouldDequeueB)
|
||||
|
||||
// hold full A until two-cycle compute is done
|
||||
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||
// FIXME: this should be nextStepCompute
|
||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||
|
||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||
@@ -369,8 +360,8 @@ class TensorCoreDecoupled(
|
||||
def assertAligned = {
|
||||
val stepMask = (1 << numTilesMBits).U
|
||||
when (dpuFire) {
|
||||
assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
|
||||
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
|
||||
assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
|
||||
((fullABuf.io.deq.bits.tag.step & stepMask) ===
|
||||
(fullBTag.io.deq.bits.step & stepMask)),
|
||||
"A and B operands are pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
|
||||
Reference in New Issue
Block a user