tensor: Reassert initiate.ready as soon as access ready
This commit is contained in:
@@ -69,6 +69,11 @@ class TensorCoreDecoupled(
|
||||
val source = UInt(sourceWidth.W)
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
class TensorMemTag extends Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
// mem response after translation from TL source to set/step tag
|
||||
class TensorMemRespWithTag(
|
||||
dataWidth: Int
|
||||
@@ -77,15 +82,11 @@ class TensorCoreDecoupled(
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
|
||||
// FSM
|
||||
// ---
|
||||
// This drives the overall pipeline of memory requests, dot-product unit
|
||||
// operations and regfile writeback.
|
||||
|
||||
val busy = RegInit(false.B)
|
||||
// Holds the warp id the core is currently working on. Note that we only
|
||||
// support one outstanding warp request
|
||||
val warpAccess = RegInit(0.U(numWarpBits.W))
|
||||
// ===========================================================================
|
||||
// Access stage
|
||||
// ===========================================================================
|
||||
//
|
||||
// Frontend of the decoupled access/execute pipeline.
|
||||
|
||||
// sets: k iteration
|
||||
val numSets = (tilingParams.k / tilingParams.kc)
|
||||
@@ -97,39 +98,15 @@ class TensorCoreDecoupled(
|
||||
val lastStep = ((1 << stepBits) - 1)
|
||||
def setDone(set: UInt) = (set === lastSet.U)
|
||||
def stepDone(step: UInt) = (step === lastStep.U)
|
||||
// 'index' is the index of a memory request among the sequence of requests
|
||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||
require(tilingParams.m == tilingParams.n,
|
||||
"currently only supports square SMEM tile")
|
||||
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
when (io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
busy := true.B
|
||||
warpAccess := wid
|
||||
when(io.writeback.fire) {
|
||||
assert(
|
||||
io.writeback.bits.wid =/= wid,
|
||||
"unsupported concurrent initiate and writeback to the same warp"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: @perf: Instead of waiting until the last writeback, release busy as
|
||||
// soon as the access frontend is complete so that there's a better chance to
|
||||
// saturate the backend with back-to-back HGMMAs. This would require sending
|
||||
// the 'wid' register to backend instead of having it shared with the
|
||||
// frontend.
|
||||
when(io.writeback.fire && io.writeback.bits.last) {
|
||||
busy := false.B
|
||||
}
|
||||
|
||||
// serialize every HGMMA request
|
||||
io.initiate.ready := !busy
|
||||
|
||||
// ===========================================================================
|
||||
// Access stage
|
||||
// ===========================================================================
|
||||
//
|
||||
// Frontend of the decoupled access/execute pipeline.
|
||||
|
||||
// States
|
||||
//
|
||||
object AccessorState extends ChiselEnum {
|
||||
val idle = Value(0.U)
|
||||
val access = Value(1.U)
|
||||
@@ -142,6 +119,30 @@ class TensorCoreDecoupled(
|
||||
val allReqsDone = WireInit(false.B)
|
||||
dontTouch(allReqsDone)
|
||||
|
||||
val warpAccess = RegInit(0.U(numWarpBits.W))
|
||||
|
||||
class BlockState extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
val stateInit = Wire(new BlockState)
|
||||
stateInit.set := 0.U
|
||||
stateInit.index := 0.U
|
||||
val stateA = RegInit(stateInit)
|
||||
val stateB = RegInit(stateInit)
|
||||
dontTouch(stateA)
|
||||
dontTouch(stateA.index)
|
||||
dontTouch(stateB)
|
||||
dontTouch(stateB.index)
|
||||
|
||||
io.initiate.ready := (state === AccessorState.idle)
|
||||
when (io.initiate.fire) {
|
||||
warpAccess := io.initiate.bits.wid
|
||||
assert(stateA.set === 0.U && stateA.index === 0.U &&
|
||||
stateB.set === 0.U && stateB.index === 0.U,
|
||||
"stateA and stateB not initialized to zero")
|
||||
}
|
||||
|
||||
switch(state) {
|
||||
is(AccessorState.idle) {
|
||||
when(io.initiate.fire) {
|
||||
@@ -154,39 +155,10 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
is(AccessorState.finish) {
|
||||
// FIXME: decouple writeback
|
||||
when(io.writeback.fire) {
|
||||
// FIXME: is finish state needed?
|
||||
state := AccessorState.idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 'index' is the index of a memory request among the sequence of requests
|
||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||
require(tilingParams.m == tilingParams.n,
|
||||
"currently only supports square SMEM tile")
|
||||
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
class State extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
class TensorMemTag extends Bundle {
|
||||
val warp = UInt(numWarpBits.W)
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
|
||||
val stateInit = Wire(new State)
|
||||
stateInit.set := 0.U
|
||||
stateInit.index := 0.U
|
||||
val stateA = RegInit(stateInit)
|
||||
val stateB = RegInit(stateInit)
|
||||
dontTouch(stateA)
|
||||
dontTouch(stateB)
|
||||
|
||||
when (io.reqA.fire) {
|
||||
when (stateA.index === lastIndex.U) {
|
||||
|
||||
Reference in New Issue
Block a user