diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index e70e59f..206250e 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -313,17 +313,24 @@ class TensorCoreDecoupled( fullAQueue.io.enq.bits.data := fullAEnqData fullAQueue.io.enq.bits.tag := fullAEnqTag - val fillBufB = Module(new FillBuffer( + // serialize every two B responses into one full 4x4 B tile + // FIXME: do the same for A + val fullB = Module(new FillBuffer( chiselTypeOf(respQueueB.bits.data), 2/*substeps*/ )) - fillBufB.io.enq.valid := respQueueB.valid - fillBufB.io.enq.bits := respQueueB.bits.data - respQueueB.ready := fillBufB.io.enq.ready + fullB.io.enq.valid := respQueueB.valid + fullB.io.enq.bits := respQueueB.bits.data + respQueueB.ready := fullB.io.enq.ready + val fullBTag = Module(new Queue( + new TensorMemTag, entries = 1, pipe = true + )) + fullBTag.io.enq.valid := respQueueB.valid + fullBTag.io.enq.bits := respQueueB.bits.tag - val operandsValid = fullAQueue.io.deq.valid && fillBufB.io.deq.valid + val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid val operandA = fullAQueue.io.deq.bits.data val operandATag = fullAQueue.io.deq.bits.tag - val operandB = fillBufB.io.deq.bits + val operandB = fullB.io.deq.bits val dpuReady = Wire(Bool()) val dpuFire = operandsValid && dpuReady val setCompute = fullAQueue.io.deq.bits.tag.set @@ -333,10 +340,6 @@ class TensorCoreDecoupled( substepCompute := substepCompute + 1.U } - // hold full A until two-cycle compute is done - fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U) - val nextStepExecute = dpuFire && (substepCompute === 1.U) - // respQueueA output arbitrates to either halfAQueue or fullAQueue depending // on the substep respQueueA.ready := MuxCase(false.B, @@ -345,12 +348,19 @@ class TensorCoreDecoupled( // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when // we fully iterated a column (M-dimension). val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U - val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask - fillBufB.io.deq.ready := dpuFire && shouldDequeueB + val shouldDequeueB = + ((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) && + (substepCompute === 1.U) + fullB.io.deq.ready := dpuFire && shouldDequeueB + fullBTag.io.deq.ready := dpuFire && shouldDequeueB dontTouch(respQueueA) dontTouch(respQueueB) dontTouch(shouldDequeueB) + // hold full A until two-cycle compute is done + fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U) + val nextStepExecute = dpuFire && (substepCompute === 1.U) + // Assert that the DPU is computing with operands of the same set/step. Note // that the B resp will only have step values multiple of 4 due to reuse. // @@ -359,9 +369,9 @@ class TensorCoreDecoupled( def assertAligned = { val stepMask = (1 << numTilesMBits).U when (dpuFire) { - assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) && + assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) && ((fullAQueue.io.deq.bits.tag.step & stepMask) === - (respQueueB.bits.tag.step & stepMask)), + (fullBTag.io.deq.bits.step & stepMask)), "A and B operands are pointing to different set/steps. " ++ "This might indicate memory response coming back out-of-order.") } @@ -387,7 +397,7 @@ class TensorCoreDecoupled( // operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq // .grouped(4/*k-dim*/).toSeq val operandBDimensional = - operandB(0)/*FIXME!*/.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq + operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq .grouped(4/*k-dim*/).toSeq require(tilingParams.mc * ncSubstep == numLanes, "substep tile size doesn't match writeback throughput")