diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 903cc1d..d266edd 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -56,10 +56,14 @@ class TensorCoreDecoupled( val laneWidth = 4/*bytes*/ * 8/*bits*/ val memWidth = numLanes * laneWidth val numFPRegBits = log2Ceil(numFPRegs) + val addressWidth = 32 val io = IO(new Bundle { val initiate = Flipped(Decoupled(new Bundle { val wid = UInt(numWarpBits.W) + // SMEM start address of A and B tile + val addressA = UInt(addressWidth.W) + val addressB = UInt(addressWidth.W) })) val writeback = Decoupled(new Bundle { val last = Bool() @@ -80,7 +84,7 @@ class TensorCoreDecoupled( sourceWidth: Int ) extends Bundle { val source = UInt(sourceWidth.W) - val address = UInt(32.W) + val address = UInt(addressWidth.W) } class TensorMemResp( sourceWidth: Int, @@ -140,6 +144,8 @@ class TensorCoreDecoupled( dontTouch(allReqsDone) val warpAccess = RegInit(0.U(numWarpBits.W)) + val addrAAccess = RegInit(0.U(addressWidth.W)) + val addrBAccess = RegInit(0.U(addressWidth.W)) class BlockState extends Bundle { val set = UInt(setBits.W) @@ -156,6 +162,8 @@ class TensorCoreDecoupled( io.initiate.ready := (state === AccessorState.idle) when (io.initiate.fire) { warpAccess := io.initiate.bits.wid + addrAAccess := io.initiate.bits.addressA + addrBAccess := io.initiate.bits.addressB assert(stateA.set === 0.U && stateA.index === 0.U && stateB.set === 0.U && stateB.index === 0.U, "stateA and stateB not initialized to zero") @@ -219,10 +227,8 @@ class TensorCoreDecoupled( // base + tileOffset } - // FIXME: bogus base address - val addressA = addressGen(0.U, stateA.set, stateA.index) - // SMEM 256KB, 8 banks: 0x8000B(32KB) per bank - val addressB = addressGen(0x8000.U, stateB.set, stateB.index) + val addressA = addressGen(addrAAccess, stateA.set, stateA.index) + val addressB = addressGen(addrBAccess, stateB.set, stateB.index) val doneReqA = RegInit(false.B) val doneReqB = RegInit(false.B)