tensor: Stage B as well for full throughput

This commit is contained in:
Hansung Kim
2024-10-18 20:12:15 -07:00
parent c4b5a11fde
commit 93c9bcc32f

View File

@@ -300,10 +300,13 @@ class TensorCoreDecoupled(
// background while the tile is being used for compute. This does come with
// capacity overhead.
val fullABuf = Module(new Queue(
new TensorMemRespWithTag(dataWidth * 2), entries = 1, pipe = true
new Bundle {
val data = chiselTypeOf(fullA.io.deq.bits)
val tag = new TensorMemTag
}, entries = 1, pipe = true
))
fullABuf.io.enq.valid := fullA.io.deq.valid
fullABuf.io.enq.bits.data := fullA.io.deq.bits.asUInt
fullABuf.io.enq.bits.data := fullA.io.deq.bits
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
fullA.io.deq.ready := fullABuf.io.enq.ready
fullATag.io.deq.ready := fullABuf.io.enq.ready
@@ -322,10 +325,22 @@ class TensorCoreDecoupled(
fullBTag.io.enq.valid := respQueueB.valid
fullBTag.io.enq.bits := respQueueB.bits.tag
val operandsValid = fullABuf.io.deq.valid && fullB.io.deq.valid
val fullBBuf = Module(new Queue(
new Bundle {
val data = chiselTypeOf(fullB.io.deq.bits)
val tag = new TensorMemTag
}, entries = 1, pipe = true
))
fullBBuf.io.enq.valid := fullB.io.deq.valid
fullBBuf.io.enq.bits.data := fullB.io.deq.bits
fullBBuf.io.enq.bits.tag := fullBTag.io.deq.bits
fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val operandA = fullABuf.io.deq.bits.data
val operandATag = fullABuf.io.deq.bits.tag
val operandB = fullB.io.deq.bits
val operandB = fullBBuf.io.deq.bits.data
val dpuReady = Wire(Bool())
val dpuFire = operandsValid && dpuReady
val setCompute = fullABuf.io.deq.bits.tag.set
@@ -335,20 +350,19 @@ class TensorCoreDecoupled(
substepCompute := substepCompute + 1.U
}
// hold full A until two-cycle compute is done
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
// we fully iterated a column (M-dimension).
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
val shouldDequeueB =
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U)
fullB.io.deq.ready := dpuFire && shouldDequeueB
fullBTag.io.deq.ready := dpuFire && shouldDequeueB
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA)
dontTouch(respQueueB)
dontTouch(shouldDequeueB)
// hold full A until two-cycle compute is done
fullABuf.io.deq.ready := dpuFire && (substepCompute === 1.U)
// FIXME: this should be nextStepCompute
val nextStepExecute = dpuFire && (substepCompute === 1.U)
@@ -360,9 +374,9 @@ class TensorCoreDecoupled(
def assertAligned = {
val stepMask = (1 << numTilesMBits).U
when (dpuFire) {
assert((fullABuf.io.deq.bits.tag.set === fullBTag.io.deq.bits.set) &&
assert((fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set) &&
((fullABuf.io.deq.bits.tag.step & stepMask) ===
(fullBTag.io.deq.bits.step & stepMask)),
(fullBBuf.io.deq.bits.tag.step & stepMask)),
"A and B operands are pointing to different set/steps. " ++
"This might indicate memory response coming back out-of-order.")
}
@@ -378,15 +392,12 @@ class TensorCoreDecoupled(
))
// operandA is 4x4 in K-major
val operandADimensional =
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
operandA.asUInt.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// operandB is 2x4 in K-major
// val operandBDimensional =
// operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
// .grouped(4/*k-dim*/).toSeq
// select 2x4 subtile out of operandB that is 4x4 in K-major
val operandBDimensional =
operandB(substepCompute).asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq