tensor: Fix access state transition to consider C req
This commit is contained in:
@@ -205,24 +205,11 @@ class TensorCoreDecoupled(
|
||||
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
||||
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
|
||||
|
||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||
val doneReqA = RegInit(false.B)
|
||||
val doneReqB = RegInit(false.B)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||
val doneReqC = RegInit(false.B)
|
||||
val genReqA = (state === AccessorState.access) && !doneReqA
|
||||
val genReqB = (state === AccessorState.access) && !doneReqB
|
||||
when (state === AccessorState.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
stateA.set := 0.U
|
||||
stateA.index := 0.U
|
||||
stateB.set := 0.U
|
||||
stateB.index := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB
|
||||
|
||||
// Request generation
|
||||
//
|
||||
@@ -288,7 +275,7 @@ class TensorCoreDecoupled(
|
||||
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
||||
(respQueueC.entries - 1).U,
|
||||
(respQueueC.entries - 2).U)
|
||||
val genReqC = (state === AccessorState.access) && hasSpace
|
||||
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
||||
// 1-cycle delay
|
||||
respCValid := genReqC
|
||||
|
||||
@@ -298,27 +285,26 @@ class TensorCoreDecoupled(
|
||||
// thread
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
|
||||
io.reqC.valid := genReqC
|
||||
io.reqC.bits := 5.U // FIXME
|
||||
|
||||
// set/index state of the C accumulator value that will be latched ath the
|
||||
// next cycle.
|
||||
val stateRegC = RegInit(stateInit)
|
||||
val stateC = RegInit(stateInit)
|
||||
when (genReqC) {
|
||||
when (stateRegC.index === lastIndex.U) {
|
||||
stateRegC.set := stateRegC.set + 1.U
|
||||
when (stateC.index === lastIndex.U) {
|
||||
stateC.set := stateC.set + 1.U
|
||||
}
|
||||
stateRegC.index := stateRegC.index + 1.U
|
||||
stateC.index := stateC.index + 1.U
|
||||
}
|
||||
dontTouch(stateRegC)
|
||||
dontTouch(stateC)
|
||||
|
||||
// queue the regfile response to buffers
|
||||
// these strictly belong to the execute stage
|
||||
respQueueC.io.enq.valid := respCValid
|
||||
respQueueC.io.enq.bits.tag.warp := warpAccess
|
||||
respQueueC.io.enq.bits.tag.set := stateRegC.set
|
||||
respQueueC.io.enq.bits.tag.index := stateRegC.index
|
||||
respQueueC.io.enq.bits.tag.set := stateC.set
|
||||
respQueueC.io.enq.bits.tag.index := stateC.index
|
||||
respQueueC.io.enq.bits.data := io.respC
|
||||
|
||||
// serialize every two C responses into one full 4x4 C tile
|
||||
@@ -334,6 +320,27 @@ class TensorCoreDecoupled(
|
||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||
|
||||
// finalize state when everything has been accessed
|
||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||
when (lastReqC && genReqC) { doneReqC := true.B }
|
||||
when (state === AccessorState.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
doneReqC := false.B
|
||||
stateA.set := 0.U
|
||||
stateA.index := 0.U
|
||||
stateB.set := 0.U
|
||||
stateB.index := 0.U
|
||||
stateC.set := 0.U
|
||||
stateC.index := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB && doneReqC
|
||||
|
||||
// ===========================================================================
|
||||
// Execute stage
|
||||
// ===========================================================================
|
||||
|
||||
Reference in New Issue
Block a user