tensor: Move C reg access to execute stage for higher util
This prevents coupling between C access in frontend & queue freeup in backend.
This commit is contained in:
@@ -131,15 +131,15 @@ class TensorCoreDecoupled(
|
|||||||
val indexBits = log2Ceil(numIndices)
|
val indexBits = log2Ceil(numIndices)
|
||||||
val lastIndex = (1 << indexBits) - 1
|
val lastIndex = (1 << indexBits) - 1
|
||||||
|
|
||||||
object AccessorState extends ChiselEnum {
|
object State extends ChiselEnum {
|
||||||
val idle = Value(0.U)
|
val idle = Value(0.U)
|
||||||
val access = Value(1.U)
|
val run = Value(1.U)
|
||||||
// All set/step sequencing is complete and the tensor core is holding the
|
// All set/step sequencing is complete and the tensor core is holding the
|
||||||
// result data until downstream writeback is ready.
|
// result data until downstream writeback is ready.
|
||||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||||
val finish = Value(2.U)
|
val finish = Value(2.U)
|
||||||
}
|
}
|
||||||
val state = RegInit(AccessorState.idle)
|
val state = RegInit(State.idle)
|
||||||
val allReqsDone = WireInit(false.B)
|
val allReqsDone = WireInit(false.B)
|
||||||
dontTouch(allReqsDone)
|
dontTouch(allReqsDone)
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ class TensorCoreDecoupled(
|
|||||||
dontTouch(stateA)
|
dontTouch(stateA)
|
||||||
dontTouch(stateB)
|
dontTouch(stateB)
|
||||||
|
|
||||||
io.initiate.ready := (state === AccessorState.idle)
|
io.initiate.ready := (state === State.idle)
|
||||||
when (io.initiate.fire) {
|
when (io.initiate.fire) {
|
||||||
warpAccess := io.initiate.bits.wid
|
warpAccess := io.initiate.bits.wid
|
||||||
addrAAccess := io.initiate.bits.addressA
|
addrAAccess := io.initiate.bits.addressA
|
||||||
@@ -170,19 +170,19 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch(state) {
|
switch(state) {
|
||||||
is(AccessorState.idle) {
|
is(State.idle) {
|
||||||
when(io.initiate.fire) {
|
when(io.initiate.fire) {
|
||||||
state := AccessorState.access
|
state := State.run
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is(AccessorState.access) {
|
is(State.run) {
|
||||||
when (allReqsDone) {
|
when (allReqsDone) {
|
||||||
state := AccessorState.finish
|
state := State.finish
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
is(AccessorState.finish) {
|
is(State.finish) {
|
||||||
// FIXME: is finish state needed?
|
// FIXME: is finish state needed?
|
||||||
state := AccessorState.idle
|
state := State.idle
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,9 +232,8 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
val doneReqA = RegInit(false.B)
|
val doneReqA = RegInit(false.B)
|
||||||
val doneReqB = RegInit(false.B)
|
val doneReqB = RegInit(false.B)
|
||||||
val doneReqC = RegInit(false.B)
|
val genReqA = (state === State.run) && !doneReqA
|
||||||
val genReqA = (state === AccessorState.access) && !doneReqA
|
val genReqB = (state === State.run) && !doneReqB
|
||||||
val genReqB = (state === AccessorState.access) && !doneReqB
|
|
||||||
|
|
||||||
// Request generation
|
// Request generation
|
||||||
//
|
//
|
||||||
@@ -274,105 +273,21 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// C access from regfile
|
|
||||||
//
|
|
||||||
|
|
||||||
// regfile is fast; don't need a deep response queue
|
|
||||||
val respQueueCDepth = 2
|
|
||||||
val respQueueC = Module(new Queue(
|
|
||||||
new Bundle {
|
|
||||||
val tag = new Bundle {
|
|
||||||
val warp = UInt(numWarpBits.W)
|
|
||||||
val set = UInt(setBits.W)
|
|
||||||
val step = UInt(stepBits.W)
|
|
||||||
}
|
|
||||||
val data = UInt(io.respC.widthOption.get.W)
|
|
||||||
},
|
|
||||||
respQueueCDepth
|
|
||||||
))
|
|
||||||
|
|
||||||
// because C reg IO arrives 1 cycle later, it will get latched onto the
|
|
||||||
// response queue 2 cycles later. Make sure there's at least two entries of
|
|
||||||
// space at the queue before requesting the reg value
|
|
||||||
// FIXME: doesn't feel very clean
|
|
||||||
require(respQueueC.entries >= 2)
|
|
||||||
val hasSpace =
|
|
||||||
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
|
||||||
(respQueueC.entries - 1).U,
|
|
||||||
(respQueueC.entries - 2).U)
|
|
||||||
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
|
||||||
|
|
||||||
// set/step state of the C accumulator value that will be latched ath the
|
|
||||||
// next cycle.
|
|
||||||
val setAccessC = RegInit(0.U(setBits.W))
|
|
||||||
val stepAccessC = RegInit(0.U(stepBits.W))
|
|
||||||
val substepAccessC = RegInit(0.U(1.W))
|
|
||||||
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
|
|
||||||
when (genReqC) {
|
|
||||||
substepAccessC := substepAccessC + 1.U
|
|
||||||
}
|
|
||||||
dontTouch(stepAccessC)
|
|
||||||
dontTouch(substepAccessC)
|
|
||||||
dontTouch(nextStepAccessC)
|
|
||||||
|
|
||||||
// give 1-cycle delay to sync valid/metadata with C regfile response
|
|
||||||
val respCValid = RegNext(genReqC)
|
|
||||||
val warpAccessCDelayed = RegNext(warpAccess)
|
|
||||||
val setAccessCDelayed = RegNext(setAccessC)
|
|
||||||
val stepAccessCDelayed = RegNext(stepAccessC)
|
|
||||||
|
|
||||||
// note rd is independent to sets
|
|
||||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
|
||||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
|
||||||
// thread
|
|
||||||
(step << 1/*2 substeps*/) + substep
|
|
||||||
}
|
|
||||||
io.reqC.valid := genReqC
|
|
||||||
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
|
|
||||||
|
|
||||||
// queue the regfile response to buffers
|
|
||||||
// these strictly belong to the execute stage
|
|
||||||
respQueueC.io.enq.valid := respCValid
|
|
||||||
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
|
|
||||||
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
|
|
||||||
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
|
|
||||||
respQueueC.io.enq.bits.data := io.respC
|
|
||||||
|
|
||||||
// serialize every two C responses into one full 4x4 C tile
|
|
||||||
val fullC = Module(new FillBuffer(
|
|
||||||
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
|
|
||||||
))
|
|
||||||
fullC.io.enq.valid := respQueueC.io.deq.valid
|
|
||||||
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
|
||||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
|
||||||
val fullCTag = Module(new Queue(
|
|
||||||
chiselTypeOf(respQueueC.io.deq.bits.tag),
|
|
||||||
entries = 1, pipe = true
|
|
||||||
))
|
|
||||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
|
||||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
|
||||||
|
|
||||||
// finalize state when everything has been accessed
|
// finalize state when everything has been accessed
|
||||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||||
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
|
|
||||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||||
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
|
when (state === State.finish) {
|
||||||
when (state === AccessorState.finish) {
|
|
||||||
doneReqA := false.B
|
doneReqA := false.B
|
||||||
doneReqB := false.B
|
doneReqB := false.B
|
||||||
doneReqC := false.B
|
|
||||||
stateA.set := 0.U
|
stateA.set := 0.U
|
||||||
stateA.index := 0.U
|
stateA.index := 0.U
|
||||||
stateB.set := 0.U
|
stateB.set := 0.U
|
||||||
stateB.index := 0.U
|
stateB.index := 0.U
|
||||||
setAccessC := 0.U
|
|
||||||
stepAccessC := 0.U
|
|
||||||
substepAccessC := 0.U
|
|
||||||
}
|
}
|
||||||
|
|
||||||
allReqsDone := doneReqA && doneReqB && doneReqC
|
allReqsDone := doneReqA && doneReqB
|
||||||
|
|
||||||
// ===========================================================================
|
// ===========================================================================
|
||||||
// Execute stage
|
// Execute stage
|
||||||
@@ -394,13 +309,6 @@ class TensorCoreDecoupled(
|
|||||||
io.writeback.bits.data.widthOption.get,
|
io.writeback.bits.data.widthOption.get,
|
||||||
"response data width does not match the writeback data width")
|
"response data width does not match the writeback data width")
|
||||||
|
|
||||||
// FIXME: unnecessary
|
|
||||||
val substepDeqA = RegInit(0.U(1.W))
|
|
||||||
when (respQueueA.fire) {
|
|
||||||
substepDeqA := substepDeqA + 1.U
|
|
||||||
}
|
|
||||||
dontTouch(substepDeqA)
|
|
||||||
|
|
||||||
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
|
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
|
||||||
// ready for compute. Also send the set/step tag along the pipe for
|
// ready for compute. Also send the set/step tag along the pipe for
|
||||||
// alignment check.
|
// alignment check.
|
||||||
@@ -461,7 +369,125 @@ class TensorCoreDecoupled(
|
|||||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
|
|
||||||
// fullC/fullCTag is instiated at the access stage
|
// C access from regfile
|
||||||
|
//
|
||||||
|
// C access is initiated from the backend because the regfile access latency
|
||||||
|
// is much lower than the SMEM latency. If C access is done in the frontend,
|
||||||
|
// the C response queue will fill up and block further run-ahead until the A
|
||||||
|
// and B smem reqs arrive and backend starts running, causing the frontend to
|
||||||
|
// stall waiting for remaining C accesses to finish.
|
||||||
|
|
||||||
|
// regfile is fast; don't need a deep response queue
|
||||||
|
val respQueueCDepth = 3
|
||||||
|
val respQueueC = Module(new Queue(
|
||||||
|
new Bundle {
|
||||||
|
val tag = new Bundle {
|
||||||
|
val warp = UInt(numWarpBits.W)
|
||||||
|
val set = UInt(setBits.W)
|
||||||
|
val step = UInt(stepBits.W)
|
||||||
|
}
|
||||||
|
val data = UInt(io.respC.widthOption.get.W)
|
||||||
|
},
|
||||||
|
respQueueCDepth
|
||||||
|
))
|
||||||
|
|
||||||
|
// access state of the C accumulator value
|
||||||
|
val stateAccessC = RegInit(State.idle)
|
||||||
|
// note this is different from warpAccess and belongs to execute
|
||||||
|
val warpAccessC = RegInit(0.U(numWarpBits.W))
|
||||||
|
val setAccessC = RegInit(0.U(setBits.W))
|
||||||
|
val stepAccessC = RegInit(0.U(stepBits.W))
|
||||||
|
val substepAccessC = RegInit(0.U(1.W))
|
||||||
|
val genReqC = WireInit(false.B)
|
||||||
|
val nextStepAccessC = genReqC && (substepAccessC === 1.U)
|
||||||
|
when (genReqC) {
|
||||||
|
substepAccessC := substepAccessC + 1.U
|
||||||
|
}
|
||||||
|
dontTouch(stepAccessC)
|
||||||
|
dontTouch(substepAccessC)
|
||||||
|
|
||||||
|
val doneReqC = RegInit(false.B)
|
||||||
|
val lastReqC = (setAccessC === lastSet.U) && (stepAccessC === lastStep.U)
|
||||||
|
when (lastReqC && nextStepAccessC) { doneReqC := true.B }
|
||||||
|
|
||||||
|
switch(stateAccessC) {
|
||||||
|
is(State.idle) {
|
||||||
|
//arrival of A and B response kicks off the execute backend
|
||||||
|
when (respQueueA.valid || respQueueB.valid) {
|
||||||
|
stateAccessC := State.run
|
||||||
|
when (respQueueA.valid) {
|
||||||
|
warpAccessC := respQueueA.bits.tag.warp
|
||||||
|
}.elsewhen (respQueueB.valid) {
|
||||||
|
warpAccessC := respQueueB.bits.tag.warp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is(State.run) {
|
||||||
|
when (doneReqC) {
|
||||||
|
stateAccessC := State.finish
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is(State.finish) {
|
||||||
|
// FIXME: is finish state needed?
|
||||||
|
stateAccessC := State.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when (stateAccessC === State.finish) {
|
||||||
|
doneReqC := false.B
|
||||||
|
setAccessC := 0.U
|
||||||
|
stepAccessC := 0.U
|
||||||
|
substepAccessC := 0.U
|
||||||
|
}
|
||||||
|
|
||||||
|
// because C reg IO arrives 1 cycle later, it will get latched onto the
|
||||||
|
// response queue 2 cycles later. Make sure there's at least two entries of
|
||||||
|
// space at the queue before requesting the reg value
|
||||||
|
// FIXME: doesn't feel very clean
|
||||||
|
require(respQueueC.entries >= 2)
|
||||||
|
val hasSpace =
|
||||||
|
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
||||||
|
(respQueueC.entries - 1).U,
|
||||||
|
(respQueueC.entries - 2).U)
|
||||||
|
|
||||||
|
genReqC := (stateAccessC === State.run) && hasSpace && !doneReqC
|
||||||
|
|
||||||
|
// give 1-cycle delay to sync valid/metadata with C regfile response
|
||||||
|
val respCValid = RegNext(genReqC)
|
||||||
|
val warpAccessCDelayed = RegNext(warpAccessC)
|
||||||
|
val setAccessCDelayed = RegNext(setAccessC)
|
||||||
|
val stepAccessCDelayed = RegNext(stepAccessC)
|
||||||
|
|
||||||
|
// note rd is independent to sets
|
||||||
|
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||||
|
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||||
|
// thread
|
||||||
|
(step << 1/*2 substeps*/) + substep
|
||||||
|
}
|
||||||
|
io.reqC.valid := genReqC
|
||||||
|
io.reqC.bits := rdGen(stepAccessC, substepAccessC)
|
||||||
|
|
||||||
|
// queue the regfile response to buffers
|
||||||
|
// these strictly belong to the execute stage
|
||||||
|
respQueueC.io.enq.valid := respCValid
|
||||||
|
respQueueC.io.enq.bits.tag.warp := warpAccessCDelayed
|
||||||
|
respQueueC.io.enq.bits.tag.set := setAccessCDelayed
|
||||||
|
respQueueC.io.enq.bits.tag.step := stepAccessCDelayed
|
||||||
|
respQueueC.io.enq.bits.data := io.respC
|
||||||
|
|
||||||
|
// serialize every two C responses into one full 4x4 C tile
|
||||||
|
val fullC = Module(new FillBuffer(
|
||||||
|
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
|
||||||
|
))
|
||||||
|
fullC.io.enq.valid := respQueueC.io.deq.valid
|
||||||
|
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
||||||
|
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||||
|
val fullCTag = Module(new Queue(
|
||||||
|
chiselTypeOf(respQueueC.io.deq.bits.tag),
|
||||||
|
entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||||
|
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||||
|
|
||||||
val fullCBuf = Module(new Queue(
|
val fullCBuf = Module(new Queue(
|
||||||
new Bundle {
|
new Bundle {
|
||||||
@@ -797,6 +823,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
|
|
||||||
tensor.io.initiate.valid := io.start
|
tensor.io.initiate.valid := io.start
|
||||||
tensor.io.initiate.bits.wid := 3.U // bogus, static value
|
tensor.io.initiate.bits.wid := 3.U // bogus, static value
|
||||||
|
tensor.io.initiate.bits.addressA := 0x0.U
|
||||||
|
tensor.io.initiate.bits.addressB := 0x800.U
|
||||||
tensor.io.writeback.ready := true.B
|
tensor.io.writeback.ready := true.B
|
||||||
|
|
||||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||||
|
|||||||
@@ -849,8 +849,8 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
|
|
||||||
// Instantiate a fake tensor core module to force unique-ification of module
|
// Instantiate a fake tensor core module to force unique-ification of module
|
||||||
// names in the Chisel-generated Verilog. These should be left out for
|
// names in the Chisel-generated Verilog. These should be left out for
|
||||||
// synthesis runs, although these will likely be optimized-out if the inputs
|
// synthesis runs, although it's likely they will be optimized-out with all
|
||||||
// are tied to low.
|
// inputs tied to low.
|
||||||
|
|
||||||
if (outer.radianceParams.core.tensorCoreDecoupled) {
|
if (outer.radianceParams.core.tensorCoreDecoupled) {
|
||||||
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
||||||
|
|||||||
Reference in New Issue
Block a user