tensor: Fix operand alignment in pipelining
This commit is contained in:
@@ -224,37 +224,51 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
dontTouch(substepExecute)
|
||||
|
||||
// Do pipelining for the A operand so that we obtain the full 4x4 A tile
|
||||
// ready for compute. The pipeline is two-stage:
|
||||
// - stage one (halfAQueue) for assembling the full A tile from half-tiles
|
||||
// coming from the resp queue, and
|
||||
// - stage two (fullAQueue) for holding the full A tile until it gets
|
||||
// matched with two 4x2 B tiles, and compute is complete.
|
||||
//
|
||||
// Note that the half-tile assembly is unnecessary for B since the B tile is
|
||||
// only 4x2.
|
||||
// Also send the set/step tag along the pipe for alignment check.
|
||||
|
||||
// note combinationally coupled ready with `pipe`
|
||||
val halfAQueue = Module(new Queue(
|
||||
chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true
|
||||
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
|
||||
))
|
||||
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
|
||||
halfAQueue.io.enq.bits := respQueueA.bits.data
|
||||
halfAQueue.io.enq.bits := respQueueA.bits
|
||||
|
||||
// we need the full data for A because we divide the D tile by half along N;
|
||||
// for B, the DPU can immediately start computing with a 4x2 tile.
|
||||
//
|
||||
// substep == 0 data goes to the LSB
|
||||
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits)
|
||||
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data)
|
||||
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
||||
"assumes 2-cycle read for a full compute tile of A")
|
||||
// only use the lower halfA's tag. substep will be incorrect.
|
||||
val fullAEnqTag = halfAQueue.io.deq.bits.tag
|
||||
val fullAQueue = Module(new Queue(
|
||||
chiselTypeOf(fullAEnqData), entries = 1, pipe = true
|
||||
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
|
||||
))
|
||||
// hold first half A data for the first substep
|
||||
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
|
||||
fullAQueue.io.enq.ready
|
||||
|
||||
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
||||
"assumes 2-cycle read for a full compute tile of A")
|
||||
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
|
||||
halfAQueue.io.deq.valid
|
||||
fullAQueue.io.enq.bits := fullAEnqData
|
||||
fullAQueue.io.enq.bits.data := fullAEnqData
|
||||
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
||||
|
||||
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
fullAQueue.io.deq.ready := dpuFire
|
||||
val nextStepExecute = dpuFire
|
||||
val substepCompute = RegInit(0.U(1.W))
|
||||
when (dpuFire) {
|
||||
substepCompute := substepCompute + 1.U
|
||||
}
|
||||
|
||||
// FIXME: need to hold A for two cycles!!
|
||||
// hold full A until two-cycle compute is done
|
||||
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||
|
||||
// make sure to dequeue from response queues only when both A and B valid
|
||||
respQueueA.ready := MuxCase(false.B,
|
||||
@@ -264,21 +278,17 @@ class TensorCoreDecoupled(
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
|
||||
// assert that the A and B response queue heads always point to the same
|
||||
// set/step/substep
|
||||
// assert that the DPU is computing with operands of the same set/step
|
||||
//
|
||||
// this assumes that memory responses come back in-order. this might be too
|
||||
// strong an assumption depending on the backing memory
|
||||
def assertAligned = {
|
||||
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
|
||||
when (bothQueueValid && (substepExecute === 0.U)) {
|
||||
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
|
||||
"A and B response queue pointing to different set/steps. " ++
|
||||
when (dpuFire) {
|
||||
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step),
|
||||
"A and B operands are pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
dontTouch(respQueueA.bits.tag)
|
||||
dontTouch(respQueueB.bits.tag)
|
||||
}
|
||||
assertAligned
|
||||
|
||||
|
||||
Reference in New Issue
Block a user