From c0292dd0aa97a8cec3d034b20e2a167f78e54af8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 18 Oct 2024 21:51:34 -0700 Subject: [PATCH] tensor: Enlarge operand buffer for A for better SMEM reuse --- .../radiance/core/TensorCoreDecoupled.scala | 159 +++++++++++------- 1 file changed, 100 insertions(+), 59 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 90cb785..fa3f6e9 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -146,18 +146,6 @@ class TensorCoreDecoupled( // Memory traffic generation // ------------------------- // - class TensorMemTag extends Bundle { - val set = UInt(setBits.W) - val step = UInt(stepBits.W) - val substep = UInt(1.W) - } - // use concatenation of set/step as the memory request source. This will get - // translated to the actual TL sourcewidth in sourceGen. - val tag = Wire(new TensorMemTag) - tag.set := setAccess - tag.step := stepAccess - tag.substep := substepAccess - val numTilesM = tilingParams.m / tilingParams.mc val numTilesN = tilingParams.n / tilingParams.nc // @cleanup: generalize in terms of M/N/K-majorness? @@ -198,12 +186,41 @@ class TensorCoreDecoupled( val (addressA, addressB) = addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) + // 'index' is the index of a memory request among the sequence of requests + // needed to read a full M-column of A or N-row of B. Its range is [0,m/2) + // or [0,n/2), where 2 is the stride can be read in a single request size. + require(tilingParams.m == tilingParams.n, + "currently only supports square SMEM tile") + val numIndices = tilingParams.m / 2 + val indexBits = log2Ceil(numIndices) + val lastIndex = (1 << indexBits) - 1 + + class TensorMemTag extends Bundle { + val set = UInt(setBits.W) + val index = UInt(indexBits.W) + } + + val tagInit = Wire(new TensorMemTag) + tagInit.set := 0.U + tagInit.index := 0.U + val tagA = RegInit(tagInit) + val tagB = RegInit(tagInit) + + when (io.reqA.fire) { + when (tagA.index === lastIndex.U) { + tagA.set := tagA.set + 1.U + } + tagA.index := tagA.index + 1.U + } + when (io.reqB.fire) { + when (tagB.index === lastIndex.U) { + tagB.set := tagB.set + 1.U + } + tagB.index := tagB.index + 1.U + } + val genReqA = (state === TensorState.run) - val numTilesMBits = log2Ceil(numTilesM) - // generate B request at every 4 steps. B achieves reuse through outer - // product so it doesn't require access at every step - val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U - val genReqB = (state === TensorState.run) && shouldFireB + val genReqB = (state === TensorState.run) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) @@ -212,11 +229,11 @@ class TensorCoreDecoupled( case ((req, (resp, respTagged)), i) => { val sourceGen = Module(new SourceGenerator( log2Ceil(numSourceIds), - metadata = Some(tag) + metadata = Some(new TensorMemTag) )) sourceGen.io.gen := req.fire - sourceGen.io.meta := tag + sourceGen.io.meta := (if (i == 0) tagA else tagB) req.valid := (if (i == 0) genReqA else genReqB) req.bits.address := (if (i == 0) addressA else addressB) req.bits.source := sourceGen.io.id.bits @@ -243,7 +260,7 @@ class TensorCoreDecoupled( val firedBNow = io.reqB.fire val firedA = firedAReg || firedANow val firedB = firedBReg || firedBNow - val nextSubstepAccess = firedA && (!shouldFireB || firedB) + val nextSubstepAccess = firedA && firedB val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U) // clear out firedABReg every substep when (nextSubstepAccess) { @@ -252,17 +269,12 @@ class TensorCoreDecoupled( substepAccess := substepAccess + 1.U } require(substepAccess.widthOption.get == 1, "there should be only two substeps") - dontTouch(shouldFireB) // Execute stage // ------------- // Backend of the decoupled access/execute pipeline. // // set and step being currently executed in the acc/ex backend - val setExecute = RegInit(0.U(setBits.W)) - val stepExecute = RegInit(0.U(stepBits.W)) - dontTouch(setExecute) - dontTouch(stepExecute) val respQueueDepth = 4 // FIXME: parameterize val respQueueA = Queue(respATagged, respQueueDepth) @@ -283,8 +295,10 @@ class TensorCoreDecoupled( // ready for compute. Also send the set/step tag along the pipe for // alignment check. + // @cleanup: dedup A and B below + val fullA = Module(new FillBuffer( - chiselTypeOf(respQueueB.bits.data), 2/*substeps*/ + chiselTypeOf(respQueueB.bits.data), numIndices )) fullA.io.enq.valid := respQueueA.valid fullA.io.enq.bits := respQueueA.bits.data @@ -337,23 +351,48 @@ class TensorCoreDecoupled( fullB.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready - val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid - val operandA = fullABuf.io.deq.bits.data - val operandATag = fullABuf.io.deq.bits.tag - val operandB = fullBBuf.io.deq.bits.data val dpuReady = Wire(Bool()) + val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid val dpuFire = operandsValid && dpuReady - val setCompute = fullABuf.io.deq.bits.tag.set - val stepCompute = fullABuf.io.deq.bits.tag.step + + val setCompute = RegInit(0.U(setBits.W)) + val stepCompute = RegInit(0.U(stepBits.W)) val substepCompute = RegInit(0.U(1.W)) + val nextStepCompute = dpuFire && (substepCompute === 1.U) + dontTouch(setCompute) + dontTouch(stepCompute) + dontTouch(substepCompute) when (dpuFire) { substepCompute := substepCompute + 1.U } - // hold full A until two-cycle compute is done - fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U) - // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when - // we fully iterated a column (M-dimension). + // Operand selection + // + // select the correct 4x4 tile from A operand buffer + val numTilesMBits = log2Ceil(numTilesM) + def selectOperandA(buf: Vec[UInt]): UInt = { + require(buf.length == numIndices) + val stepM = stepCompute & ((1 << numTilesMBits) - 1).U + Cat(buf((stepM << 1) + 1.U), buf(stepM << 1)) + } + val operandA = selectOperandA(fullABuf.io.deq.bits.data) + val operandATag = fullABuf.io.deq.bits.tag + // select the correct 2x4 tile from B operand buffer + val operandB = fullBBuf.io.deq.bits.data(substepCompute) + val operandBTag = fullBBuf.io.deq.bits.tag + dontTouch(operandATag) + dontTouch(operandBTag) + + // Operand buffer dequeue logic + // + // hold A data until the entire set is done + val shouldDequeueAMask = ((1 << stepBits) - 1).U + val shouldDequeueA = + ((stepCompute & shouldDequeueAMask) === shouldDequeueAMask) && + (substepCompute === 1.U) + fullABuf.io.deq.ready := dpuFire && shouldDequeueA + // hold B tile at respQueueB for multiple steps for reuse, only dequeue when + // we fully iterated a column (M-dimension) val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U val shouldDequeueB = ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && @@ -361,11 +400,9 @@ class TensorCoreDecoupled( fullBBuf.io.deq.ready := dpuFire && shouldDequeueB dontTouch(respQueueA) dontTouch(respQueueB) + dontTouch(shouldDequeueA) dontTouch(shouldDequeueB) - // FIXME: this should be nextStepCompute - val nextStepExecute = dpuFire && (substepCompute === 1.U) - // Assert that the DPU is computing with operands of the same set/step. Note // that the B resp will only have step values multiple of 4 due to reuse. // @@ -374,11 +411,9 @@ class TensorCoreDecoupled( def assertAligned = { val stepMask = (1 << numTilesMBits).U when (dpuFire) { - assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) && - ((fullABuf.io.deq.bits.tag.step & stepMask) === - (fullBBuf.io.deq.bits.tag.step & stepMask)), - "A and B operands are pointing to different set/steps. " ++ - "This might indicate memory response coming back out-of-order.") + assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set, + "A and B operands are pointing to different sets. " ++ + "This might indicate memory response coming back out-of-order.") } } assertAligned @@ -386,23 +421,24 @@ class TensorCoreDecoupled( // Dot-product unit // // 4x2 four-element DPUs summing up to 32 MACs in total + // val ncSubstep = tilingParams.nc / 2 + require(tilingParams.mc * ncSubstep == numLanes, + "substep tile size doesn't match writeback throughput") val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)( Module(new TensorDotProductUnit(half = false)) )) - // operandA is 4x4 in K-major - val operandADimensional = - operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq - .grouped(4/*k-dim*/).toSeq + + // reshape operands for easier routing to DPU + def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = { + x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq + .grouped(4/*k-dim*/).toSeq + } + val operandADimensional = reshapeByFourWords(operandA) require(operandADimensional.length == tilingParams.mc && operandADimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter") - // select 2x4 subtile out of operandB that is 4x4 in K-major - val operandBDimensional = - operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq - .grouped(4/*k-dim*/).toSeq - require(tilingParams.mc * ncSubstep == numLanes, - "substep tile size doesn't match writeback throughput") + val operandBDimensional = reshapeByFourWords(operandB) require(operandBDimensional.length == ncSubstep && operandBDimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter") @@ -444,12 +480,17 @@ class TensorCoreDecoupled( // ---------------- // These queues hold metadata needed for writeback in sync with the DPU. + class TensorComputeTag extends Bundle { + val set = UInt(setBits.W) + val step = UInt(stepBits.W) + val substep = UInt(1.W) + } + val queueDepth = 5 // needs to be at least the DPU latency - val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth)) + val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth)) tagQueue.io.enq.valid := dpuFire - // A and B should have the same tags - tagQueue.io.enq.bits := operandATag - // @cleanup: awkward + tagQueue.io.enq.bits.set := setCompute + tagQueue.io.enq.bits.step := stepCompute tagQueue.io.enq.bits.substep := substepCompute tagQueue.io.deq.ready := io.writeback.fire assert(tagQueue.io.enq.ready === true.B, @@ -490,7 +531,7 @@ class TensorCoreDecoupled( } } sequenceSetStep(setAccess, stepAccess, nextStepAccess) - sequenceSetStep(setExecute, stepExecute, nextStepExecute) + sequenceSetStep(setCompute, stepCompute, nextStepCompute) switch(state) { is(TensorState.idle) {