diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 7fa05ee..05fe576 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -95,23 +95,10 @@ class TensorCoreDecoupled( busy := false.B } - // set/step sequencing logic - val nextStep = true.B // TODO - val lastSet = ((1 << setBits) - 1) - val lastStep = ((1 << stepBits) - 1) - val setDone = (set === lastSet.U) - val stepDone = (step === lastStep.U) - when (nextStep) { - step := (step + 1.U) & lastStep.U - when (stepDone) { - set := (set + 1.U) & lastSet.U - } - } - // memory traffic generation val genReq = (state === TensorState.run) - List((io.reqA, io.respA), (io.reqB, io.respB)).foreach { + Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach { case (req, resp) => { val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds))) @@ -126,9 +113,35 @@ class TensorCoreDecoupled( } } + // only advance to the next step if we fired mem requests for both A and B + val firedABReg = RegInit(VecInit(false.B, false.B)) + val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map { + case (req, fired) => { when (req.fire) { fired := true.B } } + req.fire + }) + val firedAB = (firedABNow.asUInt | firedABReg.asUInt) + val nextStep = firedAB.andR + // clear out firedABReg every step. this will overwrite the previous fired + // write upon the last fire out of A and B + when (nextStep) { + firedABReg := Seq(false.B, false.B) + } + io.respA.ready := true.B io.respB.ready := true.B + // set/step sequencing logic + val lastSet = ((1 << setBits) - 1) + val lastStep = ((1 << stepBits) - 1) + val setDone = (set === lastSet.U) + val stepDone = (step === lastStep.U) + when (nextStep) { + step := (step + 1.U) & lastStep.U + when (stepDone) { + set := (set + 1.U) & lastSet.U + } + } + // state transition logic switch(state) { is(TensorState.idle) { @@ -213,8 +226,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) 8, 8, outer.numSrcIds , TensorTilingParams())) val wordSize = 4 // FIXME: hardcoded - val zip = List((outer.node.out(0), tensor.io.reqA), - (outer.node.out(1), tensor.io.reqB)) + val zip = Seq((outer.node.out(0), tensor.io.reqA), + (outer.node.out(1), tensor.io.reqB)) zip.foreach { case ((tl, edge), req) => tl.a.valid := req.valid val (legal, bits) = edge.Get(