tensor: Fix operand alignment in pipelining

This commit is contained in:
Hansung Kim
2024-10-16 22:01:02 -07:00
parent 77dae3e1f9
commit 6cad8edd18

View File

@@ -224,37 +224,51 @@ class TensorCoreDecoupled(
}
dontTouch(substepExecute)
// Do pipelining for the A operand so that we obtain the full 4x4 A tile
// ready for compute. The pipeline is two-stage:
// - stage one (halfAQueue) for assembling the full A tile from half-tiles
// coming from the resp queue, and
// - stage two (fullAQueue) for holding the full A tile until it gets
// matched with two 4x2 B tiles, and compute is complete.
//
// Note that the half-tile assembly is unnecessary for B since the B tile is
// only 4x2.
// Also send the set/step tag along the pipe for alignment check.
// note combinationally coupled ready with `pipe`
val halfAQueue = Module(new Queue(
chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true
chiselTypeOf(respQueueA.bits), entries = 1, pipe = true
))
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
halfAQueue.io.enq.bits := respQueueA.bits.data
halfAQueue.io.enq.bits := respQueueA.bits
// we need the full data for A because we divide the D tile by half along N;
// for B, the DPU can immediately start computing with a 4x2 tile.
//
// substep == 0 data goes to the LSB
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits)
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data)
require(fullAEnqData.widthOption.get == dataWidth * 2,
"assumes 2-cycle read for a full compute tile of A")
// only use the lower halfA's tag. substep will be incorrect.
val fullAEnqTag = halfAQueue.io.deq.bits.tag
val fullAQueue = Module(new Queue(
chiselTypeOf(fullAEnqData), entries = 1, pipe = true
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
))
// hold first half A data for the first substep
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
fullAQueue.io.enq.ready
require(fullAEnqData.widthOption.get == dataWidth * 2,
"assumes 2-cycle read for a full compute tile of A")
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
halfAQueue.io.deq.valid
fullAQueue.io.enq.bits := fullAEnqData
fullAQueue.io.enq.bits.data := fullAEnqData
fullAQueue.io.enq.bits.tag := fullAEnqTag
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
val dpuFire = operandsValid && dpuReady
fullAQueue.io.deq.ready := dpuFire
val nextStepExecute = dpuFire
val substepCompute = RegInit(0.U(1.W))
when (dpuFire) {
substepCompute := substepCompute + 1.U
}
// FIXME: need to hold A for two cycles!!
// hold full A until two-cycle compute is done
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
val nextStepExecute = dpuFire && (substepCompute === 1.U)
// make sure to dequeue from response queues only when both A and B valid
respQueueA.ready := MuxCase(false.B,
@@ -264,21 +278,17 @@ class TensorCoreDecoupled(
dontTouch(respQueueA)
dontTouch(respQueueB)
// assert that the A and B response queue heads always point to the same
// set/step/substep
// assert that the DPU is computing with operands of the same set/step
//
// this assumes that memory responses come back in-order. this might be too
// strong an assumption depending on the backing memory
def assertAligned = {
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
when (bothQueueValid && (substepExecute === 0.U)) {
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
"A and B response queue pointing to different set/steps. " ++
when (dpuFire) {
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step),
"A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
dontTouch(respQueueA.bits.tag)
dontTouch(respQueueB.bits.tag)
}
assertAligned