tensor: Properly route FillBuffer to DPU

This commit is contained in:
Hansung Kim
2024-10-18 17:33:55 -07:00
parent 91d9897c27
commit 7fab6f89ad

View File

@@ -313,17 +313,24 @@ class TensorCoreDecoupled(
fullAQueue.io.enq.bits.data := fullAEnqData fullAQueue.io.enq.bits.data := fullAEnqData
fullAQueue.io.enq.bits.tag := fullAEnqTag fullAQueue.io.enq.bits.tag := fullAEnqTag
val fillBufB = Module(new FillBuffer( // serialize every two B responses into one full 4x4 B tile
// FIXME: do the same for A
val fullB = Module(new FillBuffer(
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/ chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
)) ))
fillBufB.io.enq.valid := respQueueB.valid fullB.io.enq.valid := respQueueB.valid
fillBufB.io.enq.bits := respQueueB.bits.data fullB.io.enq.bits := respQueueB.bits.data
respQueueB.ready := fillBufB.io.enq.ready respQueueB.ready := fullB.io.enq.ready
val fullBTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullBTag.io.enq.valid := respQueueB.valid
fullBTag.io.enq.bits := respQueueB.bits.tag
val operandsValid = fullAQueue.io.deq.valid && fillBufB.io.deq.valid val operandsValid = fullAQueue.io.deq.valid && fullB.io.deq.valid
val operandA = fullAQueue.io.deq.bits.data val operandA = fullAQueue.io.deq.bits.data
val operandATag = fullAQueue.io.deq.bits.tag val operandATag = fullAQueue.io.deq.bits.tag
val operandB = fillBufB.io.deq.bits val operandB = fullB.io.deq.bits
val dpuReady = Wire(Bool()) val dpuReady = Wire(Bool())
val dpuFire = operandsValid && dpuReady val dpuFire = operandsValid && dpuReady
val setCompute = fullAQueue.io.deq.bits.tag.set val setCompute = fullAQueue.io.deq.bits.tag.set
@@ -333,10 +340,6 @@ class TensorCoreDecoupled(
substepCompute := substepCompute + 1.U substepCompute := substepCompute + 1.U
} }
// hold full A until two-cycle compute is done
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
val nextStepExecute = dpuFire && (substepCompute === 1.U)
// respQueueA output arbitrates to either halfAQueue or fullAQueue depending // respQueueA output arbitrates to either halfAQueue or fullAQueue depending
// on the substep // on the substep
respQueueA.ready := MuxCase(false.B, respQueueA.ready := MuxCase(false.B,
@@ -345,12 +348,19 @@ class TensorCoreDecoupled(
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when // Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension). // we fully iterated a column (M-dimension).
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask val shouldDequeueB =
fillBufB.io.deq.ready := dpuFire && shouldDequeueB ((stepExecute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U)
fullB.io.deq.ready := dpuFire && shouldDequeueB
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA) dontTouch(respQueueA)
dontTouch(respQueueB) dontTouch(respQueueB)
dontTouch(shouldDequeueB) dontTouch(shouldDequeueB)
// hold full A until two-cycle compute is done
fullAQueue.io.deq.ready := dpuFire && (substepCompute === 1.U)
val nextStepExecute = dpuFire && (substepCompute === 1.U)
// Assert that the DPU is computing with operands of the same set/step. Note // Assert that the DPU is computing with operands of the same set/step. Note
// that the B resp will only have step values multiple of 4 due to reuse. // that the B resp will only have step values multiple of 4 due to reuse.
// //
@@ -359,9 +369,9 @@ class TensorCoreDecoupled(
def assertAligned = { def assertAligned = {
val stepMask = (1 << numTilesMBits).U val stepMask = (1 << numTilesMBits).U
when (dpuFire) { when (dpuFire) {
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) && assert((fullAQueue.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
((fullAQueue.io.deq.bits.tag.step & stepMask) === ((fullAQueue.io.deq.bits.tag.step & stepMask) ===
(respQueueB.bits.tag.step & stepMask)), (fullBTag.io.deq.bits.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++ "A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.") "This might indicate memory response coming back out-of-order.")
} }
@@ -387,7 +397,7 @@ class TensorCoreDecoupled(
// operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq // operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
// .grouped(4/*k-dim*/).toSeq // .grouped(4/*k-dim*/).toSeq
val operandBDimensional = val operandBDimensional =
operandB(0)/*FIXME!*/.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq .grouped(4/*k-dim*/).toSeq
require(tilingParams.mc * ncSubstep == numLanes, require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput") "substep tile size doesn't match writeback throughput")