tensor: Write staging pipeline for A tile
This commit is contained in:
@@ -108,8 +108,12 @@ class TensorCoreDecoupled(
|
||||
// set and step being currently accessed in the acc/ex frontend
|
||||
val setAccess = RegInit(0.U(setBits.W))
|
||||
val stepAccess = RegInit(0.U(stepBits.W))
|
||||
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
|
||||
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
|
||||
val substepAccess = RegInit(0.U(1.W))
|
||||
dontTouch(setAccess)
|
||||
dontTouch(stepAccess)
|
||||
dontTouch(substepAccess)
|
||||
|
||||
when(io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
@@ -139,16 +143,19 @@ class TensorCoreDecoupled(
|
||||
class TensorMemTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
val substep = UInt(1.W)
|
||||
}
|
||||
// use concatenation of set/step as the memory request source. This will get
|
||||
// translated to the actual TL sourcewidth in sourceGen.
|
||||
val tag = Wire(new TensorMemTag)
|
||||
tag.set := setAccess
|
||||
tag.step := stepAccess
|
||||
tag.substep := substepAccess
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach {
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
(io.reqB, (io.respB, respBTagged))).foreach {
|
||||
case (req, (resp, respTagged)) => {
|
||||
val sourceGen = Module(new SourceGenerator(
|
||||
log2Ceil(numSourceIds),
|
||||
@@ -173,18 +180,22 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
|
||||
// only advance to the next step if we fired mem requests for both A and B
|
||||
// TODO: @perf: too strict? should be able to have A and B progress
|
||||
// separately
|
||||
val firedABReg = RegInit(VecInit(false.B, false.B))
|
||||
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
||||
case (req, fired) => { when (req.fire) { fired := true.B } }
|
||||
req.fire
|
||||
})
|
||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||
val nextStepAccess = firedAB.andR
|
||||
// clear out firedABReg every step. this will overwrite the previous fired
|
||||
// write upon the last fire out of A and B
|
||||
when (nextStepAccess) {
|
||||
val nextSubstepAccess = firedAB.andR
|
||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||
// clear out firedABReg every substep
|
||||
when (nextSubstepAccess) {
|
||||
firedABReg := Seq(false.B, false.B)
|
||||
substepAccess := substepAccess + 1.U
|
||||
}
|
||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||
|
||||
// Execute stage
|
||||
// -------------
|
||||
@@ -204,22 +215,72 @@ class TensorCoreDecoupled(
|
||||
io.writeback.bits.data.widthOption.get,
|
||||
"response data width does not match the writeback data width")
|
||||
|
||||
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
|
||||
// assume in-order response and that A/B responses are always aligned; this
|
||||
// might be too strong an assumption depending on the backing memory
|
||||
when (bothQueueValid) {
|
||||
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
|
||||
"A and B response queue pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
// dequeue is synchronized between A and B
|
||||
// FIXME: this need to change to dpu_ready
|
||||
val deqResp = bothQueueValid && io.writeback.ready
|
||||
respQueueA.ready := deqResp
|
||||
respQueueB.ready := deqResp
|
||||
// FIXME: this need to change to dpu_fire
|
||||
val nextStepExecute = io.writeback.fire
|
||||
val dpuReady = io.writeback.ready // FIXME: this need be actual dpu
|
||||
|
||||
val substepExecute = RegInit(0.U(1.W))
|
||||
when (respQueueA.fire) {
|
||||
substepExecute := substepExecute + 1.U
|
||||
}
|
||||
dontTouch(substepExecute)
|
||||
|
||||
// note combinationally coupled ready with `pipe`
|
||||
val halfAQueue = Module(new Queue(
|
||||
chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true
|
||||
))
|
||||
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
|
||||
halfAQueue.io.enq.bits := respQueueA.bits.data
|
||||
|
||||
// we need the full data for A because we divide the D tile by half along N;
|
||||
// for B, the DPU can immediately start computing with a 4x2 tile.
|
||||
//
|
||||
// substep == 0 data goes to the LSB
|
||||
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits)
|
||||
val fullAQueue = Module(new Queue(
|
||||
chiselTypeOf(fullAEnqData), entries = 1, pipe = true
|
||||
))
|
||||
// hold first half A data for the first substep
|
||||
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
|
||||
fullAQueue.io.enq.ready
|
||||
|
||||
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
||||
"assumes 2-cycle read for a full compute tile of A")
|
||||
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
|
||||
halfAQueue.io.deq.valid
|
||||
fullAQueue.io.enq.bits := fullAEnqData
|
||||
|
||||
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
fullAQueue.io.deq.ready := dpuFire
|
||||
val nextStepExecute = dpuFire
|
||||
|
||||
// FIXME: need to hold A for two cycles!!
|
||||
|
||||
// make sure to dequeue from response queues only when both A and B valid
|
||||
respQueueA.ready := MuxCase(false.B,
|
||||
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
|
||||
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
|
||||
respQueueB.ready := dpuFire
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
|
||||
// assert that the A and B response queue heads always point to the same
|
||||
// set/step/substep
|
||||
//
|
||||
// this assumes that memory responses come back in-order. this might be too
|
||||
// strong an assumption depending on the backing memory
|
||||
def assertAligned = {
|
||||
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
|
||||
when (bothQueueValid && (substepExecute === 0.U)) {
|
||||
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
|
||||
"A and B response queue pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
dontTouch(respQueueA.bits.tag)
|
||||
dontTouch(respQueueB.bits.tag)
|
||||
}
|
||||
assertAligned
|
||||
|
||||
def rdGen(set: UInt, step: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
@@ -229,7 +290,7 @@ class TensorCoreDecoupled(
|
||||
// FIXME: add substep here
|
||||
}
|
||||
|
||||
io.writeback.valid := bothQueueValid
|
||||
io.writeback.valid := operandsValid // FIXME: bypass logic
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||
|
||||
@@ -27,6 +27,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
||||
val b = Vec(dotProductDim, Bits((inFLen).W))
|
||||
val c = Bits((outFLen).W) // note C has the out length for accumulation
|
||||
}))
|
||||
// 'stall' is effectively out.ready, combinationally coupled to in.ready
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((outFLen).W)
|
||||
|
||||
Reference in New Issue
Block a user