tensor: Hold step until req fired for both A and B
This commit is contained in:
@@ -95,23 +95,10 @@ class TensorCoreDecoupled(
|
|||||||
busy := false.B
|
busy := false.B
|
||||||
}
|
}
|
||||||
|
|
||||||
// set/step sequencing logic
|
|
||||||
val nextStep = true.B // TODO
|
|
||||||
val lastSet = ((1 << setBits) - 1)
|
|
||||||
val lastStep = ((1 << stepBits) - 1)
|
|
||||||
val setDone = (set === lastSet.U)
|
|
||||||
val stepDone = (step === lastStep.U)
|
|
||||||
when (nextStep) {
|
|
||||||
step := (step + 1.U) & lastStep.U
|
|
||||||
when (stepDone) {
|
|
||||||
set := (set + 1.U) & lastSet.U
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// memory traffic generation
|
// memory traffic generation
|
||||||
val genReq = (state === TensorState.run)
|
val genReq = (state === TensorState.run)
|
||||||
|
|
||||||
List((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
|
Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
|
||||||
case (req, resp) => {
|
case (req, resp) => {
|
||||||
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
|
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
|
||||||
|
|
||||||
@@ -126,9 +113,35 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// only advance to the next step if we fired mem requests for both A and B
|
||||||
|
val firedABReg = RegInit(VecInit(false.B, false.B))
|
||||||
|
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
||||||
|
case (req, fired) => { when (req.fire) { fired := true.B } }
|
||||||
|
req.fire
|
||||||
|
})
|
||||||
|
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||||
|
val nextStep = firedAB.andR
|
||||||
|
// clear out firedABReg every step. this will overwrite the previous fired
|
||||||
|
// write upon the last fire out of A and B
|
||||||
|
when (nextStep) {
|
||||||
|
firedABReg := Seq(false.B, false.B)
|
||||||
|
}
|
||||||
|
|
||||||
io.respA.ready := true.B
|
io.respA.ready := true.B
|
||||||
io.respB.ready := true.B
|
io.respB.ready := true.B
|
||||||
|
|
||||||
|
// set/step sequencing logic
|
||||||
|
val lastSet = ((1 << setBits) - 1)
|
||||||
|
val lastStep = ((1 << stepBits) - 1)
|
||||||
|
val setDone = (set === lastSet.U)
|
||||||
|
val stepDone = (step === lastStep.U)
|
||||||
|
when (nextStep) {
|
||||||
|
step := (step + 1.U) & lastStep.U
|
||||||
|
when (stepDone) {
|
||||||
|
set := (set + 1.U) & lastSet.U
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// state transition logic
|
// state transition logic
|
||||||
switch(state) {
|
switch(state) {
|
||||||
is(TensorState.idle) {
|
is(TensorState.idle) {
|
||||||
@@ -213,8 +226,8 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
8, 8, outer.numSrcIds , TensorTilingParams()))
|
8, 8, outer.numSrcIds , TensorTilingParams()))
|
||||||
val wordSize = 4 // FIXME: hardcoded
|
val wordSize = 4 // FIXME: hardcoded
|
||||||
|
|
||||||
val zip = List((outer.node.out(0), tensor.io.reqA),
|
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
||||||
(outer.node.out(1), tensor.io.reqB))
|
(outer.node.out(1), tensor.io.reqB))
|
||||||
zip.foreach { case ((tl, edge), req) =>
|
zip.foreach { case ((tl, edge), req) =>
|
||||||
tl.a.valid := req.valid
|
tl.a.valid := req.valid
|
||||||
val (legal, bits) = edge.Get(
|
val (legal, bits) = edge.Get(
|
||||||
|
|||||||
Reference in New Issue
Block a user