From 99da429cb11ae4ad952907d0f36589f067b2ca94 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 29 Oct 2024 17:51:32 -0700 Subject: [PATCH] tensor: Move C reg access to execute stage for higher util This prevents coupling between C access in frontend & queue freeup in backend. --- .../radiance/core/TensorCoreDecoupled.scala | 242 ++++++++++-------- .../scala/radiance/tile/RadianceTile.scala | 4 +- 2 files changed, 137 insertions(+), 109 deletions(-) diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 7f3c7da..dcaa9f9 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -131,15 +131,15 @@ class TensorCoreDecoupled( val indexBits = log2Ceil(numIndices) val lastIndex = (1 << indexBits) - 1 - object AccessorState extends ChiselEnum { + object State extends ChiselEnum { val idle = Value(0.U) - val access = Value(1.U) + val run = Value(1.U) // All set/step sequencing is complete and the tensor core is holding the // result data until downstream writeback is ready. // FIXME: is this necessary if writeback is decoupled with queues? val finish = Value(2.U) } - val state = RegInit(AccessorState.idle) + val state = RegInit(State.idle) val allReqsDone = WireInit(false.B) dontTouch(allReqsDone) @@ -159,7 +159,7 @@ class TensorCoreDecoupled( dontTouch(stateA) dontTouch(stateB) - io.initiate.ready := (state === AccessorState.idle) + io.initiate.ready := (state === State.idle) when (io.initiate.fire) { warpAccess := io.initiate.bits.wid addrAAccess := io.initiate.bits.addressA @@ -170,19 +170,19 @@ class TensorCoreDecoupled( } switch(state) { - is(AccessorState.idle) { + is(State.idle) { when(io.initiate.fire) { - state := AccessorState.access + state := State.run } } - is(AccessorState.access) { + is(State.run) { when (allReqsDone) { - state := AccessorState.finish + state := State.finish } } - is(AccessorState.finish) { + is(State.finish) { // FIXME: is finish state needed? - state := AccessorState.idle + state := State.idle } } @@ -232,9 +232,8 @@ class TensorCoreDecoupled( val doneReqA = RegInit(false.B) val doneReqB = RegInit(false.B) - val doneReqC = RegInit(false.B) - val genReqA = (state === AccessorState.access) && !doneReqA - val genReqB = (state === AccessorState.access) && !doneReqB + val genReqA = (state === State.run) && !doneReqA + val genReqB = (state === State.run) && !doneReqB // Request generation // @@ -274,105 +273,21 @@ class TensorCoreDecoupled( } } - // C access from regfile - // - - // regfile is fast; don't need a deep response queue - val respQueueCDepth = 2 - val respQueueC = Module(new Queue( - new Bundle { - val tag = new Bundle { - val warp = UInt(numWarpBits.W) - val set = UInt(setBits.W) - val step = UInt(stepBits.W) - } - val data = UInt(io.respC.widthOption.get.W) - }, - respQueueCDepth - )) - - // because C reg IO arrives 1 cycle later, it will get latched onto the - // response queue 2 cycles later. Make sure there's at least two entries of - // space at the queue before requesting the reg value - // FIXME: doesn't feel very clean - require(respQueueC.entries >= 2) - val hasSpace = - respQueueC.io.count <= Mux(respQueueC.io.deq.fire, - (respQueueC.entries - 1).U, - (respQueueC.entries - 2).U) - val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC - - // set/step state of the C accumulator value that will be latched ath the - // next cycle. - val setAccessC = RegInit(0.U(setBits.W)) - val stepAccessC = RegInit(0.U(stepBits.W)) - val substepAccessC = RegInit(0.U(1.W)) - val nextStepAccessC = genReqC && (substepAccessC === 1.U) - when (genReqC) { - substepAccessC := substepAccessC + 1.U - } - dontTouch(stepAccessC) - dontTouch(substepAccessC) - dontTouch(nextStepAccessC) - - // give 1-cycle delay to sync valid/metadata with C regfile response - val respCValid = RegNext(genReqC) - val warpAccessCDelayed = RegNext(warpAccess) - val setAccessCDelayed = RegNext(setAccessC) - val stepAccessCDelayed = RegNext(stepAccessC) - - // note rd is independent to sets - def rdGen(step: UInt, substep: UInt): UInt = { - // each step produces 4x4 output tile, written by 8 threads with 2 regs per - // thread - (step << 1/*2 substeps*/) + substep - } - io.reqC.valid := genReqC - io.reqC.bits := rdGen(stepAccessC, substepAccessC) - - // queue the regfile response to buffers - // these strictly belong to the execute stage - respQueueC.io.enq.valid := respCValid - respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed - respQueueC.io.enq.bits.tag.set := setAccessCDelayed - respQueueC.io.enq.bits.tag.step := stepAccessCDelayed - respQueueC.io.enq.bits.data := io.respC - - // serialize every two C responses into one full 4x4 C tile - val fullC = Module(new FillBuffer( - chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/ - )) - fullC.io.enq.valid := respQueueC.io.deq.valid - fullC.io.enq.bits := respQueueC.io.deq.bits.data - respQueueC.io.deq.ready := fullC.io.enq.ready - val fullCTag = Module(new Queue( - chiselTypeOf(respQueueC.io.deq.bits.tag), - entries = 1, pipe = true - )) - fullCTag.io.enq.valid := respQueueC.io.deq.valid - fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag - // finalize state when everything has been accessed val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) - val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U) when (lastReqA && io.reqA.fire) { doneReqA := true.B } when (lastReqB && io.reqB.fire) { doneReqB := true.B } - when (lastReqC && nextStepAccessC) { doneReqC := true.B } - when (state === AccessorState.finish) { + when (state === State.finish) { doneReqA := false.B doneReqB := false.B - doneReqC := false.B stateA.set := 0.U stateA.index := 0.U stateB.set := 0.U stateB.index := 0.U - setAccessC := 0.U - stepAccessC := 0.U - substepAccessC := 0.U } - allReqsDone := doneReqA && doneReqB && doneReqC + allReqsDone := doneReqA && doneReqB // =========================================================================== // Execute stage @@ -394,13 +309,6 @@ class TensorCoreDecoupled( io.writeback.bits.data.widthOption.get, "response data width does not match the writeback data width") - // FIXME: unnecessary - val substepDeqA = RegInit(0.U(1.W)) - when (respQueueA.fire) { - substepDeqA := substepDeqA + 1.U - } - dontTouch(substepDeqA) - // Stage the operands in a pipeline so that we obtain the full 4x4 tiles // ready for compute. Also send the set/step tag along the pipe for // alignment check. @@ -461,7 +369,125 @@ class TensorCoreDecoupled( fullB.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready - // fullC/fullCTag is instiated at the access stage + // C access from regfile + // + // C access is initiated from the backend because the regfile access latency + // is much lower than the SMEM latency. If C access is done in the frontend, + // the C response queue will fill up and block further run-ahead until the A + // and B smem reqs arrive and backend starts running, causing the frontend to + // stall waiting for remaining C accesses to finish. + + // regfile is fast; don't need a deep response queue + val respQueueCDepth = 3 + val respQueueC = Module(new Queue( + new Bundle { + val tag = new Bundle { + val warp = UInt(numWarpBits.W) + val set = UInt(setBits.W) + val step = UInt(stepBits.W) + } + val data = UInt(io.respC.widthOption.get.W) + }, + respQueueCDepth + )) + + // access state of the C accumulator value + val stateAccessC = RegInit(State.idle) + // note this is different from warpAccess and belongs to execute + val warpAccessC = RegInit(0.U(numWarpBits.W)) + val setAccessC = RegInit(0.U(setBits.W)) + val stepAccessC = RegInit(0.U(stepBits.W)) + val substepAccessC = RegInit(0.U(1.W)) + val genReqC = WireInit(false.B) + val nextStepAccessC = genReqC && (substepAccessC === 1.U) + when (genReqC) { + substepAccessC := substepAccessC + 1.U + } + dontTouch(stepAccessC) + dontTouch(substepAccessC) + + val doneReqC = RegInit(false.B) + val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U) + when (lastReqC && nextStepAccessC) { doneReqC := true.B } + + switch(stateAccessC) { + is(State.idle) { + //arrival of A and B response kicks off the execute backend + when (respQueueA.valid || respQueueB.valid) { + stateAccessC := State.run + when (respQueueA.valid) { + warpAccessC := respQueueA.bits.tag.warp + }.elsewhen (respQueueB.valid) { + warpAccessC := respQueueB.bits.tag.warp + } + } + } + is(State.run) { + when (doneReqC) { + stateAccessC := State.finish + } + } + is(State.finish) { + // FIXME: is finish state needed? + stateAccessC := State.idle + } + } + + when (stateAccessC === State.finish) { + doneReqC := false.B + setAccessC := 0.U + stepAccessC := 0.U + substepAccessC := 0.U + } + + // because C reg IO arrives 1 cycle later, it will get latched onto the + // response queue 2 cycles later. Make sure there's at least two entries of + // space at the queue before requesting the reg value + // FIXME: doesn't feel very clean + require(respQueueC.entries >= 2) + val hasSpace = + respQueueC.io.count <= Mux(respQueueC.io.deq.fire, + (respQueueC.entries - 1).U, + (respQueueC.entries - 2).U) + + genReqC := (stateAccessC === State.run) && hasSpace && !doneReqC + + // give 1-cycle delay to sync valid/metadata with C regfile response + val respCValid = RegNext(genReqC) + val warpAccessCDelayed = RegNext(warpAccessC) + val setAccessCDelayed = RegNext(setAccessC) + val stepAccessCDelayed = RegNext(stepAccessC) + + // note rd is independent to sets + def rdGen(step: UInt, substep: UInt): UInt = { + // each step produces 4x4 output tile, written by 8 threads with 2 regs per + // thread + (step << 1/*2 substeps*/) + substep + } + io.reqC.valid := genReqC + io.reqC.bits := rdGen(stepAccessC, substepAccessC) + + // queue the regfile response to buffers + // these strictly belong to the execute stage + respQueueC.io.enq.valid := respCValid + respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed + respQueueC.io.enq.bits.tag.set := setAccessCDelayed + respQueueC.io.enq.bits.tag.step := stepAccessCDelayed + respQueueC.io.enq.bits.data := io.respC + + // serialize every two C responses into one full 4x4 C tile + val fullC = Module(new FillBuffer( + chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/ + )) + fullC.io.enq.valid := respQueueC.io.deq.valid + fullC.io.enq.bits := respQueueC.io.deq.bits.data + respQueueC.io.deq.ready := fullC.io.enq.ready + val fullCTag = Module(new Queue( + chiselTypeOf(respQueueC.io.deq.bits.tag), + entries = 1, pipe = true + )) + fullCTag.io.enq.valid := respQueueC.io.deq.valid + fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag val fullCBuf = Module(new Queue( new Bundle { @@ -797,6 +823,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) tensor.io.initiate.valid := io.start tensor.io.initiate.bits.wid := 3.U // bogus, static value + tensor.io.initiate.bits.addressA := 0x0.U + tensor.io.initiate.bits.addressB := 0x800.U tensor.io.writeback.ready := true.B io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index a23763c..f24d2a4 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -849,8 +849,8 @@ class RadianceTileModuleImp(outer: RadianceTile) // Instantiate a fake tensor core module to force unique-ification of module // names in the Chisel-generated Verilog. These should be left out for - // synthesis runs, although these will likely be optimized-out if the inputs - // are tied to low. + // synthesis runs, although it's likely they will be optimized-out with all + // inputs tied to low. if (outer.radianceParams.core.tensorCoreDecoupled) { val tensorNumSourceIds = (1 << outer.tensorTagWidth)