tensor: Fix access state transition to consider C req

This commit is contained in:
Hansung Kim
2024-10-25 18:23:51 -07:00
parent 991025e896
commit 1a1a4a088d

View File

@@ -205,24 +205,11 @@ class TensorCoreDecoupled(
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
val doneReqC = RegInit(false.B)
val genReqA = (state === AccessorState.access) && !doneReqA
val genReqB = (state === AccessorState.access) && !doneReqB
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
stateA.set := 0.U
stateA.index := 0.U
stateB.set := 0.U
stateB.index := 0.U
}
allReqsDone := doneReqA && doneReqB
// Request generation
//
@@ -288,7 +275,7 @@ class TensorCoreDecoupled(
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
(respQueueC.entries - 1).U,
(respQueueC.entries - 2).U)
val genReqC = (state === AccessorState.access) && hasSpace
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
// 1-cycle delay
respCValid := genReqC
@@ -298,27 +285,26 @@ class TensorCoreDecoupled(
// thread
(step << 1/*2 substeps*/) + substep
}
io.reqC.valid := genReqC
io.reqC.bits := 5.U // FIXME
// set/index state of the C accumulator value that will be latched ath the
// next cycle.
val stateRegC = RegInit(stateInit)
val stateC = RegInit(stateInit)
when (genReqC) {
when (stateRegC.index === lastIndex.U) {
stateRegC.set := stateRegC.set + 1.U
when (stateC.index === lastIndex.U) {
stateC.set := stateC.set + 1.U
}
stateRegC.index := stateRegC.index + 1.U
stateC.index := stateC.index + 1.U
}
dontTouch(stateRegC)
dontTouch(stateC)
// queue the regfile response to buffers
// these strictly belong to the execute stage
respQueueC.io.enq.valid := respCValid
respQueueC.io.enq.bits.tag.warp := warpAccess
respQueueC.io.enq.bits.tag.set := stateRegC.set
respQueueC.io.enq.bits.tag.index := stateRegC.index
respQueueC.io.enq.bits.tag.set := stateC.set
respQueueC.io.enq.bits.tag.index := stateC.index
respQueueC.io.enq.bits.data := io.respC
// serialize every two C responses into one full 4x4 C tile
@@ -334,6 +320,27 @@ class TensorCoreDecoupled(
fullCTag.io.enq.valid := respQueueC.io.deq.valid
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
// finalize state when everything has been accessed
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
when (lastReqC && genReqC) { doneReqC := true.B }
when (state === AccessorState.finish) {
doneReqA := false.B
doneReqB := false.B
doneReqC := false.B
stateA.set := 0.U
stateA.index := 0.U
stateB.set := 0.U
stateB.index := 0.U
stateC.set := 0.U
stateC.index := 0.U
}
allReqsDone := doneReqA && doneReqB && doneReqC
// ===========================================================================
// Execute stage
// ===========================================================================