From c4b5a11fdefbbfbe73b765bb1feece25d2a1d3f1 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 18 Oct 2024 19:54:20 -0700 Subject: [PATCH] tensor: Replace staging logic for A with FillBuffer --- .../radiance/core/TensorCoreDecoupled.scala | 77 ++++++++----------- 1 file changed, 34 insertions(+), 43 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 206250e..deb4dc1 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -272,46 +272,41 @@ class TensorCoreDecoupled( io.writeback.bits.data.widthOption.get, "response data width does not match the writeback data width") + // FIXME: unnecessary val substepDeqA = RegInit(0.U(1.W)) when (respQueueA.fire) { substepDeqA := substepDeqA + 1.U } dontTouch(substepDeqA) - // Do pipelining for the A operand so that we obtain the full 4x4 A tile - // ready for compute. The pipeline is two-stage: - // - stage one (halfAQueue) for assembling the full A tile from half-tiles - // coming from the resp queue, and - // - stage two (fullAQueue) for holding the full A tile until it gets - // matched with two 4x2 B tiles, and compute is complete. - // - // Note that the half-tile assembly is unnecessary for B since the B tile is - // only 4x2. - // Also send the set/step tag along the pipe for alignment check. + // Stage the operands in a pipeline so that we obtain the full 4x4 tiles + // ready for compute. Also send the set/step tag along the pipe for + // alignment check. - // note combinationally coupled ready with `pipe` - val halfAQueue = Module(new Queue( - chiselTypeOf(respQueueA.bits), entries = 1, pipe = true + val fullA = Module(new FillBuffer( + chiselTypeOf(respQueueB.bits.data), 2/*substeps*/ )) - halfAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 0.U) - halfAQueue.io.enq.bits := respQueueA.bits + fullA.io.enq.valid := respQueueA.valid + fullA.io.enq.bits := respQueueA.bits.data + respQueueA.ready := fullA.io.enq.ready + // `pipe` combinationally couples enq-deq ready + val fullATag = Module(new Queue( + new TensorMemTag, entries = 1, pipe = true + )) + fullATag.io.enq.valid := respQueueA.valid + fullATag.io.enq.bits := respQueueA.bits.tag - // substep == 0 data goes to the LSB - val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits.data) - require(fullAEnqData.widthOption.get == dataWidth * 2, - "assumes 2-cycle read for a full compute tile of A") - // only use the lower halfA's tag. substep will be incorrect. - val fullAEnqTag = halfAQueue.io.deq.bits.tag - val fullAQueue = Module(new Queue( + // stage the full A tile once more so that FillBuffer can be filled up in the + // background while the tile is being used for compute. This does come with + // capacity overhead. + val fullABuf = Module(new Queue( new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true )) - // hold first half A data for the first substep - halfAQueue.io.deq.ready := respQueueA.valid && (substepDeqA === 1.U) && - fullAQueue.io.enq.ready - fullAQueue.io.enq.valid := respQueueA.valid && (substepDeqA === 1.U) && - halfAQueue.io.deq.valid - fullAQueue.io.enq.bits.data := fullAEnqData - fullAQueue.io.enq.bits.tag := fullAEnqTag + fullABuf.io.enq.valid := fullA.io.deq.valid + fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt + fullABuf.io.enq.bits.tag := fullATag.io.deq.bits + fullA.io.deq.ready := fullABuf.io.enq.ready + fullATag.io.deq.ready := fullABuf.io.enq.ready // serialize every two B responses into one full 4x4 B tile // FIXME: do the same for A @@ -327,29 +322,24 @@ class TensorCoreDecoupled( fullBTag.io.enq.valid := respQueueB.valid fullBTag.io.enq.bits := respQueueB.bits.tag - val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid - val operandA = fullAQueue.io.deq.bits.data - val operandATag = fullAQueue.io.deq.bits.tag + val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid + val operandA = fullABuf.io.deq.bits.data + val operandATag = fullABuf.io.deq.bits.tag val operandB = fullB.io.deq.bits val dpuReady = Wire(Bool()) val dpuFire = operandsValid && dpuReady - val setCompute = fullAQueue.io.deq.bits.tag.set - val stepCompute = fullAQueue.io.deq.bits.tag.step + val setCompute = fullABuf.io.deq.bits.tag.set + val stepCompute = fullABuf.io.deq.bits.tag.step val substepCompute = RegInit(0.U(1.W)) when (dpuFire) { substepCompute := substepCompute + 1.U } - // respQueueA output arbitrates to either halfAQueue or fullAQueue depending - // on the substep - respQueueA.ready := MuxCase(false.B, - Seq((substepDeqA === 0.U) -> halfAQueue.io.enq.ready, - (substepDeqA === 1.U) -> fullAQueue.io.enq.ready)) // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when // we fully iterated a column (M-dimension). val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U val shouldDequeueB = - ((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) && + ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && (substepCompute === 1.U) fullB.io.deq.ready := dpuFire && shouldDequeueB fullBTag.io.deq.ready := dpuFire && shouldDequeueB @@ -358,7 +348,8 @@ class TensorCoreDecoupled( dontTouch(shouldDequeueB) // hold full A until two-cycle compute is done - fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U) + fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U) + // FIXME: this should be nextStepCompute val nextStepExecute = dpuFire && (substepCompute === 1.U) // Assert that the DPU is computing with operands of the same set/step. Note @@ -369,8 +360,8 @@ class TensorCoreDecoupled( def assertAligned = { val stepMask = (1 << numTilesMBits).U when (dpuFire) { - assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) && - ((fullAQueue.io.deq.bits.tag.step & stepMask) === + assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) && + ((fullABuf.io.deq.bits.tag.step & stepMask) === (fullBTag.io.deq.bits.step & stepMask)), "A and B operands are pointing to different set/steps. " ++ "This might indicate memory response coming back out-of-order.")