tensor: Move C reg access to execute stage for higher util

This prevents coupling between C access in frontend & queue freeup in
backend.
This commit is contained in:
Hansung Kim
2024-10-29 17:51:32 -07:00
parent 37b8b6470b
commit 99da429cb1
2 changed files with 137 additions and 109 deletions

View File

@@ -131,15 +131,15 @@ class TensorCoreDecoupled(
val indexBits = log2Ceil(numIndices) val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1 val lastIndex = (1 << indexBits) - 1
object AccessorState extends ChiselEnum { object State extends ChiselEnum {
val idle = Value(0.U) val idle = Value(0.U)
val access = Value(1.U) val run = Value(1.U)
// All set/step sequencing is complete and the tensor core is holding the // All set/step sequencing is complete and the tensor core is holding the
// result data until downstream writeback is ready. // result data until downstream writeback is ready.
// FIXME: is this necessary if writeback is decoupled with queues? // FIXME: is this necessary if writeback is decoupled with queues?
val finish = Value(2.U) val finish = Value(2.U)
} }
val state = RegInit(AccessorState.idle) val state = RegInit(State.idle)
val allReqsDone = WireInit(false.B) val allReqsDone = WireInit(false.B)
dontTouch(allReqsDone) dontTouch(allReqsDone)
@@ -159,7 +159,7 @@ class TensorCoreDecoupled(
dontTouch(stateA) dontTouch(stateA)
dontTouch(stateB) dontTouch(stateB)
io.initiate.ready := (state === AccessorState.idle) io.initiate.ready := (state === State.idle)
when (io.initiate.fire) { when (io.initiate.fire) {
warpAccess := io.initiate.bits.wid warpAccess := io.initiate.bits.wid
addrAAccess := io.initiate.bits.addressA addrAAccess := io.initiate.bits.addressA
@@ -170,19 +170,19 @@ class TensorCoreDecoupled(
} }
switch(state) { switch(state) {
is(AccessorState.idle) { is(State.idle) {
when(io.initiate.fire) { when(io.initiate.fire) {
state := AccessorState.access state := State.run
} }
} }
is(AccessorState.access) { is(State.run) {
when (allReqsDone) { when (allReqsDone) {
state := AccessorState.finish state := State.finish
} }
} }
is(AccessorState.finish) { is(State.finish) {
// FIXME: is finish state needed? // FIXME: is finish state needed?
state := AccessorState.idle state := State.idle
} }
} }
@@ -232,9 +232,8 @@ class TensorCoreDecoupled(
val doneReqA = RegInit(false.B) val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B) val doneReqB = RegInit(false.B)
val doneReqC = RegInit(false.B) val genReqA = (state === State.run) && !doneReqA
val genReqA = (state === AccessorState.access) && !doneReqA val genReqB = (state === State.run) && !doneReqB
val genReqB = (state === AccessorState.access) && !doneReqB
// Request generation // Request generation
// //
@@ -274,105 +273,21 @@ class TensorCoreDecoupled(
} }
} }
// C access from regfile
//
// regfile is fast; don't need a deep response queue
val respQueueCDepth = 2
val respQueueC = Module(new Queue(
new Bundle {
val tag = new Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
}
val data = UInt(io.respC.widthOption.get.W)
},
respQueueCDepth
))
// because C reg IO arrives 1 cycle later, it will get latched onto the
// response queue 2 cycles later. Make sure there's at least two entries of
// space at the queue before requesting the reg value
// FIXME: doesn't feel very clean
require(respQueueC.entries >= 2)
val hasSpace =
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
(respQueueC.entries - 1).U,
(respQueueC.entries - 2).U)
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
// set/step state of the C accumulator value that will be latched ath the
// next cycle.
val setAccessC = RegInit(0.U(setBits.W))
val stepAccessC = RegInit(0.U(stepBits.W))
val substepAccessC = RegInit(0.U(1.W))
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
when (genReqC) {
substepAccessC := substepAccessC + 1.U
}
dontTouch(stepAccessC)
dontTouch(substepAccessC)
dontTouch(nextStepAccessC)
// give 1-cycle delay to sync valid/metadata with C regfile response
val respCValid = RegNext(genReqC)
val warpAccessCDelayed = RegNext(warpAccess)
val setAccessCDelayed = RegNext(setAccessC)
val stepAccessCDelayed = RegNext(stepAccessC)
// note rd is independent to sets
def rdGen(step: UInt, substep: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
// thread
(step << 1/*2 substeps*/) + substep
}
io.reqC.valid := genReqC
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
// queue the regfile response to buffers
// these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile
val fullC = Module(new FillBuffer(
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
))
fullC.io.enq.valid := respQueueC.io.deq.valid
fullC.io.enq.bits := respQueueC.io.deq.bits.data
respQueueC.io.deq.ready := fullC.io.enq.ready
val fullCTag = Module(new Queue(
chiselTypeOf(respQueueC.io.deq.bits.tag),
entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
// finalize state when everything has been accessed // finalize state when everything has been accessed
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U) val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U) val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
when (lastReqA && io.reqA.fire) { doneReqA := true.B } when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B } when (lastReqB && io.reqB.fire) { doneReqB := true.B }
when (lastReqC && nextStepAccessC) { doneReqC := true.B } when (state === State.finish) {
when (state === AccessorState.finish) {
doneReqA := false.B doneReqA := false.B
doneReqB := false.B doneReqB := false.B
doneReqC := false.B
stateA.set := 0.U stateA.set := 0.U
stateA.index := 0.U stateA.index := 0.U
stateB.set := 0.U stateB.set := 0.U
stateB.index := 0.U stateB.index := 0.U
setAccessC := 0.U
stepAccessC := 0.U
substepAccessC := 0.U
} }
allReqsDone := doneReqA && doneReqB && doneReqC allReqsDone := doneReqA && doneReqB
// =========================================================================== // ===========================================================================
// Execute stage // Execute stage
@@ -394,13 +309,6 @@ class TensorCoreDecoupled(
io.writeback.bits.data.widthOption.get, io.writeback.bits.data.widthOption.get,
"response data width does not match the writeback data width") "response data width does not match the writeback data width")
// FIXME: unnecessary
val substepDeqA = RegInit(0.U(1.W))
when (respQueueA.fire) {
substepDeqA := substepDeqA + 1.U
}
dontTouch(substepDeqA)
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles // Stage the operands in a pipeline so that we obtain the full 4x4 tiles
// ready for compute. Also send the set/step tag along the pipe for // ready for compute. Also send the set/step tag along the pipe for
// alignment check. // alignment check.
@@ -461,7 +369,125 @@ class TensorCoreDecoupled(
fullB.io.deq.ready := fullBBuf.io.enq.ready fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready
// fullC/fullCTag is instiated at the access stage // C access from regfile
//
// C access is initiated from the backend because the regfile access latency
// is much lower than the SMEM latency. If C access is done in the frontend,
// the C response queue will fill up and block further run-ahead until the A
// and B smem reqs arrive and backend starts running, causing the frontend to
// stall waiting for remaining C accesses to finish.
// regfile is fast; don't need a deep response queue
val respQueueCDepth = 3
val respQueueC = Module(new Queue(
new Bundle {
val tag = new Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
}
val data = UInt(io.respC.widthOption.get.W)
},
respQueueCDepth
))
// access state of the C accumulator value
val stateAccessC = RegInit(State.idle)
// note this is different from warpAccess and belongs to execute
val warpAccessC = RegInit(0.U(numWarpBits.W))
val setAccessC = RegInit(0.U(setBits.W))
val stepAccessC = RegInit(0.U(stepBits.W))
val substepAccessC = RegInit(0.U(1.W))
val genReqC = WireInit(false.B)
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
when (genReqC) {
substepAccessC := substepAccessC + 1.U
}
dontTouch(stepAccessC)
dontTouch(substepAccessC)
val doneReqC = RegInit(false.B)
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
switch(stateAccessC) {
is(State.idle) {
//arrival of A and B response kicks off the execute backend
when (respQueueA.valid || respQueueB.valid) {
stateAccessC := State.run
when (respQueueA.valid) {
warpAccessC := respQueueA.bits.tag.warp
}.elsewhen (respQueueB.valid) {
warpAccessC := respQueueB.bits.tag.warp
}
}
}
is(State.run) {
when (doneReqC) {
stateAccessC := State.finish
}
}
is(State.finish) {
// FIXME: is finish state needed?
stateAccessC := State.idle
}
}
when (stateAccessC === State.finish) {
doneReqC := false.B
setAccessC := 0.U
stepAccessC := 0.U
substepAccessC := 0.U
}
// because C reg IO arrives 1 cycle later, it will get latched onto the
// response queue 2 cycles later. Make sure there's at least two entries of
// space at the queue before requesting the reg value
// FIXME: doesn't feel very clean
require(respQueueC.entries >= 2)
val hasSpace =
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
(respQueueC.entries - 1).U,
(respQueueC.entries - 2).U)
genReqC := (stateAccessC === State.run) && hasSpace && !doneReqC
// give 1-cycle delay to sync valid/metadata with C regfile response
val respCValid = RegNext(genReqC)
val warpAccessCDelayed = RegNext(warpAccessC)
val setAccessCDelayed = RegNext(setAccessC)
val stepAccessCDelayed = RegNext(stepAccessC)
// note rd is independent to sets
def rdGen(step: UInt, substep: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
// thread
(step << 1/*2 substeps*/) + substep
}
io.reqC.valid := genReqC
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
// queue the regfile response to buffers
// these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile
val fullC = Module(new FillBuffer(
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
))
fullC.io.enq.valid := respQueueC.io.deq.valid
fullC.io.enq.bits := respQueueC.io.deq.bits.data
respQueueC.io.deq.ready := fullC.io.enq.ready
val fullCTag = Module(new Queue(
chiselTypeOf(respQueueC.io.deq.bits.tag),
entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
val fullCBuf = Module(new Queue( val fullCBuf = Module(new Queue(
new Bundle { new Bundle {
@@ -797,6 +823,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.initiate.valid := io.start tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 3.U // bogus, static value tensor.io.initiate.bits.wid := 3.U // bogus, static value
tensor.io.initiate.bits.addressA := 0x0.U
tensor.io.initiate.bits.addressB := 0x800.U
tensor.io.writeback.ready := true.B tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last

View File

@@ -849,8 +849,8 @@ class RadianceTileModuleImp(outer: RadianceTile)
// Instantiate a fake tensor core module to force unique-ification of module // Instantiate a fake tensor core module to force unique-ification of module
// names in the Chisel-generated Verilog. These should be left out for // names in the Chisel-generated Verilog. These should be left out for
// synthesis runs, although these will likely be optimized-out if the inputs // synthesis runs, although it's likely they will be optimized-out with all
// are tied to low. // inputs tied to low.
if (outer.radianceParams.core.tensorCoreDecoupled) { if (outer.radianceParams.core.tensorCoreDecoupled) {
val tensorNumSourceIds = (1 << outer.tensorTagWidth) val tensorNumSourceIds = (1 << outer.tensorTagWidth)