tensor: Reassert initiate.ready as soon as access ready

This commit is contained in:
Hansung Kim
2024-10-22 23:10:11 -07:00
parent 95ecc5180f
commit 2a8c488d28

View File

@@ -69,6 +69,11 @@ class TensorCoreDecoupled(
val source = UInt(sourceWidth.W)
val data = UInt(dataWidth.W)
}
class TensorMemTag extends Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
// mem response after translation from TL source to set/step tag
class TensorMemRespWithTag(
dataWidth: Int
@@ -77,15 +82,11 @@ class TensorCoreDecoupled(
val data = UInt(dataWidth.W)
}
// FSM
// ---
// This drives the overall pipeline of memory requests, dot-product unit
// operations and regfile writeback.
val busy = RegInit(false.B)
// Holds the warp id the core is currently working on. Note that we only
// support one outstanding warp request
val warpAccess = RegInit(0.U(numWarpBits.W))
// ===========================================================================
// Access stage
// ===========================================================================
//
// Frontend of the decoupled access/execute pipeline.
// sets: k iteration
val numSets = (tilingParams.k / tilingParams.kc)
@@ -97,39 +98,15 @@ class TensorCoreDecoupled(
val lastStep = ((1 << stepBits) - 1)
def setDone(set: UInt) = (set === lastSet.U)
def stepDone(step: UInt) = (step === lastStep.U)
// 'index' is the index of a memory request among the sequence of requests
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
// or [0,n/2), where 2 is the stride can be read in a single request size.
require(tilingParams.m == tilingParams.n,
"currently only supports square SMEM tile")
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1
when (io.initiate.fire) {
val wid = io.initiate.bits.wid
busy := true.B
warpAccess := wid
when(io.writeback.fire) {
assert(
io.writeback.bits.wid =/= wid,
"unsupported concurrent initiate and writeback to the same warp"
)
}
}
// TODO: @perf: Instead of waiting until the last writeback, release busy as
// soon as the access frontend is complete so that there's a better chance to
// saturate the backend with back-to-back HGMMAs. This would require sending
// the 'wid' register to backend instead of having it shared with the
// frontend.
when(io.writeback.fire && io.writeback.bits.last) {
busy := false.B
}
// serialize every HGMMA request
io.initiate.ready := !busy
// ===========================================================================
// Access stage
// ===========================================================================
//
// Frontend of the decoupled access/execute pipeline.
// States
//
object AccessorState extends ChiselEnum {
val idle = Value(0.U)
val access = Value(1.U)
@@ -142,6 +119,30 @@ class TensorCoreDecoupled(
val allReqsDone = WireInit(false.B)
dontTouch(allReqsDone)
val warpAccess = RegInit(0.U(numWarpBits.W))
class BlockState extends Bundle {
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
val stateInit = Wire(new BlockState)
stateInit.set := 0.U
stateInit.index := 0.U
val stateA = RegInit(stateInit)
val stateB = RegInit(stateInit)
dontTouch(stateA)
dontTouch(stateA.index)
dontTouch(stateB)
dontTouch(stateB.index)
io.initiate.ready := (state === AccessorState.idle)
when (io.initiate.fire) {
warpAccess := io.initiate.bits.wid
assert(stateA.set === 0.U && stateA.index === 0.U &&
stateB.set === 0.U && stateB.index === 0.U,
"stateA and stateB not initialized to zero")
}
switch(state) {
is(AccessorState.idle) {
when(io.initiate.fire) {
@@ -154,40 +155,11 @@ class TensorCoreDecoupled(
}
}
is(AccessorState.finish) {
// FIXME: decouple writeback
when(io.writeback.fire) {
state := AccessorState.idle
}
// FIXME: is finish state needed?
state := AccessorState.idle
}
}
// 'index' is the index of a memory request among the sequence of requests
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
// or [0,n/2), where 2 is the stride can be read in a single request size.
require(tilingParams.m == tilingParams.n,
"currently only supports square SMEM tile")
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1
class State extends Bundle {
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
class TensorMemTag extends Bundle {
val warp = UInt(numWarpBits.W)
val set = UInt(setBits.W)
val index = UInt(indexBits.W)
}
val stateInit = Wire(new State)
stateInit.set := 0.U
stateInit.index := 0.U
val stateA = RegInit(stateInit)
val stateB = RegInit(stateInit)
dontTouch(stateA)
dontTouch(stateB)
when (io.reqA.fire) {
when (stateA.index === lastIndex.U) {
stateA.set := stateA.set + 1.U