tensor: Decouple A and B access states

Get rid of set/stepAccess states and let A and B access progress
independently.
This commit is contained in:
Hansung Kim
2024-10-18 22:42:41 -07:00
parent c0292dd0aa
commit 0aadc6074a

View File

@@ -82,15 +82,6 @@ class TensorCoreDecoupled(
// This drives the overall pipeline of memory requests, dot-product unit
// operations and regfile writeback.
object TensorState extends ChiselEnum {
val idle = Value(0.U)
val run = Value(1.U)
// All set/step sequencing is complete and the tensor core is holding the
// result data until downstream writeback is ready.
// FIXME: is this necessary if writeback is decoupled with queues?
val finish = Value(2.U)
}
val state = RegInit(TensorState.idle)
val busy = RegInit(false.B)
// Holds the warp id the core is currently working on. Note that we only
// support one outstanding warp request
@@ -107,22 +98,10 @@ class TensorCoreDecoupled(
def setDone(set: UInt) = (set === lastSet.U)
def stepDone(step: UInt) = (step === lastStep.U)
// set and step being currently accessed in the acc/ex frontend
val setAccess = RegInit(0.U(setBits.W))
val stepAccess = RegInit(0.U(stepBits.W))
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
val substepAccess = RegInit(0.U(1.W))
dontTouch(setAccess)
dontTouch(stepAccess)
dontTouch(substepAccess)
when(io.initiate.fire) {
when (io.initiate.fire) {
val wid = io.initiate.bits.wid
busy := true.B
warpReg := wid
setAccess := 0.U
stepAccess := 0.U
when(io.writeback.fire) {
assert(
io.writeback.bits.wid =/= wid,
@@ -143,55 +122,51 @@ class TensorCoreDecoupled(
// serialize every HGMMA request
io.initiate.ready := !busy
// Memory traffic generation
// -------------------------
// ===========================================================================
// Access stage
// ===========================================================================
//
val numTilesM = tilingParams.m / tilingParams.mc
val numTilesN = tilingParams.n / tilingParams.nc
// @cleanup: generalize in terms of M/N/K-majorness?
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
: (UInt/*A*/, UInt/*B*/) = {
// note that step iterates along N first, then M
val tileM = step % numTilesM.U
val tileN = step / numTilesM.U
// Frontend of the decoupled access/execute pipeline.
// note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses
//
// (row,col) coordinate of the compute tile
val tileRowA = tileM // M
val tileColA = set // K
val tileRowB = tileN // N
val tileColB = set // K
// (row,col) coordinate of the starting element of the compute tile
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
(substep << log2Ceil(tilingParams.mc / 2))
val elemColA = tileColA << log2Ceil(tilingParams.kc)
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
(substep << log2Ceil(tilingParams.nc / 2))
val elemColB = tileColB << log2Ceil(tilingParams.kc)
val rowStrideA = wordSize * tilingParams.k
val rowStrideABits = log2Ceil(rowStrideA)
val rowStrideB = wordSize * tilingParams.k
val rowStrideBBits = log2Ceil(rowStrideB)
val wordStrideBits = log2Ceil(wordSize)
val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits)
val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits)
(baseA + tileOffsetA, baseB + tileOffsetB)
// States
//
object AccessorState extends ChiselEnum {
val idle = Value(0.U)
val access = Value(1.U)
// All set/step sequencing is complete and the tensor core is holding the
// result data until downstream writeback is ready.
// FIXME: is this necessary if writeback is decoupled with queues?
val finish = Value(2.U)
}
val state = RegInit(AccessorState.idle)
val allReqsDone = WireInit(false.B)
dontTouch(allReqsDone)
// FIXME: bogus base address
val (addressA, addressB) =
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
switch(state) {
is(AccessorState.idle) {
when(io.initiate.fire) {
state := AccessorState.access
}
}
is(AccessorState.access) {
when (allReqsDone) {
state := AccessorState.finish
}
}
is(AccessorState.finish) {
// FIXME: decouple writeback
when(io.writeback.fire) {
state := AccessorState.idle
}
}
}
// 'index' is the index of a memory request among the sequence of requests
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
// or [0,n/2), where 2 is the stride can be read in a single request size.
require(tilingParams.m == tilingParams.n,
"currently only supports square SMEM tile")
val numIndices = tilingParams.m / 2
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
val indexBits = log2Ceil(numIndices)
val lastIndex = (1 << indexBits) - 1
@@ -219,9 +194,51 @@ class TensorCoreDecoupled(
tagB.index := tagB.index + 1.U
}
val genReqA = (state === TensorState.run)
val genReqB = (state === TensorState.run)
// Address generation
//
def addressGen(base: UInt, set: UInt, index: UInt): UInt = {
// note that both A and B are K-major to facilitate bank conflict-free SMEM
// accesses, so that below code applies to both.
//
// (row,col) coordinate of the compute tile
val tileRow = index
val tileCol = set
// (row,col) coordinate of the starting element of the compute tile
val elemRow = index << 1
val elemCol = tileCol << log2Ceil(tilingParams.kc)
val rowStride = tilingParams.k * wordSize
val rowStrideBits = log2Ceil(rowStride)
val wordStrideBits = log2Ceil(wordSize)
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
base + tileOffset
}
// FIXME: bogus base address
val addressA = addressGen(0.U, tagA.set, tagA.index)
val addressB = addressGen(0.U, tagB.set, tagB.index)
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
val genReqA = (state === AccessorState.access) && !doneReqA
val genReqB = (state === AccessorState.access) && !doneReqA
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
tagA.set := 0.U
tagA.index := 0.U
tagB.set := 0.U
tagB.index := 0.U
}
allReqsDone := doneReqA && doneReqB
// Request generation
//
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)),
@@ -249,34 +266,13 @@ class TensorCoreDecoupled(
}
}
// only advance to the next step if we fired mem requests for both A and B.
// also consider that B doesn't have to be fired every time due to reuse.
// @perf: too strict? should be able to have A and B progress separately
val firedAReg = RegInit(false.B)
val firedBReg = RegInit(false.B)
when (io.reqA.fire) { firedAReg := true.B }
when (io.reqB.fire) { firedBReg := true.B }
val firedANow = io.reqA.fire
val firedBNow = io.reqB.fire
val firedA = firedAReg || firedANow
val firedB = firedBReg || firedBNow
val nextSubstepAccess = firedA && firedB
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep
when (nextSubstepAccess) {
firedAReg := false.B
firedBReg := false.B
substepAccess := substepAccess + 1.U
}
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
// ===========================================================================
// Execute stage
// -------------
// ===========================================================================
//
// Backend of the decoupled access/execute pipeline.
//
// set and step being currently executed in the acc/ex backend
val respQueueDepth = 4 // FIXME: parameterize
val respQueueDepth = 8 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth)
val respQueueB = Queue(respBTagged, respQueueDepth)
@@ -369,6 +365,7 @@ class TensorCoreDecoupled(
// Operand selection
//
// select the correct 4x4 tile from A operand buffer
val numTilesM = tilingParams.m / tilingParams.mc
val numTilesMBits = log2Ceil(numTilesM)
def selectOperandA(buf: Vec[UInt]): UInt = {
require(buf.length == numIndices)
@@ -383,7 +380,7 @@ class TensorCoreDecoupled(
dontTouch(operandATag)
dontTouch(operandBTag)
// Operand buffer dequeue logic
// Operand buffer logic
//
// hold A data until the entire set is done
val shouldDequeueAMask = ((1 << stepBits) - 1).U
@@ -476,8 +473,8 @@ class TensorCoreDecoupled(
}
io.writeback.bits.data := flattenedDPUOut
// Writeback queues
// ----------------
// Writeback logic
//
// These queues hold metadata needed for writeback in sync with the DPU.
class TensorComputeTag extends Bundle {
@@ -530,28 +527,7 @@ class TensorCoreDecoupled(
}
}
}
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
switch(state) {
is(TensorState.idle) {
when(io.initiate.fire) {
state := TensorState.run
}
}
is(TensorState.run) {
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
when (state === TensorState.run) {
state := TensorState.finish
}
}
}
is(TensorState.finish) {
when(io.writeback.fire) {
state := TensorState.idle
}
}
}
}
// A buffer that collects multiple entries of input data and exposes the