diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index 7c07564..92f98b7 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -28,12 +28,14 @@ class TensorCoreDecoupled( val numWarps: Int, val numLanes: Int, val numSourceIds: Int, - val tilingParams: TensorTilingParams + val tilingParams: TensorTilingParams, + val numFPRegs: Int = 32 ) extends Module { val numWarpBits = log2Ceil(numWarps) val wordSize = 4 // TODO FP16 val sourceWidth = log2Ceil(numSourceIds) val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16 + val numFPRegBits = log2Ceil(numFPRegs) val io = IO(new Bundle { val initiate = Flipped(Decoupled(new Bundle { @@ -42,6 +44,7 @@ class TensorCoreDecoupled( val writeback = Decoupled(new Bundle { val last = Bool() val wid = UInt(numWarpBits.W) + val rd = UInt(numFPRegBits.W) val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W)) }) val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) @@ -218,8 +221,17 @@ class TensorCoreDecoupled( // FIXME: this need to change to dpu_fire val nextStepExecute = io.writeback.fire + def rdGen(set: UInt, step: UInt): UInt = { + // each step produces 4x4 output tile, written by 8 threads with 2 regs per + // thread + require(numLanes == 8, "currently assumes 8-wide warps") + (Cat(set, step) >> 1/*2 regs/thread*/) + // FIXME: add substep here + } + io.writeback.valid := bothQueueValid io.writeback.bits.wid := warpReg + io.writeback.bits.rd := rdGen(setExecute, stepExecute) io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute) // FIXME: debug dummy: pipe A directly to writeback