tensor: Add destination reg to IO

This commit is contained in:
Hansung Kim
2024-10-16 14:25:38 -07:00
parent e2abe1cffd
commit 444dd5d7e1

View File

@@ -28,12 +28,14 @@ class TensorCoreDecoupled(
val numWarps: Int,
val numLanes: Int,
val numSourceIds: Int,
val tilingParams: TensorTilingParams
val tilingParams: TensorTilingParams,
val numFPRegs: Int = 32
) extends Module {
val numWarpBits = log2Ceil(numWarps)
val wordSize = 4 // TODO FP16
val sourceWidth = log2Ceil(numSourceIds)
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
val numFPRegBits = log2Ceil(numFPRegs)
val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle {
@@ -42,6 +44,7 @@ class TensorCoreDecoupled(
val writeback = Decoupled(new Bundle {
val last = Bool()
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
})
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
@@ -218,8 +221,17 @@ class TensorCoreDecoupled(
// FIXME: this need to change to dpu_fire
val nextStepExecute = io.writeback.fire
def rdGen(set: UInt, step: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
// thread
require(numLanes == 8, "currently assumes 8-wide warps")
(Cat(set, step) >> 1/*2 regs/thread*/)
// FIXME: add substep here
}
io.writeback.valid := bothQueueValid
io.writeback.bits.wid := warpReg
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
// FIXME: debug dummy: pipe A directly to writeback