tensor: Add destination reg to IO
This commit is contained in:
@@ -28,12 +28,14 @@ class TensorCoreDecoupled(
|
||||
val numWarps: Int,
|
||||
val numLanes: Int,
|
||||
val numSourceIds: Int,
|
||||
val tilingParams: TensorTilingParams
|
||||
val tilingParams: TensorTilingParams,
|
||||
val numFPRegs: Int = 32
|
||||
) extends Module {
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val wordSize = 4 // TODO FP16
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||
val numFPRegBits = log2Ceil(numFPRegs)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val initiate = Flipped(Decoupled(new Bundle {
|
||||
@@ -42,6 +44,7 @@ class TensorCoreDecoupled(
|
||||
val writeback = Decoupled(new Bundle {
|
||||
val last = Bool()
|
||||
val wid = UInt(numWarpBits.W)
|
||||
val rd = UInt(numFPRegBits.W)
|
||||
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
|
||||
})
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
@@ -218,8 +221,17 @@ class TensorCoreDecoupled(
|
||||
// FIXME: this need to change to dpu_fire
|
||||
val nextStepExecute = io.writeback.fire
|
||||
|
||||
def rdGen(set: UInt, step: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
// thread
|
||||
require(numLanes == 8, "currently assumes 8-wide warps")
|
||||
(Cat(set, step) >> 1/*2 regs/thread*/)
|
||||
// FIXME: add substep here
|
||||
}
|
||||
|
||||
io.writeback.valid := bothQueueValid
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||
|
||||
// FIXME: debug dummy: pipe A directly to writeback
|
||||
|
||||
Reference in New Issue
Block a user