From 43e064fe82fa5727ac439b465002b2a09c6786ff Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 25 Oct 2024 15:22:52 -0700 Subject: [PATCH] tensor: Add access logic for C from regfile --- .../radiance/core/TensorCoreDecoupled.scala | 96 +++++++++++++++++-- 1 file changed, 89 insertions(+), 7 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index c42dc29..535dbdd 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -51,8 +51,10 @@ class TensorCoreDecoupled( }) val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) + val respC = Input(UInt(dataWidth.W)) val reqA = Decoupled(new TensorMemReq(sourceWidth)) val reqB = Decoupled(new TensorMemReq(sourceWidth)) + val reqC = Output(Valid(UInt(numFPRegBits.W))) }) dontTouch(io) @@ -131,9 +133,7 @@ class TensorCoreDecoupled( val stateA = RegInit(stateInit) val stateB = RegInit(stateInit) dontTouch(stateA) - dontTouch(stateA.index) dontTouch(stateB) - dontTouch(stateB.index) io.initiate.ready := (state === AccessorState.idle) when (io.initiate.fire) { @@ -262,6 +262,48 @@ class TensorCoreDecoupled( } } + // C access from regfile + // + + // since regfile is fixed-latency, respC valid should be determined at the + // request sending side. + val respCValid = RegInit(false.B) + + // regfile latency is 1 cycle; don't need a deep response queue + val respQueueCDepth = 1 + val respQueueC = Module(new Queue( + chiselTypeOf(io.respC), respQueueCDepth + )) + respQueueC.io.enq.valid := respCValid + respQueueC.io.enq.bits := io.respC + + // serialize every two C responses into one full 4x4 C tile + val fullC = Module(new FillBuffer( + chiselTypeOf(io.respC), 2/*substeps*/ + )) + fullC.io.enq.valid := respQueueC.io.deq.valid + fullC.io.enq.bits := respQueueC.io.deq.bits + respQueueC.io.deq.ready := fullC.io.enq.ready + + // make sure there's space at the response queue to be latched at the next + // cycle + val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready + // 1-cycle delay + respCValid := genReqC + + io.reqC.valid := genReqC + io.reqC.bits := 5.U // FIXME + + // set/index state of the C accumulator value that will be latched ath the + // next cycle. + val stateRegC = RegInit(stateInit) + when (genReqC) { + when (stateRegC.index === lastIndex.U) { + stateRegC.set := stateRegC.set + 1.U + } + stateRegC.index := stateRegC.index + 1.U + } + // =========================================================================== // Execute stage // =========================================================================== @@ -349,9 +391,31 @@ class TensorCoreDecoupled( fullB.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready + // fullC is instiated at the access stage + + val fullCTag = Module(new Queue( + new TensorMemTag, entries = 1, pipe = true + )) + fullCTag.io.enq.valid := respQueueB.valid + fullCTag.io.enq.bits := respQueueB.bits.tag + + val fullCBuf = Module(new Queue( + new Bundle { + val data = chiselTypeOf(fullC.io.deq.bits) + val tag = new TensorMemTag + }, entries = 1, pipe = true + )) + fullCBuf.io.enq.valid := fullC.io.deq.valid + fullCBuf.io.enq.bits.data := fullC.io.deq.bits + fullCBuf.io.enq.bits.tag := fullCTag.io.deq.bits + fullC.io.deq.ready := fullCBuf.io.enq.ready + fullCTag.io.deq.ready := fullCBuf.io.enq.ready + val dpuReady = Wire(Bool()) + val dpuFire = Wire(Bool()) val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid - val dpuFire = operandsValid && dpuReady + dpuFire := operandsValid && dpuReady + dontTouch(dpuFire) val setCompute = RegInit(0.U(setBits.W)) val stepCompute = RegInit(0.U(stepBits.W)) @@ -376,11 +440,14 @@ class TensorCoreDecoupled( } val operandA = selectOperandA(fullABuf.io.deq.bits.data) val operandATag = fullABuf.io.deq.bits.tag - // select the correct 2x4 tile from B operand buffer + // select the correct 2x4 tile from B/C operand buffer val operandB = fullBBuf.io.deq.bits.data(substepCompute) val operandBTag = fullBBuf.io.deq.bits.tag + val operandC = fullCBuf.io.deq.bits.data(substepCompute) + val operandCTag = fullCBuf.io.deq.bits.tag dontTouch(operandATag) dontTouch(operandBTag) + dontTouch(operandCTag) // Operand buffer logic // @@ -397,6 +464,10 @@ class TensorCoreDecoupled( ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && (substepCompute === 1.U) fullBBuf.io.deq.ready := dpuFire && shouldDequeueB + + // C buf should be synced with B buf + fullCBuf.io.deq.ready := dpuFire && shouldDequeueB + dontTouch(respQueueA) dontTouch(respQueueB) dontTouch(shouldDequeueA) @@ -414,6 +485,10 @@ class TensorCoreDecoupled( operandATag.set === operandBTag.set, "A and B operands are pointing to different warps and sets. " ++ "This might indicate memory response coming back out-of-order.") + assert(operandATag.warp === operandCTag.warp && + operandATag.set === operandCTag.set, + "A and C operands are pointing to different warps and sets. " ++ + "This might indicate memory response coming back out-of-order.") assert(operandATag.set === setCompute, "Operand arrived from memory is pointing at a different set than the FSM.") } @@ -422,7 +497,7 @@ class TensorCoreDecoupled( // Dot-product unit // - // 4x2 four-element DPUs summing up to 32 MACs in total + // 4x2 four-element DPUs summing up to 32 FP32 MACs in total // val ncSubstep = tilingParams.nc / 2 require(tilingParams.mc * ncSubstep == numLanes, @@ -444,13 +519,18 @@ class TensorCoreDecoupled( require(operandBDimensional.length == ncSubstep && operandBDimensional(0).length == tilingParams.kc, "operand width doesn't agree with tiling parameter") + // note operand C is M-major + val operandCDimensional = reshapeByFourWords(operandC) + require(operandCDimensional.length == ncSubstep && + operandCDimensional(0).length == tilingParams.mc, + "operand width doesn't agree with tiling parameter") for (m <- 0 until tilingParams.mc) { for (n <- 0 until ncSubstep) { dpus(m)(n).io.in.valid := dpuFire dpus(m)(n).io.in.bits.a := operandADimensional(m) dpus(m)(n).io.in.bits.b := operandBDimensional(n) - dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data + dpus(m)(n).io.in.bits.c := operandCDimensional(n)(m) // dpu ready couples with writeback backpressure dpus(m)(n).io.stall := !io.writeback.ready } @@ -631,8 +711,10 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) tensor.io.respB.bits.source := tlOutB.d.bits.source tlOutB.d.ready := tensor.io.respB.ready + tensor.io.respC := 42.U // FIXME bogus + tensor.io.initiate.valid := io.start - tensor.io.initiate.bits.wid := 0.U // TODO + tensor.io.initiate.bits.wid := 0.U // FIXME bogus tensor.io.writeback.ready := true.B io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last