diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index fa3f6e9..ed241b5 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -82,15 +82,6 @@ class TensorCoreDecoupled( // This drives the overall pipeline of memory requests, dot-product unit // operations and regfile writeback. - object TensorState extends ChiselEnum { - val idle = Value(0.U) - val run = Value(1.U) - // All set/step sequencing is complete and the tensor core is holding the - // result data until downstream writeback is ready. - // FIXME: is this necessary if writeback is decoupled with queues? - val finish = Value(2.U) - } - val state = RegInit(TensorState.idle) val busy = RegInit(false.B) // Holds the warp id the core is currently working on. Note that we only // support one outstanding warp request @@ -107,22 +98,10 @@ class TensorCoreDecoupled( def setDone(set: UInt) = (set === lastSet.U) def stepDone(step: UInt) = (step === lastStep.U) - // set and step being currently accessed in the acc/ex frontend - val setAccess = RegInit(0.U(setBits.W)) - val stepAccess = RegInit(0.U(stepBits.W)) - // we need full 4x4 A tile to fire DPU, but since the memory width is 8 - // words, we need 2 cycles to read A. `substep` tells which cycle we're at. - val substepAccess = RegInit(0.U(1.W)) - dontTouch(setAccess) - dontTouch(stepAccess) - dontTouch(substepAccess) - - when(io.initiate.fire) { + when (io.initiate.fire) { val wid = io.initiate.bits.wid busy := true.B warpReg := wid - setAccess := 0.U - stepAccess := 0.U when(io.writeback.fire) { assert( io.writeback.bits.wid =/= wid, @@ -143,55 +122,51 @@ class TensorCoreDecoupled( // serialize every HGMMA request io.initiate.ready := !busy - // Memory traffic generation - // ------------------------- + // =========================================================================== + // Access stage + // =========================================================================== // - val numTilesM = tilingParams.m / tilingParams.mc - val numTilesN = tilingParams.n / tilingParams.nc - // @cleanup: generalize in terms of M/N/K-majorness? - def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt) - : (UInt/*A*/, UInt/*B*/) = { - // note that step iterates along N first, then M - val tileM = step % numTilesM.U - val tileN = step / numTilesM.U + // Frontend of the decoupled access/execute pipeline. - // note that both A and B are K-major to facilitate bank conflict-free SMEM - // accesses - // - // (row,col) coordinate of the compute tile - val tileRowA = tileM // M - val tileColA = set // K - val tileRowB = tileN // N - val tileColB = set // K - // (row,col) coordinate of the starting element of the compute tile - val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) + - (substep << log2Ceil(tilingParams.mc / 2)) - val elemColA = tileColA << log2Ceil(tilingParams.kc) - val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) + - (substep << log2Ceil(tilingParams.nc / 2)) - val elemColB = tileColB << log2Ceil(tilingParams.kc) - val rowStrideA = wordSize * tilingParams.k - val rowStrideABits = log2Ceil(rowStrideA) - val rowStrideB = wordSize * tilingParams.k - val rowStrideBBits = log2Ceil(rowStrideB) - val wordStrideBits = log2Ceil(wordSize) - - val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits) - val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits) - - (baseA + tileOffsetA, baseB + tileOffsetB) + // States + // + object AccessorState extends ChiselEnum { + val idle = Value(0.U) + val access = Value(1.U) + // All set/step sequencing is complete and the tensor core is holding the + // result data until downstream writeback is ready. + // FIXME: is this necessary if writeback is decoupled with queues? + val finish = Value(2.U) } + val state = RegInit(AccessorState.idle) + val allReqsDone = WireInit(false.B) + dontTouch(allReqsDone) - // FIXME: bogus base address - val (addressA, addressB) = - addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess) + switch(state) { + is(AccessorState.idle) { + when(io.initiate.fire) { + state := AccessorState.access + } + } + is(AccessorState.access) { + when (allReqsDone) { + state := AccessorState.finish + } + } + is(AccessorState.finish) { + // FIXME: decouple writeback + when(io.writeback.fire) { + state := AccessorState.idle + } + } + } // 'index' is the index of a memory request among the sequence of requests // needed to read a full M-column of A or N-row of B. Its range is [0,m/2) // or [0,n/2), where 2 is the stride can be read in a single request size. require(tilingParams.m == tilingParams.n, "currently only supports square SMEM tile") - val numIndices = tilingParams.m / 2 + val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/ val indexBits = log2Ceil(numIndices) val lastIndex = (1 << indexBits) - 1 @@ -219,9 +194,51 @@ class TensorCoreDecoupled( tagB.index := tagB.index + 1.U } - val genReqA = (state === TensorState.run) - val genReqB = (state === TensorState.run) + // Address generation + // + def addressGen(base: UInt, set: UInt, index: UInt): UInt = { + // note that both A and B are K-major to facilitate bank conflict-free SMEM + // accesses, so that below code applies to both. + // + // (row,col) coordinate of the compute tile + val tileRow = index + val tileCol = set + // (row,col) coordinate of the starting element of the compute tile + val elemRow = index << 1 + val elemCol = tileCol << log2Ceil(tilingParams.kc) + val rowStride = tilingParams.k * wordSize + val rowStrideBits = log2Ceil(rowStride) + val wordStrideBits = log2Ceil(wordSize) + val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits) + base + tileOffset + } + + // FIXME: bogus base address + val addressA = addressGen(0.U, tagA.set, tagA.index) + val addressB = addressGen(0.U, tagB.set, tagB.index) + + val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U) + val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U) + val doneReqA = RegInit(false.B) + val doneReqB = RegInit(false.B) + when (lastReqA && io.reqA.fire) { doneReqA := true.B } + when (lastReqB && io.reqB.fire) { doneReqB := true.B } + val genReqA = (state === AccessorState.access) && !doneReqA + val genReqB = (state === AccessorState.access) && !doneReqA + when (state === AccessorState.finish) { + doneReqA := false.B + doneReqB := false.B + tagA.set := 0.U + tagA.index := 0.U + tagB.set := 0.U + tagB.index := 0.U + } + + allReqsDone := doneReqA && doneReqB + + // Request generation + // val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) Seq((io.reqA, (io.respA, respATagged)), @@ -249,34 +266,13 @@ class TensorCoreDecoupled( } } - // only advance to the next step if we fired mem requests for both A and B. - // also consider that B doesn't have to be fired every time due to reuse. - // @perf: too strict? should be able to have A and B progress separately - val firedAReg = RegInit(false.B) - val firedBReg = RegInit(false.B) - when (io.reqA.fire) { firedAReg := true.B } - when (io.reqB.fire) { firedBReg := true.B } - val firedANow = io.reqA.fire - val firedBNow = io.reqB.fire - val firedA = firedAReg || firedANow - val firedB = firedBReg || firedBNow - val nextSubstepAccess = firedA && firedB - val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U) - // clear out firedABReg every substep - when (nextSubstepAccess) { - firedAReg := false.B - firedBReg := false.B - substepAccess := substepAccess + 1.U - } - require(substepAccess.widthOption.get == 1, "there should be only two substeps") - + // =========================================================================== // Execute stage - // ------------- + // =========================================================================== + // // Backend of the decoupled access/execute pipeline. // - // set and step being currently executed in the acc/ex backend - - val respQueueDepth = 4 // FIXME: parameterize + val respQueueDepth = 8 // FIXME: parameterize val respQueueA = Queue(respATagged, respQueueDepth) val respQueueB = Queue(respBTagged, respQueueDepth) @@ -369,6 +365,7 @@ class TensorCoreDecoupled( // Operand selection // // select the correct 4x4 tile from A operand buffer + val numTilesM = tilingParams.m / tilingParams.mc val numTilesMBits = log2Ceil(numTilesM) def selectOperandA(buf: Vec[UInt]): UInt = { require(buf.length == numIndices) @@ -383,7 +380,7 @@ class TensorCoreDecoupled( dontTouch(operandATag) dontTouch(operandBTag) - // Operand buffer dequeue logic + // Operand buffer logic // // hold A data until the entire set is done val shouldDequeueAMask = ((1 << stepBits) - 1).U @@ -476,8 +473,8 @@ class TensorCoreDecoupled( } io.writeback.bits.data := flattenedDPUOut - // Writeback queues - // ---------------- + // Writeback logic + // // These queues hold metadata needed for writeback in sync with the DPU. class TensorComputeTag extends Bundle { @@ -530,28 +527,7 @@ class TensorCoreDecoupled( } } } - sequenceSetStep(setAccess, stepAccess, nextStepAccess) sequenceSetStep(setCompute, stepCompute, nextStepCompute) - - switch(state) { - is(TensorState.idle) { - when(io.initiate.fire) { - state := TensorState.run - } - } - is(TensorState.run) { - when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) { - when (state === TensorState.run) { - state := TensorState.finish - } - } - } - is(TensorState.finish) { - when(io.writeback.fire) { - state := TensorState.idle - } - } - } } // A buffer that collects multiple entries of input data and exposes the