From 93c9bcc32f5b516f3bd51990ff60e22e0348f409 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 18 Oct 2024 20:12:15 -0700 Subject: [PATCH] tensor: Stage B as well for full throughput --- .../radiance/core/TensorCoreDecoupled.scala | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index deb4dc1..90cb785 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -300,10 +300,13 @@ class TensorCoreDecoupled( // background while the tile is being used for compute. This does come with // capacity overhead. val fullABuf = Module(new Queue( - new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true + new Bundle { + val data = chiselTypeOf(fullA.io.deq.bits) + val tag = new TensorMemTag + }, entries = 1, pipe = true )) fullABuf.io.enq.valid := fullA.io.deq.valid - fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt + fullABuf.io.enq.bits.data := fullA.io.deq.bits fullABuf.io.enq.bits.tag := fullATag.io.deq.bits fullA.io.deq.ready := fullABuf.io.enq.ready fullATag.io.deq.ready := fullABuf.io.enq.ready @@ -322,10 +325,22 @@ class TensorCoreDecoupled( fullBTag.io.enq.valid := respQueueB.valid fullBTag.io.enq.bits := respQueueB.bits.tag - val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid + val fullBBuf = Module(new Queue( + new Bundle { + val data = chiselTypeOf(fullB.io.deq.bits) + val tag = new TensorMemTag + }, entries = 1, pipe = true + )) + fullBBuf.io.enq.valid := fullB.io.deq.valid + fullBBuf.io.enq.bits.data := fullB.io.deq.bits + fullBBuf.io.enq.bits.tag := fullBTag.io.deq.bits + fullB.io.deq.ready := fullBBuf.io.enq.ready + fullBTag.io.deq.ready := fullBBuf.io.enq.ready + + val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid val operandA = fullABuf.io.deq.bits.data val operandATag = fullABuf.io.deq.bits.tag - val operandB = fullB.io.deq.bits + val operandB = fullBBuf.io.deq.bits.data val dpuReady = Wire(Bool()) val dpuFire = operandsValid && dpuReady val setCompute = fullABuf.io.deq.bits.tag.set @@ -335,20 +350,19 @@ class TensorCoreDecoupled( substepCompute := substepCompute + 1.U } + // hold full A until two-cycle compute is done + fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U) // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when // we fully iterated a column (M-dimension). val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U val shouldDequeueB = ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && (substepCompute === 1.U) - fullB.io.deq.ready := dpuFire && shouldDequeueB - fullBTag.io.deq.ready := dpuFire && shouldDequeueB + fullBBuf.io.deq.ready := dpuFire && shouldDequeueB dontTouch(respQueueA) dontTouch(respQueueB) dontTouch(shouldDequeueB) - // hold full A until two-cycle compute is done - fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U) // FIXME: this should be nextStepCompute val nextStepExecute = dpuFire && (substepCompute === 1.U) @@ -360,9 +374,9 @@ class TensorCoreDecoupled( def assertAligned = { val stepMask = (1 << numTilesMBits).U when (dpuFire) { - assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) && + assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) && ((fullABuf.io.deq.bits.tag.step & stepMask) === - (fullBTag.io.deq.bits.step & stepMask)), + (fullBBuf.io.deq.bits.tag.step & stepMask)), "A and B operands are pointing to different set/steps. " ++ "This might indicate memory response coming back out-of-order.") } @@ -378,15 +392,12 @@ class TensorCoreDecoupled( )) // operandA is 4x4 in K-major val operandADimensional = - operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq + operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq .grouped(4/*k-dim*/).toSeq require(operandADimensional.length == tilingParams.mc && operandADimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter") - // operandB is 2x4 in K-major - // val operandBDimensional = - // operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq - // .grouped(4/*k-dim*/).toSeq + // select 2x4 subtile out of operandB that is 4x4 in K-major val operandBDimensional = operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq .grouped(4/*k-dim*/).toSeq