tensor: Move C reg access to execute stage for higher util
This prevents coupling between C access in frontend & queue freeup in backend.
This commit is contained in:
@@ -131,15 +131,15 @@ class TensorCoreDecoupled(
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
object AccessorState extends ChiselEnum {
|
||||
object State extends ChiselEnum {
|
||||
val idle = Value(0.U)
|
||||
val access = Value(1.U)
|
||||
val run = Value(1.U)
|
||||
// All set/step sequencing is complete and the tensor core is holding the
|
||||
// result data until downstream writeback is ready.
|
||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||
val finish = Value(2.U)
|
||||
}
|
||||
val state = RegInit(AccessorState.idle)
|
||||
val state = RegInit(State.idle)
|
||||
val allReqsDone = WireInit(false.B)
|
||||
dontTouch(allReqsDone)
|
||||
|
||||
@@ -159,7 +159,7 @@ class TensorCoreDecoupled(
|
||||
dontTouch(stateA)
|
||||
dontTouch(stateB)
|
||||
|
||||
io.initiate.ready := (state === AccessorState.idle)
|
||||
io.initiate.ready := (state === State.idle)
|
||||
when (io.initiate.fire) {
|
||||
warpAccess := io.initiate.bits.wid
|
||||
addrAAccess := io.initiate.bits.addressA
|
||||
@@ -170,19 +170,19 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
|
||||
switch(state) {
|
||||
is(AccessorState.idle) {
|
||||
is(State.idle) {
|
||||
when(io.initiate.fire) {
|
||||
state := AccessorState.access
|
||||
state := State.run
|
||||
}
|
||||
}
|
||||
is(AccessorState.access) {
|
||||
is(State.run) {
|
||||
when (allReqsDone) {
|
||||
state := AccessorState.finish
|
||||
state := State.finish
|
||||
}
|
||||
}
|
||||
is(AccessorState.finish) {
|
||||
is(State.finish) {
|
||||
// FIXME: is finish state needed?
|
||||
state := AccessorState.idle
|
||||
state := State.idle
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,9 +232,8 @@ class TensorCoreDecoupled(
|
||||
|
||||
val doneReqA = RegInit(false.B)
|
||||
val doneReqB = RegInit(false.B)
|
||||
val doneReqC = RegInit(false.B)
|
||||
val genReqA = (state === AccessorState.access) && !doneReqA
|
||||
val genReqB = (state === AccessorState.access) && !doneReqB
|
||||
val genReqA = (state === State.run) && !doneReqA
|
||||
val genReqB = (state === State.run) && !doneReqB
|
||||
|
||||
// Request generation
|
||||
//
|
||||
@@ -274,105 +273,21 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
|
||||
// C access from regfile
|
||||
//
|
||||
|
||||
// regfile is fast; don't need a deep response queue
|
||||
val respQueueCDepth = 2
|
||||
val respQueueC = Module(new Queue(
|
||||
new Bundle {
|
||||
val tag = new Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
}
|
||||
val data = UInt(io.respC.widthOption.get.W)
|
||||
},
|
||||
respQueueCDepth
|
||||
))
|
||||
|
||||
// because C reg IO arrives 1 cycle later, it will get latched onto the
|
||||
// response queue 2 cycles later. Make sure there's at least two entries of
|
||||
// space at the queue before requesting the reg value
|
||||
// FIXME: doesn't feel very clean
|
||||
require(respQueueC.entries >= 2)
|
||||
val hasSpace =
|
||||
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
||||
(respQueueC.entries - 1).U,
|
||||
(respQueueC.entries - 2).U)
|
||||
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
||||
|
||||
// set/step state of the C accumulator value that will be latched ath the
|
||||
// next cycle.
|
||||
val setAccessC = RegInit(0.U(setBits.W))
|
||||
val stepAccessC = RegInit(0.U(stepBits.W))
|
||||
val substepAccessC = RegInit(0.U(1.W))
|
||||
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
|
||||
when (genReqC) {
|
||||
substepAccessC := substepAccessC + 1.U
|
||||
}
|
||||
dontTouch(stepAccessC)
|
||||
dontTouch(substepAccessC)
|
||||
dontTouch(nextStepAccessC)
|
||||
|
||||
// give 1-cycle delay to sync valid/metadata with C regfile response
|
||||
val respCValid = RegNext(genReqC)
|
||||
val warpAccessCDelayed = RegNext(warpAccess)
|
||||
val setAccessCDelayed = RegNext(setAccessC)
|
||||
val stepAccessCDelayed = RegNext(stepAccessC)
|
||||
|
||||
// note rd is independent to sets
|
||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
// thread
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
io.reqC.valid := genReqC
|
||||
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
|
||||
|
||||
// queue the regfile response to buffers
|
||||
// these strictly belong to the execute stage
|
||||
respQueueC.io.enq.valid := respCValid
|
||||
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
|
||||
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
|
||||
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
|
||||
respQueueC.io.enq.bits.data := io.respC
|
||||
|
||||
// serialize every two C responses into one full 4x4 C tile
|
||||
val fullC = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
|
||||
))
|
||||
fullC.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||
val fullCTag = Module(new Queue(
|
||||
chiselTypeOf(respQueueC.io.deq.bits.tag),
|
||||
entries = 1, pipe = true
|
||||
))
|
||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||
|
||||
// finalize state when everything has been accessed
|
||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
|
||||
when (state === AccessorState.finish) {
|
||||
when (state === State.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
doneReqC := false.B
|
||||
stateA.set := 0.U
|
||||
stateA.index := 0.U
|
||||
stateB.set := 0.U
|
||||
stateB.index := 0.U
|
||||
setAccessC := 0.U
|
||||
stepAccessC := 0.U
|
||||
substepAccessC := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB && doneReqC
|
||||
allReqsDone := doneReqA && doneReqB
|
||||
|
||||
// ===========================================================================
|
||||
// Execute stage
|
||||
@@ -394,13 +309,6 @@ class TensorCoreDecoupled(
|
||||
io.writeback.bits.data.widthOption.get,
|
||||
"response data width does not match the writeback data width")
|
||||
|
||||
// FIXME: unnecessary
|
||||
val substepDeqA = RegInit(0.U(1.W))
|
||||
when (respQueueA.fire) {
|
||||
substepDeqA := substepDeqA + 1.U
|
||||
}
|
||||
dontTouch(substepDeqA)
|
||||
|
||||
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
|
||||
// ready for compute. Also send the set/step tag along the pipe for
|
||||
// alignment check.
|
||||
@@ -461,7 +369,125 @@ class TensorCoreDecoupled(
|
||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||
|
||||
// fullC/fullCTag is instiated at the access stage
|
||||
// C access from regfile
|
||||
//
|
||||
// C access is initiated from the backend because the regfile access latency
|
||||
// is much lower than the SMEM latency. If C access is done in the frontend,
|
||||
// the C response queue will fill up and block further run-ahead until the A
|
||||
// and B smem reqs arrive and backend starts running, causing the frontend to
|
||||
// stall waiting for remaining C accesses to finish.
|
||||
|
||||
// regfile is fast; don't need a deep response queue
|
||||
val respQueueCDepth = 3
|
||||
val respQueueC = Module(new Queue(
|
||||
new Bundle {
|
||||
val tag = new Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
}
|
||||
val data = UInt(io.respC.widthOption.get.W)
|
||||
},
|
||||
respQueueCDepth
|
||||
))
|
||||
|
||||
// access state of the C accumulator value
|
||||
val stateAccessC = RegInit(State.idle)
|
||||
// note this is different from warpAccess and belongs to execute
|
||||
val warpAccessC = RegInit(0.U(numWarpBits.W))
|
||||
val setAccessC = RegInit(0.U(setBits.W))
|
||||
val stepAccessC = RegInit(0.U(stepBits.W))
|
||||
val substepAccessC = RegInit(0.U(1.W))
|
||||
val genReqC = WireInit(false.B)
|
||||
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
|
||||
when (genReqC) {
|
||||
substepAccessC := substepAccessC + 1.U
|
||||
}
|
||||
dontTouch(stepAccessC)
|
||||
dontTouch(substepAccessC)
|
||||
|
||||
val doneReqC = RegInit(false.B)
|
||||
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
|
||||
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
|
||||
|
||||
switch(stateAccessC) {
|
||||
is(State.idle) {
|
||||
//arrival of A and B response kicks off the execute backend
|
||||
when (respQueueA.valid || respQueueB.valid) {
|
||||
stateAccessC := State.run
|
||||
when (respQueueA.valid) {
|
||||
warpAccessC := respQueueA.bits.tag.warp
|
||||
}.elsewhen (respQueueB.valid) {
|
||||
warpAccessC := respQueueB.bits.tag.warp
|
||||
}
|
||||
}
|
||||
}
|
||||
is(State.run) {
|
||||
when (doneReqC) {
|
||||
stateAccessC := State.finish
|
||||
}
|
||||
}
|
||||
is(State.finish) {
|
||||
// FIXME: is finish state needed?
|
||||
stateAccessC := State.idle
|
||||
}
|
||||
}
|
||||
|
||||
when (stateAccessC === State.finish) {
|
||||
doneReqC := false.B
|
||||
setAccessC := 0.U
|
||||
stepAccessC := 0.U
|
||||
substepAccessC := 0.U
|
||||
}
|
||||
|
||||
// because C reg IO arrives 1 cycle later, it will get latched onto the
|
||||
// response queue 2 cycles later. Make sure there's at least two entries of
|
||||
// space at the queue before requesting the reg value
|
||||
// FIXME: doesn't feel very clean
|
||||
require(respQueueC.entries >= 2)
|
||||
val hasSpace =
|
||||
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
||||
(respQueueC.entries - 1).U,
|
||||
(respQueueC.entries - 2).U)
|
||||
|
||||
genReqC := (stateAccessC === State.run) && hasSpace && !doneReqC
|
||||
|
||||
// give 1-cycle delay to sync valid/metadata with C regfile response
|
||||
val respCValid = RegNext(genReqC)
|
||||
val warpAccessCDelayed = RegNext(warpAccessC)
|
||||
val setAccessCDelayed = RegNext(setAccessC)
|
||||
val stepAccessCDelayed = RegNext(stepAccessC)
|
||||
|
||||
// note rd is independent to sets
|
||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
// thread
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
io.reqC.valid := genReqC
|
||||
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
|
||||
|
||||
// queue the regfile response to buffers
|
||||
// these strictly belong to the execute stage
|
||||
respQueueC.io.enq.valid := respCValid
|
||||
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
|
||||
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
|
||||
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
|
||||
respQueueC.io.enq.bits.data := io.respC
|
||||
|
||||
// serialize every two C responses into one full 4x4 C tile
|
||||
val fullC = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
|
||||
))
|
||||
fullC.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||
val fullCTag = Module(new Queue(
|
||||
chiselTypeOf(respQueueC.io.deq.bits.tag),
|
||||
entries = 1, pipe = true
|
||||
))
|
||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||
|
||||
val fullCBuf = Module(new Queue(
|
||||
new Bundle {
|
||||
@@ -797,6 +823,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
|
||||
tensor.io.initiate.valid := io.start
|
||||
tensor.io.initiate.bits.wid := 3.U // bogus, static value
|
||||
tensor.io.initiate.bits.addressA := 0x0.U
|
||||
tensor.io.initiate.bits.addressB := 0x800.U
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||
|
||||
@@ -849,8 +849,8 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
||||
|
||||
// Instantiate a fake tensor core module to force unique-ification of module
|
||||
// names in the Chisel-generated Verilog. These should be left out for
|
||||
// synthesis runs, although these will likely be optimized-out if the inputs
|
||||
// are tied to low.
|
||||
// synthesis runs, although it's likely they will be optimized-out with all
|
||||
// inputs tied to low.
|
||||
|
||||
if (outer.radianceParams.core.tensorCoreDecoupled) {
|
||||
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
||||
|
||||
Reference in New Issue
Block a user