tensor: Add access logic for C from regfile
This commit is contained in:
@@ -51,8 +51,10 @@ class TensorCoreDecoupled(
|
|||||||
})
|
})
|
||||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||||
|
val respC = Input(UInt(dataWidth.W))
|
||||||
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
||||||
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
||||||
|
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
||||||
})
|
})
|
||||||
dontTouch(io)
|
dontTouch(io)
|
||||||
|
|
||||||
@@ -131,9 +133,7 @@ class TensorCoreDecoupled(
|
|||||||
val stateA = RegInit(stateInit)
|
val stateA = RegInit(stateInit)
|
||||||
val stateB = RegInit(stateInit)
|
val stateB = RegInit(stateInit)
|
||||||
dontTouch(stateA)
|
dontTouch(stateA)
|
||||||
dontTouch(stateA.index)
|
|
||||||
dontTouch(stateB)
|
dontTouch(stateB)
|
||||||
dontTouch(stateB.index)
|
|
||||||
|
|
||||||
io.initiate.ready := (state === AccessorState.idle)
|
io.initiate.ready := (state === AccessorState.idle)
|
||||||
when (io.initiate.fire) {
|
when (io.initiate.fire) {
|
||||||
@@ -262,6 +262,48 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// C access from regfile
|
||||||
|
//
|
||||||
|
|
||||||
|
// since regfile is fixed-latency, respC valid should be determined at the
|
||||||
|
// request sending side.
|
||||||
|
val respCValid = RegInit(false.B)
|
||||||
|
|
||||||
|
// regfile latency is 1 cycle; don't need a deep response queue
|
||||||
|
val respQueueCDepth = 1
|
||||||
|
val respQueueC = Module(new Queue(
|
||||||
|
chiselTypeOf(io.respC), respQueueCDepth
|
||||||
|
))
|
||||||
|
respQueueC.io.enq.valid := respCValid
|
||||||
|
respQueueC.io.enq.bits := io.respC
|
||||||
|
|
||||||
|
// serialize every two C responses into one full 4x4 C tile
|
||||||
|
val fullC = Module(new FillBuffer(
|
||||||
|
chiselTypeOf(io.respC), 2/*substeps*/
|
||||||
|
))
|
||||||
|
fullC.io.enq.valid := respQueueC.io.deq.valid
|
||||||
|
fullC.io.enq.bits := respQueueC.io.deq.bits
|
||||||
|
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||||
|
|
||||||
|
// make sure there's space at the response queue to be latched at the next
|
||||||
|
// cycle
|
||||||
|
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
|
||||||
|
// 1-cycle delay
|
||||||
|
respCValid := genReqC
|
||||||
|
|
||||||
|
io.reqC.valid := genReqC
|
||||||
|
io.reqC.bits := 5.U // FIXME
|
||||||
|
|
||||||
|
// set/index state of the C accumulator value that will be latched ath the
|
||||||
|
// next cycle.
|
||||||
|
val stateRegC = RegInit(stateInit)
|
||||||
|
when (genReqC) {
|
||||||
|
when (stateRegC.index === lastIndex.U) {
|
||||||
|
stateRegC.set := stateRegC.set + 1.U
|
||||||
|
}
|
||||||
|
stateRegC.index := stateRegC.index + 1.U
|
||||||
|
}
|
||||||
|
|
||||||
// ===========================================================================
|
// ===========================================================================
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// ===========================================================================
|
// ===========================================================================
|
||||||
@@ -349,9 +391,31 @@ class TensorCoreDecoupled(
|
|||||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
|
|
||||||
|
// fullC is instiated at the access stage
|
||||||
|
|
||||||
|
val fullCTag = Module(new Queue(
|
||||||
|
new TensorMemTag, entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
fullCTag.io.enq.valid := respQueueB.valid
|
||||||
|
fullCTag.io.enq.bits := respQueueB.bits.tag
|
||||||
|
|
||||||
|
val fullCBuf = Module(new Queue(
|
||||||
|
new Bundle {
|
||||||
|
val data = chiselTypeOf(fullC.io.deq.bits)
|
||||||
|
val tag = new TensorMemTag
|
||||||
|
}, entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
fullCBuf.io.enq.valid := fullC.io.deq.valid
|
||||||
|
fullCBuf.io.enq.bits.data := fullC.io.deq.bits
|
||||||
|
fullCBuf.io.enq.bits.tag := fullCTag.io.deq.bits
|
||||||
|
fullC.io.deq.ready := fullCBuf.io.enq.ready
|
||||||
|
fullCTag.io.deq.ready := fullCBuf.io.enq.ready
|
||||||
|
|
||||||
val dpuReady = Wire(Bool())
|
val dpuReady = Wire(Bool())
|
||||||
|
val dpuFire = Wire(Bool())
|
||||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||||
val dpuFire = operandsValid && dpuReady
|
dpuFire := operandsValid && dpuReady
|
||||||
|
dontTouch(dpuFire)
|
||||||
|
|
||||||
val setCompute = RegInit(0.U(setBits.W))
|
val setCompute = RegInit(0.U(setBits.W))
|
||||||
val stepCompute = RegInit(0.U(stepBits.W))
|
val stepCompute = RegInit(0.U(stepBits.W))
|
||||||
@@ -376,11 +440,14 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
|
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
|
||||||
val operandATag = fullABuf.io.deq.bits.tag
|
val operandATag = fullABuf.io.deq.bits.tag
|
||||||
// select the correct 2x4 tile from B operand buffer
|
// select the correct 2x4 tile from B/C operand buffer
|
||||||
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
|
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
|
||||||
val operandBTag = fullBBuf.io.deq.bits.tag
|
val operandBTag = fullBBuf.io.deq.bits.tag
|
||||||
|
val operandC = fullCBuf.io.deq.bits.data(substepCompute)
|
||||||
|
val operandCTag = fullCBuf.io.deq.bits.tag
|
||||||
dontTouch(operandATag)
|
dontTouch(operandATag)
|
||||||
dontTouch(operandBTag)
|
dontTouch(operandBTag)
|
||||||
|
dontTouch(operandCTag)
|
||||||
|
|
||||||
// Operand buffer logic
|
// Operand buffer logic
|
||||||
//
|
//
|
||||||
@@ -397,6 +464,10 @@ class TensorCoreDecoupled(
|
|||||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||||
(substepCompute === 1.U)
|
(substepCompute === 1.U)
|
||||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
|
|
||||||
|
// C buf should be synced with B buf
|
||||||
|
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
|
|
||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
dontTouch(respQueueB)
|
dontTouch(respQueueB)
|
||||||
dontTouch(shouldDequeueA)
|
dontTouch(shouldDequeueA)
|
||||||
@@ -414,6 +485,10 @@ class TensorCoreDecoupled(
|
|||||||
operandATag.set === operandBTag.set,
|
operandATag.set === operandBTag.set,
|
||||||
"A and B operands are pointing to different warps and sets. " ++
|
"A and B operands are pointing to different warps and sets. " ++
|
||||||
"This might indicate memory response coming back out-of-order.")
|
"This might indicate memory response coming back out-of-order.")
|
||||||
|
assert(operandATag.warp === operandCTag.warp &&
|
||||||
|
operandATag.set === operandCTag.set,
|
||||||
|
"A and C operands are pointing to different warps and sets. " ++
|
||||||
|
"This might indicate memory response coming back out-of-order.")
|
||||||
assert(operandATag.set === setCompute,
|
assert(operandATag.set === setCompute,
|
||||||
"Operand arrived from memory is pointing at a different set than the FSM.")
|
"Operand arrived from memory is pointing at a different set than the FSM.")
|
||||||
}
|
}
|
||||||
@@ -422,7 +497,7 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
// Dot-product unit
|
// Dot-product unit
|
||||||
//
|
//
|
||||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
// 4x2 four-element DPUs summing up to 32 FP32 MACs in total
|
||||||
//
|
//
|
||||||
val ncSubstep = tilingParams.nc / 2
|
val ncSubstep = tilingParams.nc / 2
|
||||||
require(tilingParams.mc * ncSubstep == numLanes,
|
require(tilingParams.mc * ncSubstep == numLanes,
|
||||||
@@ -444,13 +519,18 @@ class TensorCoreDecoupled(
|
|||||||
require(operandBDimensional.length == ncSubstep &&
|
require(operandBDimensional.length == ncSubstep &&
|
||||||
operandBDimensional(0).length == tilingParams.kc,
|
operandBDimensional(0).length == tilingParams.kc,
|
||||||
"operand width doesn't agree with tiling parameter")
|
"operand width doesn't agree with tiling parameter")
|
||||||
|
// note operand C is M-major
|
||||||
|
val operandCDimensional = reshapeByFourWords(operandC)
|
||||||
|
require(operandCDimensional.length == ncSubstep &&
|
||||||
|
operandCDimensional(0).length == tilingParams.mc,
|
||||||
|
"operand width doesn't agree with tiling parameter")
|
||||||
|
|
||||||
for (m <- 0 until tilingParams.mc) {
|
for (m <- 0 until tilingParams.mc) {
|
||||||
for (n <- 0 until ncSubstep) {
|
for (n <- 0 until ncSubstep) {
|
||||||
dpus(m)(n).io.in.valid := dpuFire
|
dpus(m)(n).io.in.valid := dpuFire
|
||||||
dpus(m)(n).io.in.bits.a := operandADimensional(m)
|
dpus(m)(n).io.in.bits.a := operandADimensional(m)
|
||||||
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
|
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
|
||||||
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
|
dpus(m)(n).io.in.bits.c := operandCDimensional(n)(m)
|
||||||
// dpu ready couples with writeback backpressure
|
// dpu ready couples with writeback backpressure
|
||||||
dpus(m)(n).io.stall := !io.writeback.ready
|
dpus(m)(n).io.stall := !io.writeback.ready
|
||||||
}
|
}
|
||||||
@@ -631,8 +711,10 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
|||||||
tensor.io.respB.bits.source := tlOutB.d.bits.source
|
tensor.io.respB.bits.source := tlOutB.d.bits.source
|
||||||
tlOutB.d.ready := tensor.io.respB.ready
|
tlOutB.d.ready := tensor.io.respB.ready
|
||||||
|
|
||||||
|
tensor.io.respC := 42.U // FIXME bogus
|
||||||
|
|
||||||
tensor.io.initiate.valid := io.start
|
tensor.io.initiate.valid := io.start
|
||||||
tensor.io.initiate.bits.wid := 0.U // TODO
|
tensor.io.initiate.bits.wid := 0.U // FIXME bogus
|
||||||
tensor.io.writeback.ready := true.B
|
tensor.io.writeback.ready := true.B
|
||||||
|
|
||||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||||
|
|||||||
Reference in New Issue
Block a user