tensor: Fix timing of fullCTag

This commit is contained in:
Hansung Kim
2024-10-25 17:29:35 -07:00
parent 43e064fe82
commit 81efecb3c8

View File

@@ -272,25 +272,25 @@ class TensorCoreDecoupled(
// regfile latency is 1 cycle; don't need a deep response queue
val respQueueCDepth = 1
val respQueueC = Module(new Queue(
chiselTypeOf(io.respC), respQueueCDepth
new Bundle {
val tag = new TensorMemTag
val data = UInt(io.respC.widthOption.get.W)
},
respQueueCDepth
))
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits := io.respC
// serialize every two C responses into one full 4x4 C tile
val fullC = Module(new FillBuffer(
chiselTypeOf(io.respC), 2/*substeps*/
))
fullC.io.enq.valid := respQueueC.io.deq.valid
fullC.io.enq.bits := respQueueC.io.deq.bits
respQueueC.io.deq.ready := fullC.io.enq.ready
// make sure there's space at the response queue to be latched at the next
// cycle
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
// 1-cycle delay
respCValid := genReqC
// note rd is independent to sets
def rdGen(step: UInt, substep: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
// thread
(step << 1/*2 substeps*/) + substep
}
io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME
@@ -303,6 +303,28 @@ class TensorCoreDecoupled(
}
stateRegC.index := stateRegC.index + 1.U
}
dontTouch(stateRegC)
// queue the regfile response to buffers
// these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccess
respQueueC.io.enq.bits.tag.set := stateRegC.set
respQueueC.io.enq.bits.tag.index := stateRegC.index
respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile
val fullC = Module(new FillBuffer(
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
))
fullC.io.enq.valid := respQueueC.io.deq.valid
fullC.io.enq.bits := respQueueC.io.deq.bits.data
respQueueC.io.deq.ready := fullC.io.enq.ready
val fullCTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
// ===========================================================================
// Execute stage
@@ -391,18 +413,12 @@ class TensorCoreDecoupled(
fullB.io.deq.ready := fullBBuf.io.enq.ready
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
// fullC is instiated at the access stage
val fullCTag = Module(new Queue(
new TensorMemTag, entries = 1, pipe = true
))
fullCTag.io.enq.valid := respQueueB.valid
fullCTag.io.enq.bits := respQueueB.bits.tag
// fullC/fullCTag is instiated at the access stage
val fullCBuf = Module(new Queue(
new Bundle {
val data = chiselTypeOf(fullC.io.deq.bits)
val tag = new TensorMemTag
val data = chiselTypeOf(fullC.io.deq.bits)
}, entries = 1, pipe = true
))
fullCBuf.io.enq.valid := fullC.io.deq.valid
@@ -413,7 +429,8 @@ class TensorCoreDecoupled(
val dpuReady = Wire(Bool())
val dpuFire = Wire(Bool())
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
val operandsValid =
fullABuf.io.deq.valid && fullBBuf.io.deq.valid && fullCBuf.io.deq.valid
dpuFire := operandsValid && dpuReady
dontTouch(dpuFire)
@@ -465,7 +482,7 @@ class TensorCoreDecoupled(
(substepCompute === 1.U)
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
// C buf should be synced with B buf
// C deq should be synced with B deq
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
dontTouch(respQueueA)
@@ -584,13 +601,6 @@ class TensorCoreDecoupled(
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
// note rd is independent to sets
def rdGen(step: UInt, substep: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
// thread
(step << 1/*2 substeps*/) + substep
}
val warpWriteback = tagQueue.io.deq.bits.warp
val setWriteback = tagQueue.io.deq.bits.set
val stepWriteback = tagQueue.io.deq.bits.step