diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 535dbdd..bbea6bd 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -272,25 +272,25 @@ class TensorCoreDecoupled( // regfile latency is 1 cycle; don't need a deep response queue val respQueueCDepth = 1 val respQueueC = Module(new Queue( - chiselTypeOf(io.respC), respQueueCDepth + new Bundle { + val tag = new TensorMemTag + val data = UInt(io.respC.widthOption.get.W) + }, + respQueueCDepth )) - respQueueC.io.enq.valid := respCValid - respQueueC.io.enq.bits := io.respC - - // serialize every two C responses into one full 4x4 C tile - val fullC = Module(new FillBuffer( - chiselTypeOf(io.respC), 2/*substeps*/ - )) - fullC.io.enq.valid := respQueueC.io.deq.valid - fullC.io.enq.bits := respQueueC.io.deq.bits - respQueueC.io.deq.ready := fullC.io.enq.ready - // make sure there's space at the response queue to be latched at the next // cycle val genReqC = (state === AccessorState.access) && respQueueC.io.enq.ready // 1-cycle delay respCValid := genReqC + // note rd is independent to sets + def rdGen(step: UInt, substep: UInt): UInt = { + // each step produces 4x4 output tile, written by 8 threads with 2 regs per + // thread + (step << 1/*2 substeps*/) + substep + } + io.reqC.valid := genReqC io.reqC.bits := 5.U // FIXME @@ -303,6 +303,28 @@ class TensorCoreDecoupled( } stateRegC.index := stateRegC.index + 1.U } + dontTouch(stateRegC) + + // queue the regfile response to buffers + // these strictly belong to the execute stage + respQueueC.io.enq.valid := respCValid + respQueueC.io.enq.bits.tag.warp := warpAccess + respQueueC.io.enq.bits.tag.set := stateRegC.set + respQueueC.io.enq.bits.tag.index := stateRegC.index + respQueueC.io.enq.bits.data := io.respC + + // serialize every two C responses into one full 4x4 C tile + val fullC = Module(new FillBuffer( + chiselTypeOf(respQueueC.io.deq.bits.data), 2/*substeps*/ + )) + fullC.io.enq.valid := respQueueC.io.deq.valid + fullC.io.enq.bits := respQueueC.io.deq.bits.data + respQueueC.io.deq.ready := fullC.io.enq.ready + val fullCTag = Module(new Queue( + new TensorMemTag, entries = 1, pipe = true + )) + fullCTag.io.enq.valid := respQueueC.io.deq.valid + fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag // =========================================================================== // Execute stage @@ -391,18 +413,12 @@ class TensorCoreDecoupled( fullB.io.deq.ready := fullBBuf.io.enq.ready fullBTag.io.deq.ready := fullBBuf.io.enq.ready - // fullC is instiated at the access stage - - val fullCTag = Module(new Queue( - new TensorMemTag, entries = 1, pipe = true - )) - fullCTag.io.enq.valid := respQueueB.valid - fullCTag.io.enq.bits := respQueueB.bits.tag + // fullC/fullCTag is instiated at the access stage val fullCBuf = Module(new Queue( new Bundle { - val data = chiselTypeOf(fullC.io.deq.bits) val tag = new TensorMemTag + val data = chiselTypeOf(fullC.io.deq.bits) }, entries = 1, pipe = true )) fullCBuf.io.enq.valid := fullC.io.deq.valid @@ -413,7 +429,8 @@ class TensorCoreDecoupled( val dpuReady = Wire(Bool()) val dpuFire = Wire(Bool()) - val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid + val operandsValid = + fullABuf.io.deq.valid && fullBBuf.io.deq.valid && fullCBuf.io.deq.valid dpuFire := operandsValid && dpuReady dontTouch(dpuFire) @@ -465,7 +482,7 @@ class TensorCoreDecoupled( (substepCompute === 1.U) fullBBuf.io.deq.ready := dpuFire && shouldDequeueB - // C buf should be synced with B buf + // C deq should be synced with B deq fullCBuf.io.deq.ready := dpuFire && shouldDequeueB dontTouch(respQueueA) @@ -584,13 +601,6 @@ class TensorCoreDecoupled( // val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) - // note rd is independent to sets - def rdGen(step: UInt, substep: UInt): UInt = { - // each step produces 4x4 output tile, written by 8 threads with 2 regs per - // thread - (step << 1/*2 substeps*/) + substep - } - val warpWriteback = tagQueue.io.deq.bits.warp val setWriteback = tagQueue.io.deq.bits.set val stepWriteback = tagQueue.io.deq.bits.step