tensor: Consider data reuse for B memory request
B is reused every 4 steps because of the k->i->j iteration order.
This commit is contained in:
@@ -145,8 +145,6 @@ class TensorCoreDecoupled(
|
|||||||
// Memory traffic generation
|
// Memory traffic generation
|
||||||
// -------------------------
|
// -------------------------
|
||||||
//
|
//
|
||||||
val genReq = (state === TensorState.run)
|
|
||||||
|
|
||||||
class TensorMemTag extends Bundle {
|
class TensorMemTag extends Bundle {
|
||||||
val set = UInt(setBits.W)
|
val set = UInt(setBits.W)
|
||||||
val step = UInt(stepBits.W)
|
val step = UInt(stepBits.W)
|
||||||
@@ -159,16 +157,14 @@ class TensorCoreDecoupled(
|
|||||||
tag.step := stepAccess
|
tag.step := stepAccess
|
||||||
tag.substep := substepAccess
|
tag.substep := substepAccess
|
||||||
|
|
||||||
|
val numTilesM = tilingParams.m / tilingParams.mc
|
||||||
|
val numTilesN = tilingParams.n / tilingParams.nc
|
||||||
// @cleanup: generalize in terms of M/N/K-majorness?
|
// @cleanup: generalize in terms of M/N/K-majorness?
|
||||||
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
|
def addressGen(baseA: UInt, baseB: UInt, set: UInt, step: UInt, substep: UInt)
|
||||||
: (UInt/*A*/, UInt/*B*/) = {
|
: (UInt/*A*/, UInt/*B*/) = {
|
||||||
// note that step iterates along N first, then M
|
// note that step iterates along N first, then M
|
||||||
val numComputeTilesM = tilingParams.m / tilingParams.mc
|
val tileM = step % numTilesM.U
|
||||||
val numComputeTilesN = tilingParams.n / tilingParams.nc
|
val tileN = step / numTilesM.U
|
||||||
val tileM = step % numComputeTilesM.U
|
|
||||||
val tileN = step / numComputeTilesM.U
|
|
||||||
val mcSubstep = tilingParams.mc / 2
|
|
||||||
val ncSubstep = tilingParams.nc / 2
|
|
||||||
|
|
||||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||||
// accesses
|
// accesses
|
||||||
@@ -180,11 +176,11 @@ class TensorCoreDecoupled(
|
|||||||
val tileColB = set // K
|
val tileColB = set // K
|
||||||
// (row,col) coordinate of the starting element of the compute tile
|
// (row,col) coordinate of the starting element of the compute tile
|
||||||
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
|
val elemRowA = (tileRowA << log2Ceil(tilingParams.mc)) +
|
||||||
(substep << log2Ceil(mcSubstep))
|
(substep << log2Ceil(tilingParams.mc / 2))
|
||||||
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
val elemColA = tileColA << log2Ceil(tilingParams.kc)
|
||||||
val elemRowB = tileRowB << log2Ceil(tilingParams.nc)
|
val elemRowB = (tileRowB << log2Ceil(tilingParams.nc)) +
|
||||||
(substep << log2Ceil(ncSubstep))
|
(substep << log2Ceil(tilingParams.nc / 2))
|
||||||
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
val elemColB = tileColB << log2Ceil(tilingParams.kc)
|
||||||
val rowStrideA = wordSize * tilingParams.k
|
val rowStrideA = wordSize * tilingParams.k
|
||||||
val rowStrideABits = log2Ceil(rowStrideA)
|
val rowStrideABits = log2Ceil(rowStrideA)
|
||||||
val rowStrideB = wordSize * tilingParams.k
|
val rowStrideB = wordSize * tilingParams.k
|
||||||
@@ -201,6 +197,13 @@ class TensorCoreDecoupled(
|
|||||||
val (addressA, addressB) =
|
val (addressA, addressB) =
|
||||||
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
addressGen(0.U, 0.U, setAccess, stepAccess, substepAccess)
|
||||||
|
|
||||||
|
val genReqA = (state === TensorState.run)
|
||||||
|
val numTilesMBits = log2Ceil(numTilesM)
|
||||||
|
// generate B request at every 4 steps. B achieves reuse through outer
|
||||||
|
// product so it doesn't require access at every step
|
||||||
|
val shouldFireB = (stepAccess & ((1 << numTilesMBits) - 1).U) === 0.U
|
||||||
|
val genReqB = (state === TensorState.run) && shouldFireB
|
||||||
|
|
||||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||||
Seq((io.reqA, (io.respA, respATagged)),
|
Seq((io.reqA, (io.respA, respATagged)),
|
||||||
@@ -213,7 +216,7 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
sourceGen.io.gen := req.fire
|
sourceGen.io.gen := req.fire
|
||||||
sourceGen.io.meta := tag
|
sourceGen.io.meta := tag
|
||||||
req.valid := genReq
|
req.valid := (if (i == 0) genReqA else genReqB)
|
||||||
req.bits.address := (if (i == 0) addressA else addressB)
|
req.bits.address := (if (i == 0) addressA else addressB)
|
||||||
req.bits.source := sourceGen.io.id.bits
|
req.bits.source := sourceGen.io.id.bits
|
||||||
|
|
||||||
@@ -228,23 +231,27 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// only advance to the next step if we fired mem requests for both A and B
|
// only advance to the next step if we fired mem requests for both A and B.
|
||||||
// TODO: @perf: too strict? should be able to have A and B progress
|
// also consider that B doesn't have to be fired every time due to reuse.
|
||||||
// separately
|
// @perf: too strict? should be able to have A and B progress separately
|
||||||
val firedABReg = RegInit(VecInit(false.B, false.B))
|
val firedAReg = RegInit(false.B)
|
||||||
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
val firedBReg = RegInit(false.B)
|
||||||
case (req, fired) => { when (req.fire) { fired := true.B } }
|
when (io.reqA.fire) { firedAReg := true.B }
|
||||||
req.fire
|
when (io.reqB.fire) { firedBReg := true.B }
|
||||||
})
|
val firedANow = io.reqA.fire
|
||||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
val firedBNow = io.reqB.fire
|
||||||
val nextSubstepAccess = firedAB.andR
|
val firedA = firedAReg || firedANow
|
||||||
|
val firedB = firedBReg || firedBNow
|
||||||
|
val nextSubstepAccess = firedA && (!shouldFireB || firedB)
|
||||||
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
val nextStepAccess = nextSubstepAccess && (substepAccess === 1.U)
|
||||||
// clear out firedABReg every substep
|
// clear out firedABReg every substep
|
||||||
when (nextSubstepAccess) {
|
when (nextSubstepAccess) {
|
||||||
firedABReg := Seq(false.B, false.B)
|
firedAReg := false.B
|
||||||
|
firedBReg := false.B
|
||||||
substepAccess := substepAccess + 1.U
|
substepAccess := substepAccess + 1.U
|
||||||
}
|
}
|
||||||
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
require(substepAccess.widthOption.get == 1, "there should be only two substeps")
|
||||||
|
dontTouch(shouldFireB)
|
||||||
|
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// -------------
|
// -------------
|
||||||
@@ -327,18 +334,26 @@ class TensorCoreDecoupled(
|
|||||||
respQueueA.ready := MuxCase(false.B,
|
respQueueA.ready := MuxCase(false.B,
|
||||||
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
|
Seq((substepExecute === 0.U) -> halfAQueue.io.enq.ready,
|
||||||
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
|
(substepExecute === 1.U) -> fullAQueue.io.enq.ready))
|
||||||
respQueueB.ready := dpuFire
|
// Hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||||
|
// we fully iterated a column (M-dimension).
|
||||||
|
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||||
|
val shouldDequeueB = (stepExecute & shouldDequeueBMask) === shouldDequeueBMask
|
||||||
|
respQueueB.ready := dpuFire && shouldDequeueB
|
||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
dontTouch(respQueueB)
|
dontTouch(respQueueB)
|
||||||
|
dontTouch(shouldDequeueB)
|
||||||
|
|
||||||
// assert that the DPU is computing with operands of the same set/step
|
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||||
|
// that the B resp will only have step values multiple of 4 due to reuse.
|
||||||
//
|
//
|
||||||
// this assumes that memory responses come back in-order. this might be too
|
// This check assumes that memory responses come back in-order. Might be too
|
||||||
// strong an assumption depending on the backing memory
|
// strong of an assumption depending on the backing memory.
|
||||||
def assertAligned = {
|
def assertAligned = {
|
||||||
|
val stepMask = (1 << numTilesMBits).U
|
||||||
when (dpuFire) {
|
when (dpuFire) {
|
||||||
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
|
assert((fullAQueue.io.deq.bits.tag.set === respQueueB.bits.tag.set) &&
|
||||||
(fullAQueue.io.deq.bits.tag.step === respQueueB.bits.tag.step),
|
((fullAQueue.io.deq.bits.tag.step & stepMask) ===
|
||||||
|
(respQueueB.bits.tag.step & stepMask)),
|
||||||
"A and B operands are pointing to different set/steps. " ++
|
"A and B operands are pointing to different set/steps. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
}
|
}
|
||||||
@@ -348,26 +363,26 @@ class TensorCoreDecoupled(
|
|||||||
// Dot-product unit
|
// Dot-product unit
|
||||||
//
|
//
|
||||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||||
val dpus = Seq.fill(4)(Seq.fill(2)(
|
val ncSubstep = tilingParams.nc / 2
|
||||||
|
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||||
Module(new TensorDotProductUnit(half = false))
|
Module(new TensorDotProductUnit(half = false))
|
||||||
))
|
))
|
||||||
// operandA is 4x4 in K-major
|
// operandA is 4x4 in K-major
|
||||||
val operandADimensional =
|
val operandADimensional =
|
||||||
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
.grouped(4).toSeq
|
.grouped(4/*k-dim*/).toSeq
|
||||||
assert(operandADimensional.length == tilingParams.mc &&
|
require(operandADimensional.length == tilingParams.mc &&
|
||||||
operandADimensional(0).length == tilingParams.kc,
|
operandADimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
// operandB is 2x4, i.e. 4x2 in N-major
|
// operandB is 2x4 in K-major
|
||||||
val operandBDimensional =
|
val operandBDimensional =
|
||||||
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||||
.grouped(4).toSeq
|
.grouped(4/*k-dim*/).toSeq
|
||||||
val ncSubstep = tilingParams.nc / 2
|
require(tilingParams.mc * ncSubstep == numLanes,
|
||||||
assert(tilingParams.mc * ncSubstep == numLanes,
|
"substep tile size doesn't match writeback throughput")
|
||||||
"substep tile size doesn't match writeback throughput")
|
require(operandBDimensional.length == ncSubstep &&
|
||||||
assert(operandBDimensional.length == ncSubstep &&
|
operandBDimensional(0).length == tilingParams.kc,
|
||||||
operandBDimensional(0).length == tilingParams.kc,
|
"operand width doesn't agree with tiling parameter")
|
||||||
"operand width doesn't agree with tiling parameter")
|
|
||||||
|
|
||||||
for (m <- 0 until tilingParams.mc) {
|
for (m <- 0 until tilingParams.mc) {
|
||||||
for (n <- 0 until ncSubstep) {
|
for (n <- 0 until ncSubstep) {
|
||||||
@@ -406,10 +421,8 @@ class TensorCoreDecoupled(
|
|||||||
// ----------------
|
// ----------------
|
||||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||||
|
|
||||||
val queueDepth = 6 // needs to be at least the DPU latency
|
val queueDepth = 5 // needs to be at least the DPU latency
|
||||||
val tagQueue = Module(new Queue(
|
val tagQueue = Module(new Queue(chiselTypeOf(operandATag), queueDepth))
|
||||||
chiselTypeOf(operandATag), queueDepth
|
|
||||||
))
|
|
||||||
tagQueue.io.enq.valid := dpuFire
|
tagQueue.io.enq.valid := dpuFire
|
||||||
// A and B should have the same tags
|
// A and B should have the same tags
|
||||||
tagQueue.io.enq.bits := operandATag
|
tagQueue.io.enq.bits := operandATag
|
||||||
@@ -573,11 +586,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
|||||||
val tensor = LazyModule(new TensorCoreDecoupledTL)
|
val tensor = LazyModule(new TensorCoreDecoupledTL)
|
||||||
val xbar = LazyModule(new TLXbar)
|
val xbar = LazyModule(new TLXbar)
|
||||||
val ramA = LazyModule(new TLRAM(
|
val ramA = LazyModule(new TLRAM(
|
||||||
address = AddressSet(0x000, 0xfffeff),
|
address = AddressSet(0x000, 0xfffbff),
|
||||||
beatBytes = 32 // @cleanup: hardcoded
|
beatBytes = 32 // @cleanup: hardcoded
|
||||||
))
|
))
|
||||||
val ramB = LazyModule(new TLRAM(
|
val ramB = LazyModule(new TLRAM(
|
||||||
address = AddressSet(0x100, 0xfffeff),
|
address = AddressSet(0x400, 0xfffbff),
|
||||||
beatBytes = 32 // @cleanup: hardcoded
|
beatBytes = 32 // @cleanup: hardcoded
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user