tensor: Fix timing of fullCTag
This commit is contained in:
@@ -272,25 +272,25 @@ class TensorCoreDecoupled(
|
|||||||
// regfile latency is 1 cycle; don't need a deep response queue
|
// regfile latency is 1 cycle; don't need a deep response queue
|
||||||
val respQueueCDepth = 1
|
val respQueueCDepth = 1
|
||||||
val respQueueC = Module(new Queue(
|
val respQueueC = Module(new Queue(
|
||||||
chiselTypeOf(io.respC), respQueueCDepth
|
new Bundle {
|
||||||
|
val tag = new TensorMemTag
|
||||||
|
val data = UInt(io.respC.widthOption.get.W)
|
||||||
|
},
|
||||||
|
respQueueCDepth
|
||||||
))
|
))
|
||||||
respQueueC.io.enq.valid := respCValid
|
|
||||||
respQueueC.io.enq.bits := io.respC
|
|
||||||
|
|
||||||
// serialize every two C responses into one full 4x4 C tile
|
|
||||||
val fullC = Module(new FillBuffer(
|
|
||||||
chiselTypeOf(io.respC), 2/*substeps*/
|
|
||||||
))
|
|
||||||
fullC.io.enq.valid := respQueueC.io.deq.valid
|
|
||||||
fullC.io.enq.bits := respQueueC.io.deq.bits
|
|
||||||
respQueueC.io.deq.ready := fullC.io.enq.ready
|
|
||||||
|
|
||||||
// make sure there's space at the response queue to be latched at the next
|
// make sure there's space at the response queue to be latched at the next
|
||||||
// cycle
|
// cycle
|
||||||
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
|
val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready
|
||||||
// 1-cycle delay
|
// 1-cycle delay
|
||||||
respCValid := genReqC
|
respCValid := genReqC
|
||||||
|
|
||||||
|
// note rd is independent to sets
|
||||||
|
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||||
|
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||||
|
// thread
|
||||||
|
(step << 1/*2 substeps*/) + substep
|
||||||
|
}
|
||||||
|
|
||||||
io.reqC.valid := genReqC
|
io.reqC.valid := genReqC
|
||||||
io.reqC.bits := 5.U // FIXME
|
io.reqC.bits := 5.U // FIXME
|
||||||
|
|
||||||
@@ -303,6 +303,28 @@ class TensorCoreDecoupled(
|
|||||||
}
|
}
|
||||||
stateRegC.index := stateRegC.index + 1.U
|
stateRegC.index := stateRegC.index + 1.U
|
||||||
}
|
}
|
||||||
|
dontTouch(stateRegC)
|
||||||
|
|
||||||
|
// queue the regfile response to buffers
|
||||||
|
// these strictly belong to the execute stage
|
||||||
|
respQueueC.io.enq.valid := respCValid
|
||||||
|
respQueueC.io.enq.bits.tag.warp := warpAccess
|
||||||
|
respQueueC.io.enq.bits.tag.set := stateRegC.set
|
||||||
|
respQueueC.io.enq.bits.tag.index := stateRegC.index
|
||||||
|
respQueueC.io.enq.bits.data := io.respC
|
||||||
|
|
||||||
|
// serialize every two C responses into one full 4x4 C tile
|
||||||
|
val fullC = Module(new FillBuffer(
|
||||||
|
chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/
|
||||||
|
))
|
||||||
|
fullC.io.enq.valid := respQueueC.io.deq.valid
|
||||||
|
fullC.io.enq.bits := respQueueC.io.deq.bits.data
|
||||||
|
respQueueC.io.deq.ready := fullC.io.enq.ready
|
||||||
|
val fullCTag = Module(new Queue(
|
||||||
|
new TensorMemTag, entries = 1, pipe = true
|
||||||
|
))
|
||||||
|
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||||
|
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||||
|
|
||||||
// ===========================================================================
|
// ===========================================================================
|
||||||
// Execute stage
|
// Execute stage
|
||||||
@@ -391,18 +413,12 @@ class TensorCoreDecoupled(
|
|||||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||||
|
|
||||||
// fullC is instiated at the access stage
|
// fullC/fullCTag is instiated at the access stage
|
||||||
|
|
||||||
val fullCTag = Module(new Queue(
|
|
||||||
new TensorMemTag, entries = 1, pipe = true
|
|
||||||
))
|
|
||||||
fullCTag.io.enq.valid := respQueueB.valid
|
|
||||||
fullCTag.io.enq.bits := respQueueB.bits.tag
|
|
||||||
|
|
||||||
val fullCBuf = Module(new Queue(
|
val fullCBuf = Module(new Queue(
|
||||||
new Bundle {
|
new Bundle {
|
||||||
val data = chiselTypeOf(fullC.io.deq.bits)
|
|
||||||
val tag = new TensorMemTag
|
val tag = new TensorMemTag
|
||||||
|
val data = chiselTypeOf(fullC.io.deq.bits)
|
||||||
}, entries = 1, pipe = true
|
}, entries = 1, pipe = true
|
||||||
))
|
))
|
||||||
fullCBuf.io.enq.valid := fullC.io.deq.valid
|
fullCBuf.io.enq.valid := fullC.io.deq.valid
|
||||||
@@ -413,7 +429,8 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
val dpuReady = Wire(Bool())
|
val dpuReady = Wire(Bool())
|
||||||
val dpuFire = Wire(Bool())
|
val dpuFire = Wire(Bool())
|
||||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
val operandsValid =
|
||||||
|
fullABuf.io.deq.valid && fullBBuf.io.deq.valid && fullCBuf.io.deq.valid
|
||||||
dpuFire := operandsValid && dpuReady
|
dpuFire := operandsValid && dpuReady
|
||||||
dontTouch(dpuFire)
|
dontTouch(dpuFire)
|
||||||
|
|
||||||
@@ -465,7 +482,7 @@ class TensorCoreDecoupled(
|
|||||||
(substepCompute === 1.U)
|
(substepCompute === 1.U)
|
||||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
|
|
||||||
// C buf should be synced with B buf
|
// C deq should be synced with B deq
|
||||||
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
|
fullCBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||||
|
|
||||||
dontTouch(respQueueA)
|
dontTouch(respQueueA)
|
||||||
@@ -584,13 +601,6 @@ class TensorCoreDecoupled(
|
|||||||
|
|
||||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||||
|
|
||||||
// note rd is independent to sets
|
|
||||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
|
||||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
|
||||||
// thread
|
|
||||||
(step << 1/*2 substeps*/) + substep
|
|
||||||
}
|
|
||||||
|
|
||||||
val warpWriteback = tagQueue.io.deq.bits.warp
|
val warpWriteback = tagQueue.io.deq.bits.warp
|
||||||
val setWriteback = tagQueue.io.deq.bits.set
|
val setWriteback = tagQueue.io.deq.bits.set
|
||||||
val stepWriteback = tagQueue.io.deq.bits.step
|
val stepWriteback = tagQueue.io.deq.bits.step
|
||||||
|
|||||||
Reference in New Issue
Block a user