tensor: Fix access state transition to consider C req
This commit is contained in:
@@ -205,24 +205,11 @@ class TensorCoreDecoupled(
|
|||||||
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
|
||||||
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
|
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
|
||||||
|
|
||||||
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
|
||||||
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
|
||||||
val doneReqA = RegInit(false.B)
|
val doneReqA = RegInit(false.B)
|
||||||
val doneReqB = RegInit(false.B)
|
val doneReqB = RegInit(false.B)
|
||||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
val doneReqC = RegInit(false.B)
|
||||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
|
||||||
val genReqA = (state === AccessorState.access) && !doneReqA
|
val genReqA = (state === AccessorState.access) && !doneReqA
|
||||||
val genReqB = (state === AccessorState.access) && !doneReqB
|
val genReqB = (state === AccessorState.access) && !doneReqB
|
||||||
when (state === AccessorState.finish) {
|
|
||||||
doneReqA := false.B
|
|
||||||
doneReqB := false.B
|
|
||||||
stateA.set := 0.U
|
|
||||||
stateA.index := 0.U
|
|
||||||
stateB.set := 0.U
|
|
||||||
stateB.index := 0.U
|
|
||||||
}
|
|
||||||
|
|
||||||
allReqsDone := doneReqA && doneReqB
|
|
||||||
|
|
||||||
// Request generation
|
// Request generation
|
||||||
//
|
//
|
||||||
@@ -288,7 +275,7 @@ class TensorCoreDecoupled(
|
|||||||
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
respQueueC.io.count <= Mux(respQueueC.io.deq.fire,
|
||||||
(respQueueC.entries - 1).U,
|
(respQueueC.entries - 1).U,
|
||||||
(respQueueC.entries - 2).U)
|
(respQueueC.entries - 2).U)
|
||||||
val genReqC = (state === AccessorState.access) && hasSpace
|
val genReqC = (state === AccessorState.access) && hasSpace && !doneReqC
|
||||||
// 1-cycle delay
|
// 1-cycle delay
|
||||||
respCValid := genReqC
|
respCValid := genReqC
|
||||||
|
|
||||||
@@ -298,27 +285,26 @@ class TensorCoreDecoupled(
|
|||||||
// thread
|
// thread
|
||||||
(step << 1/*2 substeps*/) + substep
|
(step << 1/*2 substeps*/) + substep
|
||||||
}
|
}
|
||||||
|
|
||||||
io.reqC.valid := genReqC
|
io.reqC.valid := genReqC
|
||||||
io.reqC.bits := 5.U // FIXME
|
io.reqC.bits := 5.U // FIXME
|
||||||
|
|
||||||
// set/index state of the C accumulator value that will be latched ath the
|
// set/index state of the C accumulator value that will be latched ath the
|
||||||
// next cycle.
|
// next cycle.
|
||||||
val stateRegC = RegInit(stateInit)
|
val stateC = RegInit(stateInit)
|
||||||
when (genReqC) {
|
when (genReqC) {
|
||||||
when (stateRegC.index === lastIndex.U) {
|
when (stateC.index === lastIndex.U) {
|
||||||
stateRegC.set := stateRegC.set + 1.U
|
stateC.set := stateC.set + 1.U
|
||||||
}
|
}
|
||||||
stateRegC.index := stateRegC.index + 1.U
|
stateC.index := stateC.index + 1.U
|
||||||
}
|
}
|
||||||
dontTouch(stateRegC)
|
dontTouch(stateC)
|
||||||
|
|
||||||
// queue the regfile response to buffers
|
// queue the regfile response to buffers
|
||||||
// these strictly belong to the execute stage
|
// these strictly belong to the execute stage
|
||||||
respQueueC.io.enq.valid := respCValid
|
respQueueC.io.enq.valid := respCValid
|
||||||
respQueueC.io.enq.bits.tag.warp := warpAccess
|
respQueueC.io.enq.bits.tag.warp := warpAccess
|
||||||
respQueueC.io.enq.bits.tag.set := stateRegC.set
|
respQueueC.io.enq.bits.tag.set := stateC.set
|
||||||
respQueueC.io.enq.bits.tag.index := stateRegC.index
|
respQueueC.io.enq.bits.tag.index := stateC.index
|
||||||
respQueueC.io.enq.bits.data := io.respC
|
respQueueC.io.enq.bits.data := io.respC
|
||||||
|
|
||||||
// serialize every two C responses into one full 4x4 C tile
|
// serialize every two C responses into one full 4x4 C tile
|
||||||
@@ -334,6 +320,27 @@ class TensorCoreDecoupled(
|
|||||||
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
fullCTag.io.enq.valid := respQueueC.io.deq.valid
|
||||||
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
fullCTag.io.enq.bits := respQueueC.io.deq.bits.tag
|
||||||
|
|
||||||
|
// finalize state when everything has been accessed
|
||||||
|
val lastReqA = (stateA.set === lastSet.U) && (stateA.index === lastIndex.U)
|
||||||
|
val lastReqB = (stateB.set === lastSet.U) && (stateB.index === lastIndex.U)
|
||||||
|
val lastReqC = (stateC.set === lastSet.U) && (stateC.index === lastIndex.U)
|
||||||
|
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||||
|
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||||
|
when (lastReqC && genReqC) { doneReqC := true.B }
|
||||||
|
when (state === AccessorState.finish) {
|
||||||
|
doneReqA := false.B
|
||||||
|
doneReqB := false.B
|
||||||
|
doneReqC := false.B
|
||||||
|
stateA.set := 0.U
|
||||||
|
stateA.index := 0.U
|
||||||
|
stateB.set := 0.U
|
||||||
|
stateB.index := 0.U
|
||||||
|
stateC.set := 0.U
|
||||||
|
stateC.index := 0.U
|
||||||
|
}
|
||||||
|
|
||||||
|
allReqsDone := doneReqA && doneReqB && doneReqC
|
||||||
|
|
||||||
// ===========================================================================
|
// ===========================================================================
|
||||||
// Execute stage
|
// Execute stage
|
||||||
// ===========================================================================
|
// ===========================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user