tensor: Add access logic for C from regfile

This commit is contained in:
Hansung Kim
2024-10-25 15:22:52 -07:00
parent fc5b864b86
commit 43e064fe82

View File

@@ -51,8 +51,10 @@ class TensorCoreDecoupled(
}) })
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respC = Input(UInt(dataWidth.W))
val reqA = Decoupled(new TensorMemReq(sourceWidth)) val reqA = Decoupled(new TensorMemReq(sourceWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth)) val reqB = Decoupled(new TensorMemReq(sourceWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W)))
}) })
dontTouch(io) dontTouch(io)
@@ -131,9 +133,7 @@ class TensorCoreDecoupled(
val stateA = RegInit(stateInit) val stateA = RegInit(stateInit)
val stateB = RegInit(stateInit) val stateB = RegInit(stateInit)
dontTouch(stateA) dontTouch(stateA)
dontTouch(stateA.index)
dontTouch(stateB) dontTouch(stateB)
dontTouch(stateB.index)
io.initiate.ready := (state === AccessorState.idle) io.initiate.ready := (state === AccessorState.idle)
when (io.initiate.fire) { when (io.initiate.fire) {
@@ -262,6 +262,48 @@ class TensorCoreDecoupled(
} }
} }
// C access from regfile
//
// since regfile is fixed-latency, respC valid should be determined at the
// request sending side.
val respCValid = RegInit(false.B)
// regfile latency is 1 cycle; don't need a deep response queue
val respQueueCDepth = 1
val respQueueC = Module(new Queue(
chiselTypeOf(io.respC), respQueueCDepth
))
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits := io.respC
// serialize every two C responses into one full 4x4 C tile
val fullC = Module(new FillBuffer(
chiselTypeOf(io.respC), 2/*substeps*/
))
fullC.io.enq.valid := respQueueC.io.deq.valid
fullC.io.enq.bits := respQueueC.io.deq.bits
respQueueC.io.deq.ready := fullC.io.enq.ready
// make sure there's space at the response queue to be latched at the next
// cycle
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
// 1-cycle delay
respCValid := genReqC
io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME
// set/index state of the C accumulator value that will be latched ath the
// next cycle.
val stateRegC = RegInit(stateInit)
when (genReqC) {
when (stateRegC.index === lastIndex.U) {
stateRegC.set := stateRegC.set + 1.U
}
stateRegC.index := stateRegC.index + 1.U
}
// =========================================================================== // ===========================================================================
// Execute stage // Execute stage
// =========================================================================== // ===========================================================================
@@ -349,9 +391,31 @@ class TensorCoreDecoupled(
fullB.io.deq.ready := fullBBuf.io.enq.ready fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready
// fullC is instiated at the access stage
val fullCTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueB.valid
fullCTag.io.enq.bits := respQueueB.bits.tag
val fullCBuf = Module(new Queue(
new Bundle {
val data = chiselTypeOf(fullC.io.deq.bits)
val tag = new TensorMemTag
}, entries = 1, pipe = true
))
fullCBuf.io.enq.valid := fullC.io.deq.valid
fullCBuf.io.enq.bits.data := fullC.io.deq.bits
fullCBuf.io.enq.bits.tag := fullCTag.io.deq.bits
fullC.io.deq.ready := fullCBuf.io.enq.ready
fullCTag.io.deq.ready := fullCBuf.io.enq.ready
val dpuReady = Wire(Bool()) val dpuReady = Wire(Bool())
val dpuFire = Wire(Bool())
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val dpuFire = operandsValid && dpuReady dpuFire := operandsValid && dpuReady
dontTouch(dpuFire)
val setCompute = RegInit(0.U(setBits.W)) val setCompute = RegInit(0.U(setBits.W))
val stepCompute = RegInit(0.U(stepBits.W)) val stepCompute = RegInit(0.U(stepBits.W))
@@ -376,11 +440,14 @@ class TensorCoreDecoupled(
} }
val operandA = selectOperandA(fullABuf.io.deq.bits.data) val operandA = selectOperandA(fullABuf.io.deq.bits.data)
val operandATag = fullABuf.io.deq.bits.tag val operandATag = fullABuf.io.deq.bits.tag
// select the correct 2x4 tile from B operand buffer // select the correct 2x4 tile from B/C operand buffer
val operandB = fullBBuf.io.deq.bits.data(substepCompute) val operandB = fullBBuf.io.deq.bits.data(substepCompute)
val operandBTag = fullBBuf.io.deq.bits.tag val operandBTag = fullBBuf.io.deq.bits.tag
val operandC = fullCBuf.io.deq.bits.data(substepCompute)
val operandCTag = fullCBuf.io.deq.bits.tag
dontTouch(operandATag) dontTouch(operandATag)
dontTouch(operandBTag) dontTouch(operandBTag)
dontTouch(operandCTag)
// Operand buffer logic // Operand buffer logic
// //
@@ -397,6 +464,10 @@ class TensorCoreDecoupled(
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) && ((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
(substepCompute === 1.U) (substepCompute === 1.U)
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
// C buf should be synced with B buf
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA) dontTouch(respQueueA)
dontTouch(respQueueB) dontTouch(respQueueB)
dontTouch(shouldDequeueA) dontTouch(shouldDequeueA)
@@ -414,6 +485,10 @@ class TensorCoreDecoupled(
operandATag.set === operandBTag.set, operandATag.set === operandBTag.set,
"A and B operands are pointing to different warps and sets. " ++ "A and B operands are pointing to different warps and sets. " ++
"This might indicate memory response coming back out-of-order.") "This might indicate memory response coming back out-of-order.")
assert(operandATag.warp === operandCTag.warp &&
operandATag.set === operandCTag.set,
"A and C operands are pointing to different warps and sets. " ++
"This might indicate memory response coming back out-of-order.")
assert(operandATag.set === setCompute, assert(operandATag.set === setCompute,
"Operand arrived from memory is pointing at a different set than the FSM.") "Operand arrived from memory is pointing at a different set than the FSM.")
} }
@@ -422,7 +497,7 @@ class TensorCoreDecoupled(
// Dot-product unit // Dot-product unit
// //
// 4x2 four-element DPUs summing up to 32 MACs in total // 4x2 four-element DPUs summing up to 32 FP32 MACs in total
// //
val ncSubstep = tilingParams.nc / 2 val ncSubstep = tilingParams.nc / 2
require(tilingParams.mc * ncSubstep == numLanes, require(tilingParams.mc * ncSubstep == numLanes,
@@ -444,13 +519,18 @@ class TensorCoreDecoupled(
require(operandBDimensional.length == ncSubstep && require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc, operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter") "operand width doesn't agree with tiling parameter")
// note operand C is M-major
val operandCDimensional = reshapeByFourWords(operandC)
require(operandCDimensional.length == ncSubstep &&
operandCDimensional(0).length == tilingParams.mc,
"operand width doesn't agree with tiling parameter")
for (m <- 0 until tilingParams.mc) { for (m <- 0 until tilingParams.mc) {
for (n <- 0 until ncSubstep) { for (n <- 0 until ncSubstep) {
dpus(m)(n).io.in.valid := dpuFire dpus(m)(n).io.in.valid := dpuFire
dpus(m)(n).io.in.bits.a := operandADimensional(m) dpus(m)(n).io.in.bits.a := operandADimensional(m)
dpus(m)(n).io.in.bits.b := operandBDimensional(n) dpus(m)(n).io.in.bits.b := operandBDimensional(n)
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data dpus(m)(n).io.in.bits.c := operandCDimensional(n)(m)
// dpu ready couples with writeback backpressure // dpu ready couples with writeback backpressure
dpus(m)(n).io.stall := !io.writeback.ready dpus(m)(n).io.stall := !io.writeback.ready
} }
@@ -631,8 +711,10 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tensor.io.respB.bits.source := tlOutB.d.bits.source tensor.io.respB.bits.source := tlOutB.d.bits.source
tlOutB.d.ready := tensor.io.respB.ready tlOutB.d.ready := tensor.io.respB.ready
tensor.io.respC := 42.U // FIXME bogus
tensor.io.initiate.valid := io.start tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 0.U // TODO tensor.io.initiate.bits.wid := 0.U // FIXME bogus
tensor.io.writeback.ready := true.B tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last