tensor: Add IO and latching for smem address

This commit is contained in:
Hansung Kim
2024-10-28 19:28:45 -07:00
parent c22fd20616
commit 1ae1965580

View File

@@ -56,10 +56,14 @@ class TensorCoreDecoupled(
val laneWidth = 4/*bytes*/ * 8/*bits*/ val laneWidth = 4/*bytes*/ * 8/*bits*/
val memWidth = numLanes * laneWidth val memWidth = numLanes * laneWidth
val numFPRegBits = log2Ceil(numFPRegs) val numFPRegBits = log2Ceil(numFPRegs)
val addressWidth = 32
val io = IO(new Bundle { val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle { val initiate = Flipped(Decoupled(new Bundle {
val wid = UInt(numWarpBits.W) val wid = UInt(numWarpBits.W)
// SMEM start address of A and B tile
val addressA = UInt(addressWidth.W)
val addressB = UInt(addressWidth.W)
})) }))
val writeback = Decoupled(new Bundle { val writeback = Decoupled(new Bundle {
val last = Bool() val last = Bool()
@@ -80,7 +84,7 @@ class TensorCoreDecoupled(
sourceWidth: Int sourceWidth: Int
) extends Bundle { ) extends Bundle {
val source = UInt(sourceWidth.W) val source = UInt(sourceWidth.W)
val address = UInt(32.W) val address = UInt(addressWidth.W)
} }
class TensorMemResp( class TensorMemResp(
sourceWidth: Int, sourceWidth: Int,
@@ -140,6 +144,8 @@ class TensorCoreDecoupled(
dontTouch(allReqsDone) dontTouch(allReqsDone)
val warpAccess = RegInit(0.U(numWarpBits.W)) val warpAccess = RegInit(0.U(numWarpBits.W))
val addrAAccess = RegInit(0.U(addressWidth.W))
val addrBAccess = RegInit(0.U(addressWidth.W))
class BlockState extends Bundle { class BlockState extends Bundle {
val set = UInt(setBits.W) val set = UInt(setBits.W)
@@ -156,6 +162,8 @@ class TensorCoreDecoupled(
io.initiate.ready := (state === AccessorState.idle) io.initiate.ready := (state === AccessorState.idle)
when (io.initiate.fire) { when (io.initiate.fire) {
warpAccess := io.initiate.bits.wid warpAccess := io.initiate.bits.wid
addrAAccess := io.initiate.bits.addressA
addrBAccess := io.initiate.bits.addressB
assert(stateA.set === 0.U && stateA.index === 0.U && assert(stateA.set === 0.U && stateA.index === 0.U &&
stateB.set === 0.U && stateB.index === 0.U, stateB.set === 0.U && stateB.index === 0.U,
"stateA and stateB not initialized to zero") "stateA and stateB not initialized to zero")
@@ -219,10 +227,8 @@ class TensorCoreDecoupled(
// base + tileOffset // base + tileOffset
} }
// FIXME: bogus base address val addressA = addressGen(addrAAccess, stateA.set, stateA.index)
val addressA = addressGen(0.U, stateA.set, stateA.index) val addressB = addressGen(addrBAccess, stateB.set, stateB.index)
// SMEM 256KB, 8 banks: 0x8000B(32KB) per bank
val addressB = addressGen(0x8000.U, stateB.set, stateB.index)
val doneReqA = RegInit(false.B) val doneReqA = RegInit(false.B)
val doneReqB = RegInit(false.B) val doneReqB = RegInit(false.B)