tensor: Write staging pipeline for A tile

This commit is contained in:
Hansung Kim
2024-10-16 21:21:48 -07:00
parent 444dd5d7e1
commit 77dae3e1f9
2 changed files with 83 additions and 21 deletions

View File

@@ -108,8 +108,12 @@ class TensorCoreDecoupled(
// set and step being currently accessed in the acc/ex frontend
val setAccess = RegInit(0.U(setBits.W))
val stepAccess = RegInit(0.U(stepBits.W))
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
val substepAccess = RegInit(0.U(1.W))
dontTouch(setAccess)
dontTouch(stepAccess)
dontTouch(substepAccess)
when(io.initiate.fire) {
val wid = io.initiate.bits.wid
@@ -139,16 +143,19 @@ class TensorCoreDecoupled(
class TensorMemTag extends Bundle {
val set = UInt(setBits.W)
val step = UInt(stepBits.W)
val substep = UInt(1.W)
}
// use concatenation of set/step as the memory request source. This will get
// translated to the actual TL sourcewidth in sourceGen.
val tag = Wire(new TensorMemTag)
tag.set := setAccess
tag.step := stepAccess
tag.substep := substepAccess
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach {
Seq((io.reqA, (io.respA, respATagged)),
(io.reqB, (io.respB, respBTagged))).foreach {
case (req, (resp, respTagged)) => {
val sourceGen = Module(new SourceGenerator(
log2Ceil(numSourceIds),
@@ -173,18 +180,22 @@ class TensorCoreDecoupled(
}
// only advance to the next step if we fired mem requests for both A and B
// TODO: @perf: too strict? should be able to have A and B progress
// separately
val firedABReg = RegInit(VecInit(false.B, false.B))
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
case (req, fired) => { when (req.fire) { fired := true.B } }
req.fire
})
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
val nextStepAccess = firedAB.andR
// clear out firedABReg every step. this will overwrite the previous fired
// write upon the last fire out of A and B
when (nextStepAccess) {
val nextSubstepAccess = firedAB.andR
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
// clear out firedABReg every substep
when (nextSubstepAccess) {
firedABReg := Seq(false.B, false.B)
substepAccess := substepAccess + 1.U
}
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
// Execute stage
// -------------
@@ -204,22 +215,72 @@ class TensorCoreDecoupled(
io.writeback.bits.data.widthOption.get,
"response data width does not match the writeback data width")
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
// assume in-order response and that A/B responses are always aligned; this
// might be too strong an assumption depending on the backing memory
when (bothQueueValid) {
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
"A and B response queue pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
// dequeue is synchronized between A and B
// FIXME: this need to change to dpu_ready
val deqResp = bothQueueValid && io.writeback.ready
respQueueA.ready := deqResp
respQueueB.ready := deqResp
// FIXME: this need to change to dpu_fire
val nextStepExecute = io.writeback.fire
val dpuReady = io.writeback.ready // FIXME: this need be actual dpu
val substepExecute = RegInit(0.U(1.W))
when (respQueueA.fire) {
substepExecute := substepExecute + 1.U
}
dontTouch(substepExecute)
// note combinationally coupled ready with `pipe`
val halfAQueue = Module(new Queue(
chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true
))
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
halfAQueue.io.enq.bits := respQueueA.bits.data
// we need the full data for A because we divide the D tile by half along N;
// for B, the DPU can immediately start computing with a 4x2 tile.
//
// substep == 0 data goes to the LSB
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits)
val fullAQueue = Module(new Queue(
chiselTypeOf(fullAEnqData), entries = 1, pipe = true
))
// hold first half A data for the first substep
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
fullAQueue.io.enq.ready
require(fullAEnqData.widthOption.get == dataWidth * 2,
"assumes 2-cycle read for a full compute tile of A")
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
halfAQueue.io.deq.valid
fullAQueue.io.enq.bits := fullAEnqData
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
val dpuFire = operandsValid && dpuReady
fullAQueue.io.deq.ready := dpuFire
val nextStepExecute = dpuFire
// FIXME: need to hold A for two cycles!!
// make sure to dequeue from response queues only when both A and B valid
respQueueA.ready := MuxCase(false.B,
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
respQueueB.ready := dpuFire
dontTouch(respQueueA)
dontTouch(respQueueB)
// assert that the A and B response queue heads always point to the same
// set/step/substep
//
// this assumes that memory responses come back in-order. this might be too
// strong an assumption depending on the backing memory
def assertAligned = {
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
when (bothQueueValid && (substepExecute === 0.U)) {
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
"A and B response queue pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
dontTouch(respQueueA.bits.tag)
dontTouch(respQueueB.bits.tag)
}
assertAligned
def rdGen(set: UInt, step: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
@@ -229,7 +290,7 @@ class TensorCoreDecoupled(
// FIXME: add substep here
}
io.writeback.valid := bothQueueValid
io.writeback.valid := operandsValid // FIXME: bypass logic
io.writeback.bits.wid := warpReg
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)

View File

@@ -27,6 +27,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
val b = Vec(dotProductDim, Bits((inFLen).W))
val c = Bits((outFLen).W) // note C has the out length for accumulation
}))
// 'stall' is effectively out.ready, combinationally coupled to in.ready
val stall = Input(Bool())
val out = Valid(new Bundle {
val data = Bits((outFLen).W)