tensor: Assert alignment of A and B response queues

This commit is contained in:
Hansung Kim
2024-10-15 17:08:14 -07:00
parent de393115cd
commit efaf599fbe

View File

@@ -97,15 +97,16 @@ class TensorCoreDecoupled(
// steps: i-j iteration // steps: i-j iteration
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc) val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
val stepBits = log2Ceil(numSteps) val stepBits = log2Ceil(numSteps)
val set = RegInit(0.U(setBits.W)) // set and step being currently accessed in the acc/ex frontend
val step = RegInit(0.U(stepBits.W)) val setAccess = RegInit(0.U(setBits.W))
val stepAccess = RegInit(0.U(stepBits.W))
when(io.initiate.fire) { when(io.initiate.fire) {
val wid = io.initiate.bits.wid val wid = io.initiate.bits.wid
busy := true.B busy := true.B
warpReg := wid warpReg := wid
set := 0.U setAccess := 0.U
step := 0.U stepAccess := 0.U
when(io.writeback.fire) { when(io.writeback.fire) {
assert( assert(
io.writeback.bits.wid =/= wid, io.writeback.bits.wid =/= wid,
@@ -129,8 +130,8 @@ class TensorCoreDecoupled(
// use concatenation of set/step as the memory request source. This will get // use concatenation of set/step as the memory request source. This will get
// translated to the actual TL sourcewidth in sourceGen. // translated to the actual TL sourcewidth in sourceGen.
val tag = Wire(new TensorMemTag) val tag = Wire(new TensorMemTag)
tag.set := set tag.set := setAccess
tag.step := step tag.step := stepAccess
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth))) val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
@@ -176,16 +177,32 @@ class TensorCoreDecoupled(
// ------------- // -------------
// Backend of the decoupled access/execute pipeline. // Backend of the decoupled access/execute pipeline.
// //
// set and step being currently executed in the acc/ex backend
val setExecute = RegInit(0.U(setBits.W))
val stepExecute = RegInit(0.U(stepBits.W))
val respQueueDepth = 4 // FIXME: parameterize val respQueueDepth = 4 // FIXME: parameterize
val respQueueA = Queue(respATagged, respQueueDepth) val respQueueA = Queue(respATagged, respQueueDepth)
val respQueueB = Queue(respBTagged, respQueueDepth) val respQueueB = Queue(respBTagged, respQueueDepth)
respQueueA.ready := io.writeback.ready // FIXME
respQueueB.ready := io.writeback.ready // FIXME
require(respQueueA.bits.data.widthOption.get == require(respQueueA.bits.data.widthOption.get ==
io.writeback.bits.data.widthOption.get, io.writeback.bits.data.widthOption.get,
"response data width does not match the writeback data width") "response data width does not match the writeback data width")
val bothQueueValid = (respQueueA.valid && respQueueB.valid)
// assume in-order response and that A/B responses are always aligned; this
// might be too strong an assumption depending on the backing memory
when (bothQueueValid) {
assert((respQueueA.bits.tag.set === respQueueB.bits.tag.set) &&
(respQueueA.bits.tag.step === respQueueB.bits.tag.step),
"A and B response queue pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
// synchronized dequeue
val deqResp = bothQueueValid && io.writeback.ready
respQueueA.ready := deqResp
respQueueB.ready := deqResp
// FIXME: debug dummy: pipe A directly to writeback // FIXME: debug dummy: pipe A directly to writeback
io.writeback.valid := respQueueA.valid io.writeback.valid := respQueueA.valid
val groupedRespA = respQueueA.bits.data val groupedRespA = respQueueA.bits.data
@@ -201,12 +218,12 @@ class TensorCoreDecoupled(
// set/step sequencing logic // set/step sequencing logic
val lastSet = ((1 << setBits) - 1) val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1) val lastStep = ((1 << stepBits) - 1)
val setDone = (set === lastSet.U) val setDone = (setAccess === lastSet.U)
val stepDone = (step === lastStep.U) val stepDone = (stepAccess === lastStep.U)
when (nextStep) { when (nextStep) {
step := (step + 1.U) & lastStep.U stepAccess := (stepAccess + 1.U) & lastStep.U
when (stepDone) { when (stepDone) {
set := (set + 1.U) & lastSet.U setAccess := (setAccess + 1.U) & lastSet.U
} }
} }