tensor: Fix access state transition to consider C req

This commit is contained in:
Hansung Kim
2024-10-25 18:23:51 -07:00
parent 991025e896
commit 1a1a4a088d

View File

@@ -205,24 +205,11 @@ class TensorCoreDecoupled(
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank // SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
val addressB = addressGen(0x8000.U, stateB.set, stateB.index) val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val doneReqA = RegInit(false.B) val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B) val doneReqB = RegInit(false.B)
when (lastReqA && io.reqA.fire) { doneReqA := true.B } val doneReqC = RegInit(false.B)
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
val genReqA = (state === AccessorState.access) && !doneReqA val genReqA = (state === AccessorState.access) && !doneReqA
val genReqB = (state === AccessorState.access) && !doneReqB val genReqB = (state === AccessorState.access) && !doneReqB
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
stateA.set := 0.U
stateA.index := 0.U
stateB.set := 0.U
stateB.index := 0.U
}
allReqsDone := doneReqA && doneReqB
// Request generation // Request generation
// //
@@ -288,7 +275,7 @@ class TensorCoreDecoupled(
respQueueC.io.count <= Mux(respQueueC.io.deq.fire, respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
(respQueueC.entries - 1).U, (respQueueC.entries - 1).U,
(respQueueC.entries - 2).U) (respQueueC.entries - 2).U)
val genReqC = (state === AccessorState.access) && hasSpace val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
// 1-cycle delay // 1-cycle delay
respCValid := genReqC respCValid := genReqC
@@ -298,27 +285,26 @@ class TensorCoreDecoupled(
// thread // thread
(step << 1/*2 substeps*/) + substep (step << 1/*2 substeps*/) + substep
} }
io.reqC.valid := genReqC io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME io.reqC.bits := 5.U // FIXME
// set/index state of the C accumulator value that will be latched ath the // set/index state of the C accumulator value that will be latched ath the
// next cycle. // next cycle.
val stateRegC = RegInit(stateInit) val stateC = RegInit(stateInit)
when (genReqC) { when (genReqC) {
when (stateRegC.index === lastIndex.U) { when (stateC.index === lastIndex.U) {
stateRegC.set := stateRegC.set + 1.U stateC.set := stateC.set + 1.U
} }
stateRegC.index := stateRegC.index + 1.U stateC.index := stateC.index + 1.U
} }
dontTouch(stateRegC) dontTouch(stateC)
// queue the regfile response to buffers // queue the regfile response to buffers
// these strictly belong to the execute stage // these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccess respQueueC.io.enq.bits.tag.warp := warpAccess
respQueueC.io.enq.bits.tag.set := stateRegC.set respQueueC.io.enq.bits.tag.set := stateC.set
respQueueC.io.enq.bits.tag.index := stateRegC.index respQueueC.io.enq.bits.tag.index := stateC.index
respQueueC.io.enq.bits.data := io.respC respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile // serialize every two C responses into one full 4x4 C tile
@@ -334,6 +320,27 @@ class TensorCoreDecoupled(
fullCTag.io.enq.valid := respQueueC.io.deq.valid fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
// finalize state when everything has been accessed
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
when (lastReqC && genReqC) { doneReqC := true.B }
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
doneReqC := false.B
stateA.set := 0.U
stateA.index := 0.U
stateB.set := 0.U
stateB.index := 0.U
stateC.set := 0.U
stateC.index := 0.U
}
allReqsDone := doneReqA && doneReqB && doneReqC
// =========================================================================== // ===========================================================================
// Execute stage // Execute stage
// =========================================================================== // ===========================================================================