tensor: Hold step until req fired for both A and B

This commit is contained in:
Hansung Kim
2024-10-14 22:06:58 -07:00
parent 14a640bf2d
commit 8d2e13b4ee

View File

@@ -95,23 +95,10 @@ class TensorCoreDecoupled(
busy := false.B
}
// set/step sequencing logic
val nextStep = true.B // TODO
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1)
val setDone = (set === lastSet.U)
val stepDone = (step === lastStep.U)
when (nextStep) {
step := (step + 1.U) & lastStep.U
when (stepDone) {
set := (set + 1.U) & lastSet.U
}
}
// memory traffic generation
val genReq = (state === TensorState.run)
List((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
case (req, resp) => {
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
@@ -126,9 +113,35 @@ class TensorCoreDecoupled(
}
}
// only advance to the next step if we fired mem requests for both A and B
val firedABReg = RegInit(VecInit(false.B, false.B))
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
case (req, fired) => { when (req.fire) { fired := true.B } }
req.fire
})
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
val nextStep = firedAB.andR
// clear out firedABReg every step. this will overwrite the previous fired
// write upon the last fire out of A and B
when (nextStep) {
firedABReg := Seq(false.B, false.B)
}
io.respA.ready := true.B
io.respB.ready := true.B
// set/step sequencing logic
val lastSet = ((1 << setBits) - 1)
val lastStep = ((1 << stepBits) - 1)
val setDone = (set === lastSet.U)
val stepDone = (step === lastStep.U)
when (nextStep) {
step := (step + 1.U) & lastStep.U
when (stepDone) {
set := (set + 1.U) & lastSet.U
}
}
// state transition logic
switch(state) {
is(TensorState.idle) {
@@ -213,8 +226,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
8, 8, outer.numSrcIds , TensorTilingParams()))
val wordSize = 4 // FIXME: hardcoded
val zip = List((outer.node.out(0), tensor.io.reqA),
(outer.node.out(1), tensor.io.reqB))
val zip = Seq((outer.node.out(0), tensor.io.reqA),
(outer.node.out(1), tensor.io.reqB))
zip.foreach { case ((tl, edge), req) =>
tl.a.valid := req.valid
val (legal, bits) = edge.Get(