tensor: Write staging pipeline for A tile
This commit is contained in:
@@ -108,8 +108,12 @@ class TensorCoreDecoupled(
|
|||||||
// set and step being currently accessed in the acc/ex frontend
|
// set and step being currently accessed in the acc/ex frontend
|
||||||
val setAccess = RegInit(0.U(setBits.W))
|
val setAccess = RegInit(0.U(setBits.W))
|
||||||
val stepAccess = RegInit(0.U(stepBits.W))
|
val stepAccess = RegInit(0.U(stepBits.W))
|
||||||
|
// we need full 4x4 A tile to fire DPU, but since the memory width is 8
|
||||||
|
// words, we need 2 cycles to read A. `substep` tells which cycle we're at.
|
||||||
|
val substepAccess = RegInit(0.U(1.W))
|
||||||
dontTouch(setAccess)
|
dontTouch(setAccess)
|
||||||
dontTouch(stepAccess)
|
dontTouch(stepAccess)
|
||||||
|
dontTouch(substepAccess)
|
||||||
|
|
||||||
when(io.initiate.fire) {
|
when(io.initiate.fire) {
|
||||||
val wid = io.initiate.bits.wid
|
val wid = io.initiate.bits.wid
|
||||||
@@ -139,16 +143,19 @@ class TensorCoreDecoupled(
|
|||||||
class TensorMemTag extends Bundle {
|
class TensorMemTag extends Bundle {
|
||||||
val set = UInt(setBits.W)
|
val set = UInt(setBits.W)
|
||||||
val step = UInt(stepBits.W)
|
val step = UInt(stepBits.W)
|
||||||
|
val substep = UInt(1.W)
|
||||||
}
|
}
|
||||||
// use concatenation of set/step as the memory request source. This will get
|
// use concatenation of set/step as the memory request source. This will get
|
||||||
// translated to the actual TL sourcewidth in sourceGen.
|
// translated to the actual TL sourcewidth in sourceGen.
|
||||||
val tag = Wire(new TensorMemTag)
|
val tag = Wire(new TensorMemTag)
|
||||||
tag.set := setAccess
|
tag.set := setAccess
|
||||||
tag.step := stepAccess
|
tag.step := stepAccess
|
||||||
|
tag.substep := substepAccess
|
||||||
|
|
||||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
Seq((io.reqA, (io.respA, respATagged)), (io.reqB, (io.respB, respBTagged))).foreach {
|
Seq((io.reqA, (io.respA, respATagged)),
|
||||||
|
(io.reqB, (io.respB, respBTagged))).foreach {
|
||||||
case (req, (resp, respTagged)) => {
|
case (req, (resp, respTagged)) => {
|
||||||
val sourceGen = Module(new SourceGenerator(
|
val sourceGen = Module(new SourceGenerator(
|
||||||
log2Ceil(numSourceIds),
|
log2Ceil(numSourceIds),
|
||||||
@@ -173,18 +180,22 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only advance to the next step if we fired mem requests for both A and B
|
// only advance to the next step if we fired mem requests for both A and B
|
||||||
|
// TODO: @perf: too strict? should be able to have A and B progress
|
||||||
|
// separately
|
||||||
val firedABReg = RegInit(VecInit(false.B, false.B))
|
val firedABReg = RegInit(VecInit(false.B, false.B))
|
||||||
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
||||||
case (req, fired) => { when (req.fire) { fired := true.B } }
|
case (req, fired) => { when (req.fire) { fired := true.B } }
|
||||||
req.fire
|
req.fire
|
||||||
})
|
})
|
||||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||||
val nextStepAccess = firedAB.andR
|
val nextSubstepAccess = firedAB.andR
|
||||||
// clear out firedABReg every step. this will overwrite the previous fired
|
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||||
// write upon the last fire out of A and B
|
// clear out firedABReg every substep
|
||||||
when (nextStepAccess) {
|
when (nextSubstepAccess) {
|
||||||
firedABReg := Seq(false.B, false.B)
|
firedABReg := Seq(false.B, false.B)
|
||||||
|
substepAccess := substepAccess + 1.U
|
||||||
}
|
}
|
||||||
|
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||||
|
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// -------------
|
// -------------
|
||||||
@@ -204,22 +215,72 @@ class TensorCoreDecoupled(
|
|||||||
io.writeback.bits.data.widthOption.get,
|
io.writeback.bits.data.widthOption.get,
|
||||||
"response data width does not match the writeback data width")
|
"response data width does not match the writeback data width")
|
||||||
|
|
||||||
|
// FIXME: this need to change to dpu_ready
|
||||||
|
val dpuReady = io.writeback.ready // FIXME: this need be actual dpu
|
||||||
|
|
||||||
|
val substepExecute = RegInit(0.U(1.W))
|
||||||
|
when (respQueueA.fire) {
|
||||||
|
substepExecute := substepExecute + 1.U
|
||||||
|
}
|
||||||
|
dontTouch(substepExecute)
|
||||||
|
|
||||||
|
// note combinationally coupled ready with `pipe`
|
||||||
|
val halfAQueue = Module(new Queue(
|
||||||
|
chiselTypeOf(respQueueA.bits.data), entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
halfAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 0.U)
|
||||||
|
halfAQueue.io.enq.bits := respQueueA.bits.data
|
||||||
|
|
||||||
|
// we need the full data for A because we divide the D tile by half along N;
|
||||||
|
// for B, the DPU can immediately start computing with a 4x2 tile.
|
||||||
|
//
|
||||||
|
// substep == 0 data goes to the LSB
|
||||||
|
val fullAEnqData = Cat(respQueueA.bits.data, halfAQueue.io.deq.bits)
|
||||||
|
val fullAQueue = Module(new Queue(
|
||||||
|
chiselTypeOf(fullAEnqData), entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
// hold first half A data for the first substep
|
||||||
|
halfAQueue.io.deq.ready := respQueueA.valid && (substepExecute === 1.U) &&
|
||||||
|
fullAQueue.io.enq.ready
|
||||||
|
|
||||||
|
require(fullAEnqData.widthOption.get == dataWidth * 2,
|
||||||
|
"assumes 2-cycle read for a full compute tile of A")
|
||||||
|
fullAQueue.io.enq.valid := respQueueA.valid && (substepExecute === 1.U) &&
|
||||||
|
halfAQueue.io.deq.valid
|
||||||
|
fullAQueue.io.enq.bits := fullAEnqData
|
||||||
|
|
||||||
|
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
||||||
|
val dpuFire = operandsValid && dpuReady
|
||||||
|
fullAQueue.io.deq.ready := dpuFire
|
||||||
|
val nextStepExecute = dpuFire
|
||||||
|
|
||||||
|
// FIXME: need to hold A for two cycles!!
|
||||||
|
|
||||||
|
// make sure to dequeue from response queues only when both A and B valid
|
||||||
|
respQueueA.ready := MuxCase(false.B,
|
||||||
|
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
|
||||||
|
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
|
||||||
|
respQueueB.ready := dpuFire
|
||||||
|
dontTouch(respQueueA)
|
||||||
|
dontTouch(respQueueB)
|
||||||
|
|
||||||
|
// assert that the A and B response queue heads always point to the same
|
||||||
|
// set/step/substep
|
||||||
|
//
|
||||||
|
// this assumes that memory responses come back in-order. this might be too
|
||||||
|
// strong an assumption depending on the backing memory
|
||||||
|
def assertAligned = {
|
||||||
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
|
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
|
||||||
// assume in-order response and that A/B responses are always aligned; this
|
when (bothQueueValid && (substepExecute === 0.U)) {
|
||||||
// might be too strong an assumption depending on the backing memory
|
|
||||||
when (bothQueueValid) {
|
|
||||||
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
|
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||||
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
|
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
|
||||||
"A and B response queue pointing to different set/steps. " ++
|
"A and B response queue pointing to different set/steps. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
}
|
}
|
||||||
// dequeue is synchronized between A and B
|
dontTouch(respQueueA.bits.tag)
|
||||||
// FIXME: this need to change to dpu_ready
|
dontTouch(respQueueB.bits.tag)
|
||||||
val deqResp = bothQueueValid && io.writeback.ready
|
}
|
||||||
respQueueA.ready := deqResp
|
assertAligned
|
||||||
respQueueB.ready := deqResp
|
|
||||||
// FIXME: this need to change to dpu_fire
|
|
||||||
val nextStepExecute = io.writeback.fire
|
|
||||||
|
|
||||||
def rdGen(set: UInt, step: UInt): UInt = {
|
def rdGen(set: UInt, step: UInt): UInt = {
|
||||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||||
@@ -229,7 +290,7 @@ class TensorCoreDecoupled(
|
|||||||
// FIXME: add substep here
|
// FIXME: add substep here
|
||||||
}
|
}
|
||||||
|
|
||||||
io.writeback.valid := bothQueueValid
|
io.writeback.valid := operandsValid // FIXME: bypass logic
|
||||||
io.writeback.bits.wid := warpReg
|
io.writeback.bits.wid := warpReg
|
||||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
|||||||
val b = Vec(dotProductDim, Bits((inFLen).W))
|
val b = Vec(dotProductDim, Bits((inFLen).W))
|
||||||
val c = Bits((outFLen).W) // note C has the out length for accumulation
|
val c = Bits((outFLen).W) // note C has the out length for accumulation
|
||||||
}))
|
}))
|
||||||
|
// 'stall' is effectively out.ready, combinationally coupled to in.ready
|
||||||
val stall = Input(Bool())
|
val stall = Input(Bool())
|
||||||
val out = Valid(new Bundle {
|
val out = Valid(new Bundle {
|
||||||
val data = Bits((outFLen).W)
|
val data = Bits((outFLen).W)
|
||||||
|
|||||||
Reference in New Issue
Block a user