tensor: Properly route FillBuffer to DPU
This commit is contained in:
@@ -313,17 +313,24 @@ class TensorCoreDecoupled(
|
||||
fullAQueue.io.enq.bits.data := fullAEnqData
|
||||
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
||||
|
||||
val fillBufB = Module(new FillBuffer(
|
||||
// serialize every two B responses into one full 4x4 B tile
|
||||
// FIXME: do the same for A
|
||||
val fullB = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
|
||||
))
|
||||
fillBufB.io.enq.valid := respQueueB.valid
|
||||
fillBufB.io.enq.bits := respQueueB.bits.data
|
||||
respQueueB.ready := fillBufB.io.enq.ready
|
||||
fullB.io.enq.valid := respQueueB.valid
|
||||
fullB.io.enq.bits := respQueueB.bits.data
|
||||
respQueueB.ready := fullB.io.enq.ready
|
||||
val fullBTag = Module(new Queue(
|
||||
new TensorMemTag, entries = 1, pipe = true
|
||||
))
|
||||
fullBTag.io.enq.valid := respQueueB.valid
|
||||
fullBTag.io.enq.bits := respQueueB.bits.tag
|
||||
|
||||
val operandsValid = fullAQueue.io.deq.valid && fillBufB.io.deq.valid
|
||||
val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid
|
||||
val operandA = fullAQueue.io.deq.bits.data
|
||||
val operandATag = fullAQueue.io.deq.bits.tag
|
||||
val operandB = fillBufB.io.deq.bits
|
||||
val operandB = fullB.io.deq.bits
|
||||
val dpuReady = Wire(Bool())
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
val setCompute = fullAQueue.io.deq.bits.tag.set
|
||||
@@ -333,10 +340,6 @@ class TensorCoreDecoupled(
|
||||
substepCompute := substepCompute + 1.U
|
||||
}
|
||||
|
||||
// hold full A until two-cycle compute is done
|
||||
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||
|
||||
// respQueueA output arbitrates to either halfAQueue or fullAQueue depending
|
||||
// on the substep
|
||||
respQueueA.ready := MuxCase(false.B,
|
||||
@@ -345,12 +348,19 @@ class TensorCoreDecoupled(
|
||||
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||
// we fully iterated a column (M-dimension).
|
||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||
val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask
|
||||
fillBufB.io.deq.ready := dpuFire && shouldDequeueB
|
||||
val shouldDequeueB =
|
||||
((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullB.io.deq.ready := dpuFire && shouldDequeueB
|
||||
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
dontTouch(shouldDequeueB)
|
||||
|
||||
// hold full A until two-cycle compute is done
|
||||
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
|
||||
val nextStepExecute = dpuFire && (substepCompute === 1.U)
|
||||
|
||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||
// that the B resp will only have step values multiple of 4 due to reuse.
|
||||
//
|
||||
@@ -359,9 +369,9 @@ class TensorCoreDecoupled(
|
||||
def assertAligned = {
|
||||
val stepMask = (1 << numTilesMBits).U
|
||||
when (dpuFire) {
|
||||
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||
assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
|
||||
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
|
||||
(respQueueB.bits.tag.step & stepMask)),
|
||||
(fullBTag.io.deq.bits.step & stepMask)),
|
||||
"A and B operands are pointing to different set/steps. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
@@ -387,7 +397,7 @@ class TensorCoreDecoupled(
|
||||
// operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
// .grouped(4/*k-dim*/).toSeq
|
||||
val operandBDimensional =
|
||||
operandB(0)/*FIXME!*/.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
|
||||
Reference in New Issue
Block a user