tensor: Decouple A and B access states
Get rid of set/stepAccess states and let A and B access progress independently.
This commit is contained in:
@@ -82,15 +82,6 @@ class TensorCoreDecoupled(
|
||||
// This drives the overall pipeline of memory requests, dot-product unit
|
||||
// operations and regfile writeback.
|
||||
|
||||
object TensorState extends ChiselEnum {
|
||||
val idle = Value(0.U)
|
||||
val run = Value(1.U)
|
||||
// All set/step sequencing is complete and the tensor core is holding the
|
||||
// result data until downstream writeback is ready.
|
||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||
val finish = Value(2.U)
|
||||
}
|
||||
val state = RegInit(TensorState.idle)
|
||||
val busy = RegInit(false.B)
|
||||
// Holds the warp id the core is currently working on. Note that we only
|
||||
// support one outstanding warp request
|
||||
@@ -107,22 +98,10 @@ class TensorCoreDecoupled(
|
||||
def setDone(set: UInt) = (set === lastSet.U)
|
||||
def stepDone(step: UInt) = (step === lastStep.U)
|
||||
|
||||
// set and step being currently accessed in the acc/ex frontend
|
||||
val setAccess = RegInit(0.U(setBits.W))
|
||||
val stepAccess = RegInit(0.U(stepBits.W))
|
||||
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
|
||||
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
|
||||
val substepAccess = RegInit(0.U(1.W))
|
||||
dontTouch(setAccess)
|
||||
dontTouch(stepAccess)
|
||||
dontTouch(substepAccess)
|
||||
|
||||
when(io.initiate.fire) {
|
||||
when (io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
busy := true.B
|
||||
warpReg := wid
|
||||
setAccess := 0.U
|
||||
stepAccess := 0.U
|
||||
when(io.writeback.fire) {
|
||||
assert(
|
||||
io.writeback.bits.wid =/= wid,
|
||||
@@ -143,55 +122,51 @@ class TensorCoreDecoupled(
|
||||
// serialize every HGMMA request
|
||||
io.initiate.ready := !busy
|
||||
|
||||
// Memory traffic generation
|
||||
// -------------------------
|
||||
// ===========================================================================
|
||||
// Access stage
|
||||
// ===========================================================================
|
||||
//
|
||||
val numTilesM = tilingParams.m / tilingParams.mc
|
||||
val numTilesN = tilingParams.n / tilingParams.nc
|
||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
||||
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
|
||||
: (UInt/*A*/, UInt/*B*/) = {
|
||||
// note that step iterates along N first, then M
|
||||
val tileM = step % numTilesM.U
|
||||
val tileN = step / numTilesM.U
|
||||
// Frontend of the decoupled access/execute pipeline.
|
||||
|
||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||
// accesses
|
||||
//
|
||||
// (row,col) coordinate of the compute tile
|
||||
val tileRowA = tileM // M
|
||||
val tileColA = set // K
|
||||
val tileRowB = tileN // N
|
||||
val tileColB = set // K
|
||||
// (row,col) coordinate of the starting element of the compute tile
|
||||
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
|
||||
(substep << log2Ceil(tilingParams.mc / 2))
|
||||
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
||||
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
|
||||
(substep << log2Ceil(tilingParams.nc / 2))
|
||||
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
||||
val rowStrideA = wordSize * tilingParams.k
|
||||
val rowStrideABits = log2Ceil(rowStrideA)
|
||||
val rowStrideB = wordSize * tilingParams.k
|
||||
val rowStrideBBits = log2Ceil(rowStrideB)
|
||||
val wordStrideBits = log2Ceil(wordSize)
|
||||
|
||||
val tileOffsetA = (elemRowA << rowStrideABits) + (elemColA << wordStrideBits)
|
||||
val tileOffsetB = (elemRowB << rowStrideBBits) + (elemColB << wordStrideBits)
|
||||
|
||||
(baseA + tileOffsetA, baseB + tileOffsetB)
|
||||
// States
|
||||
//
|
||||
object AccessorState extends ChiselEnum {
|
||||
val idle = Value(0.U)
|
||||
val access = Value(1.U)
|
||||
// All set/step sequencing is complete and the tensor core is holding the
|
||||
// result data until downstream writeback is ready.
|
||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||
val finish = Value(2.U)
|
||||
}
|
||||
val state = RegInit(AccessorState.idle)
|
||||
val allReqsDone = WireInit(false.B)
|
||||
dontTouch(allReqsDone)
|
||||
|
||||
// FIXME: bogus base address
|
||||
val (addressA, addressB) =
|
||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
||||
switch(state) {
|
||||
is(AccessorState.idle) {
|
||||
when(io.initiate.fire) {
|
||||
state := AccessorState.access
|
||||
}
|
||||
}
|
||||
is(AccessorState.access) {
|
||||
when (allReqsDone) {
|
||||
state := AccessorState.finish
|
||||
}
|
||||
}
|
||||
is(AccessorState.finish) {
|
||||
// FIXME: decouple writeback
|
||||
when(io.writeback.fire) {
|
||||
state := AccessorState.idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 'index' is the index of a memory request among the sequence of requests
|
||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||
require(tilingParams.m == tilingParams.n,
|
||||
"currently only supports square SMEM tile")
|
||||
val numIndices = tilingParams.m / 2
|
||||
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
@@ -219,9 +194,51 @@ class TensorCoreDecoupled(
|
||||
tagB.index := tagB.index + 1.U
|
||||
}
|
||||
|
||||
val genReqA = (state === TensorState.run)
|
||||
val genReqB = (state === TensorState.run)
|
||||
// Address generation
|
||||
//
|
||||
def addressGen(base: UInt, set: UInt, index: UInt): UInt = {
|
||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||
// accesses, so that below code applies to both.
|
||||
//
|
||||
// (row,col) coordinate of the compute tile
|
||||
val tileRow = index
|
||||
val tileCol = set
|
||||
// (row,col) coordinate of the starting element of the compute tile
|
||||
val elemRow = index << 1
|
||||
val elemCol = tileCol << log2Ceil(tilingParams.kc)
|
||||
val rowStride = tilingParams.k * wordSize
|
||||
val rowStrideBits = log2Ceil(rowStride)
|
||||
val wordStrideBits = log2Ceil(wordSize)
|
||||
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
|
||||
|
||||
base + tileOffset
|
||||
}
|
||||
|
||||
// FIXME: bogus base address
|
||||
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
||||
val addressB = addressGen(0.U, tagB.set, tagB.index)
|
||||
|
||||
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
||||
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
|
||||
val doneReqA = RegInit(false.B)
|
||||
val doneReqB = RegInit(false.B)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||
val genReqA = (state === AccessorState.access) && !doneReqA
|
||||
val genReqB = (state === AccessorState.access) && !doneReqA
|
||||
when (state === AccessorState.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
tagA.set := 0.U
|
||||
tagA.index := 0.U
|
||||
tagB.set := 0.U
|
||||
tagB.index := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB
|
||||
|
||||
// Request generation
|
||||
//
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
@@ -249,34 +266,13 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
|
||||
// only advance to the next step if we fired mem requests for both A and B.
|
||||
// also consider that B doesn't have to be fired every time due to reuse.
|
||||
// @perf: too strict? should be able to have A and B progress separately
|
||||
val firedAReg = RegInit(false.B)
|
||||
val firedBReg = RegInit(false.B)
|
||||
when (io.reqA.fire) { firedAReg := true.B }
|
||||
when (io.reqB.fire) { firedBReg := true.B }
|
||||
val firedANow = io.reqA.fire
|
||||
val firedBNow = io.reqB.fire
|
||||
val firedA = firedAReg || firedANow
|
||||
val firedB = firedBReg || firedBNow
|
||||
val nextSubstepAccess = firedA && firedB
|
||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||
// clear out firedABReg every substep
|
||||
when (nextSubstepAccess) {
|
||||
firedAReg := false.B
|
||||
firedBReg := false.B
|
||||
substepAccess := substepAccess + 1.U
|
||||
}
|
||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||
|
||||
// ===========================================================================
|
||||
// Execute stage
|
||||
// -------------
|
||||
// ===========================================================================
|
||||
//
|
||||
// Backend of the decoupled access/execute pipeline.
|
||||
//
|
||||
// set and step being currently executed in the acc/ex backend
|
||||
|
||||
val respQueueDepth = 4 // FIXME: parameterize
|
||||
val respQueueDepth = 8 // FIXME: parameterize
|
||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||
val respQueueB = Queue(respBTagged, respQueueDepth)
|
||||
|
||||
@@ -369,6 +365,7 @@ class TensorCoreDecoupled(
|
||||
// Operand selection
|
||||
//
|
||||
// select the correct 4x4 tile from A operand buffer
|
||||
val numTilesM = tilingParams.m / tilingParams.mc
|
||||
val numTilesMBits = log2Ceil(numTilesM)
|
||||
def selectOperandA(buf: Vec[UInt]): UInt = {
|
||||
require(buf.length == numIndices)
|
||||
@@ -383,7 +380,7 @@ class TensorCoreDecoupled(
|
||||
dontTouch(operandATag)
|
||||
dontTouch(operandBTag)
|
||||
|
||||
// Operand buffer dequeue logic
|
||||
// Operand buffer logic
|
||||
//
|
||||
// hold A data until the entire set is done
|
||||
val shouldDequeueAMask = ((1 << stepBits) - 1).U
|
||||
@@ -476,8 +473,8 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
io.writeback.bits.data := flattenedDPUOut
|
||||
|
||||
// Writeback queues
|
||||
// ----------------
|
||||
// Writeback logic
|
||||
//
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
class TensorComputeTag extends Bundle {
|
||||
@@ -530,28 +527,7 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
}
|
||||
}
|
||||
sequenceSetStep(setAccess, stepAccess, nextStepAccess)
|
||||
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||
|
||||
switch(state) {
|
||||
is(TensorState.idle) {
|
||||
when(io.initiate.fire) {
|
||||
state := TensorState.run
|
||||
}
|
||||
}
|
||||
is(TensorState.run) {
|
||||
when (setDone(setAccess) && stepDone(stepAccess) && nextStepAccess) {
|
||||
when (state === TensorState.run) {
|
||||
state := TensorState.finish
|
||||
}
|
||||
}
|
||||
}
|
||||
is(TensorState.finish) {
|
||||
when(io.writeback.fire) {
|
||||
state := TensorState.idle
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A buffer that collects multiple entries of input data and exposes the
|
||||
|
||||
Reference in New Issue
Block a user