Integrate WU architecture in Radiance

This commit is contained in:
2026-05-25 19:25:59 +08:00
parent 5112f3665a
commit 87a4bbc757
6 changed files with 298 additions and 167 deletions

View File

@@ -24,7 +24,7 @@ ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
endif endif
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG)) ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=1 +define+NUM_WARPS=4 +define+NUM_THREADS=4 +define+NUM_TENSOR_WARPS=2 +define+EXT_T_BLACKWELL
endif endif
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG)) ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4

View File

@@ -57,13 +57,21 @@ class TensorCoreBlackwell(
// Direct SRAM port for TMEM (no TileLink overhead) // Direct SRAM port for TMEM (no TileLink overhead)
class TmemSramPort extends Bundle { class TmemSramPort extends Bundle {
val wen = Output(Bool()) val aRen = Output(Bool())
val ren = Output(Bool()) val aRready = Input(Bool())
val waddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) val aRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val raddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) val aRdata = Input(UInt(memWidth.W))
val wdata = Output(UInt(memWidth.W))
val mask = Output(UInt(maskWidth.W)) val cRen = Output(Bool())
val rdata = Input(UInt(memWidth.W)) val cRready = Input(Bool())
val cRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cRdata = Input(UInt(memWidth.W))
val cWen = Output(Bool())
val cWready = Input(Bool())
val cWaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
val cWdata = Output(UInt(memWidth.W))
val cMask = Output(UInt(maskWidth.W))
} }
val io = IO(new Bundle { val io = IO(new Bundle {
@@ -94,7 +102,7 @@ class TensorCoreBlackwell(
val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp, val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp,
bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq, bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq,
bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb, bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb,
cbRead, cbWrite = Value cbRead, cbCapture, cbWrite = Value
} }
val state = RegInit(State.idle) val state = RegInit(State.idle)
@@ -147,12 +155,14 @@ class TensorCoreBlackwell(
io.reqA <> reqA io.reqA <> reqA
io.reqB <> reqB io.reqB <> reqB
io.tmemC.wen := false.B io.tmemC.aRen := false.B
io.tmemC.ren := false.B io.tmemC.aRaddr := 0.U
io.tmemC.waddr := 0.U io.tmemC.cRen := false.B
io.tmemC.raddr := 0.U io.tmemC.cRaddr := 0.U
io.tmemC.wdata := 0.U io.tmemC.cWen := false.B
io.tmemC.mask := 0.U io.tmemC.cWaddr := 0.U
io.tmemC.cWdata := 0.U
io.tmemC.cMask := 0.U
val wbValid = RegInit(false.B) val wbValid = RegInit(false.B)
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W))) val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
@@ -229,13 +239,15 @@ class TensorCoreBlackwell(
} }
when(state === State.bwLoadAReq) { when(state === State.bwLoadAReq) {
io.tmemC.ren := true.B io.tmemC.aRen := true.B
io.tmemC.raddr := tmemABase + aFragIndex io.tmemC.aRaddr := tmemABase + aFragIndex
when(io.tmemC.aRready) {
state := State.bwLoadAResp state := State.bwLoadAResp
} }
}
when(state === State.bwLoadAResp) { when(state === State.bwLoadAResp) {
aBuf(aIndexReg) := io.tmemC.rdata aBuf(aIndexReg) := io.tmemC.aRdata
when(aIndexReg === (numAFragsPerSet - 1).U) { when(aIndexReg === (numAFragsPerSet - 1).U) {
bGroupReg := 0.U bGroupReg := 0.U
bIndexReg := 0.U bIndexReg := 0.U
@@ -274,13 +286,15 @@ class TensorCoreBlackwell(
} }
when(state === State.bwReadCReq) { when(state === State.bwReadCReq) {
io.tmemC.ren := true.B io.tmemC.cRen := true.B
io.tmemC.raddr := tmemCBase + cFragIndex io.tmemC.cRaddr := tmemCBase + cFragIndex
when(io.tmemC.cRready) {
state := State.bwReadCResp state := State.bwReadCResp
} }
}
when(state === State.bwReadCResp) { when(state === State.bwReadCResp) {
cDataReg := io.tmemC.rdata cDataReg := io.tmemC.cRdata
elemReg := 0.U elemReg := 0.U
state := State.bwCompute state := State.bwCompute
} }
@@ -303,10 +317,11 @@ class TensorCoreBlackwell(
} }
when(state === State.bwWriteCReq) { when(state === State.bwWriteCReq) {
io.tmemC.wen := true.B io.tmemC.cWen := true.B
io.tmemC.waddr := tmemCBase + cFragIndex io.tmemC.cWaddr := tmemCBase + cFragIndex
io.tmemC.wdata := mmaDataReg.asUInt io.tmemC.cWdata := mmaDataReg.asUInt
io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
when(substepReg === 0.U) { when(substepReg === 0.U) {
substepReg := 1.U substepReg := 1.U
state := State.bwReadCReq state := State.bwReadCReq
@@ -333,6 +348,7 @@ class TensorCoreBlackwell(
state := State.bwWriteCWait state := State.bwWriteCWait
} }
} }
}
when(state === State.bwWriteCWait) { when(state === State.bwWriteCWait) {
when(waitCounter === 0.U) { when(waitCounter === 0.U) {
@@ -361,24 +377,26 @@ class TensorCoreBlackwell(
} }
when(state === State.cpWrite) { when(state === State.cpWrite) {
io.respA.ready := true.B io.respA.ready := io.tmemC.cWready
io.tmemC.cWen := io.respA.valid
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.cWdata := io.respA.bits.data
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.respA.fire) { when(io.respA.fire) {
io.tmemC.wen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respA.bits.data
io.tmemC.mask := Fill(maskWidth, 1.U(1.W))
state := State.idle state := State.idle
} }
} }
when(state === State.ldReq) { when(state === State.ldReq) {
io.tmemC.ren := true.B io.tmemC.cRen := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.waitWb state := State.waitWb
} }
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) { when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
wbData := io.tmemC.rdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) wbData := io.tmemC.cRdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B wbValid := true.B
state := State.idle state := State.idle
} }
@@ -389,16 +407,25 @@ class TensorCoreBlackwell(
} }
when(state === State.stWrite) { when(state === State.stWrite) {
io.tmemC.wen := true.B io.tmemC.cWen := true.B
io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
io.tmemC.wdata := io.respC io.tmemC.cWdata := io.respC
io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
when(io.tmemC.cWready) {
state := State.idle state := State.idle
} }
}
when(state === State.cbRead) { when(state === State.cbRead) {
io.tmemC.ren := true.B io.tmemC.cRen := true.B
io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
when(io.tmemC.cRready) {
state := State.cbCapture
}
}
when(state === State.cbCapture) {
cDataReg := io.tmemC.cRdata
state := State.cbWrite state := State.cbWrite
} }
@@ -408,7 +435,7 @@ class TensorCoreBlackwell(
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrBReg reqA.bits.address := addrBReg
reqA.bits.source := sourceCounter reqA.bits.source := sourceCounter
reqA.bits.data := io.tmemC.rdata reqA.bits.data := cDataReg
when(reqA.fire) { when(reqA.fire) {
bumpSource() bumpSource()
state := State.waitWb state := State.waitWb

View File

@@ -51,6 +51,7 @@ class WithRadianceCores(
tensorCoreFP16: Boolean, tensorCoreFP16: Boolean,
tensorCoreDecoupled: Boolean, tensorCoreDecoupled: Boolean,
tensorCoreBlackwell: Boolean, tensorCoreBlackwell: Boolean,
numTensorWarps: Int,
startupAddress: BigInt, startupAddress: BigInt,
useVxCache: Boolean useVxCache: Boolean
) extends Config((site, _, up) => { ) extends Config((site, _, up) => {
@@ -63,6 +64,7 @@ class WithRadianceCores(
tensorCoreFP16 = tensorCoreFP16, tensorCoreFP16 = tensorCoreFP16,
tensorCoreDecoupled = tensorCoreDecoupled, tensorCoreDecoupled = tensorCoreDecoupled,
tensorCoreBlackwell = tensorCoreBlackwell, tensorCoreBlackwell = tensorCoreBlackwell,
numTensorWarps = numTensorWarps,
startupAddress = startupAddress startupAddress = startupAddress
), ),
btb = None, btb = None,
@@ -101,6 +103,7 @@ class WithRadianceCores(
def this(n: Int, location: HierarchicalLocation = InSubsystem, def this(n: Int, location: HierarchicalLocation = InSubsystem,
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false, tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
tensorCoreBlackwell: Boolean = false, tensorCoreBlackwell: Boolean = false,
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16), startupAddress: BigInt = BigInt("10100", 16),
useVxCache: Boolean = false) useVxCache: Boolean = false)
= this(n, location, RocketCrossingParams( = this(n, location, RocketCrossingParams(
@@ -110,7 +113,7 @@ class WithRadianceCores(
case InSubsystem => CBUS case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId) case InCluster(clusterId) => CCBUS(clusterId)
} }
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, startupAddress, useVxCache) ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, numTensorWarps, startupAddress, useVxCache)
} }
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => { class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {

View File

@@ -102,6 +102,7 @@ case class VortexCoreParams(
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
numTensorWarps: Int = 4,
startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
haveCease: Boolean = true, // non-standard CEASE instruction haveCease: Boolean = true, // non-standard CEASE instruction
@@ -210,7 +211,9 @@ class RadianceTile private (
case Some(false) => 1 case Some(false) => 1
case None => 1 case None => 1
} }
val imemTagWidth = UUID_WIDTH + NW_WIDTH // Must match VX_gpu_pkg.sv: ICACHE_TAG_WIDTH = domain + UUID + wid.
val imemDomainWidth = 1
val imemTagWidth = imemDomainWidth + UUID_WIDTH + NW_WIDTH
require(numWarps >= numLsuLanes, require(numWarps >= numLsuLanes,
s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})") s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})")
@@ -286,8 +289,16 @@ class RadianceTile private (
} }
val tcSmemSize = 32 val tcSmemSize = 32
val numTensorWarps = radianceParams.core.numTensorWarps
val numScalarWarps = numWarps - numTensorWarps
require(numTensorWarps > 0 && numTensorWarps < numWarps,
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
if (radianceParams.core.tensorCoreBlackwell) {
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
}
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0 val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i => val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2( TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2( masters = Seq(TLMasterParameters.v2(
@@ -304,10 +315,11 @@ class RadianceTile private (
} }
// For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand) // For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand)
// tcGmemNode provides global memory access for cp (global→tmem) and cb (tmem→global) // tcGmemNodes provide global memory access for cp (global→tmem) and cb (tmem→global)
val tcGmemNode = if (radianceParams.core.tensorCoreBlackwell) Some(TLClientNode(Seq( val tcGmemNodes = if (radianceParams.core.tensorCoreBlackwell) {
TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2( Seq.tabulate(numTensorCores) { i =>
name = s"rad_tc_gmem_${radianceParams.coreId}", TLClientNode(Seq(TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_gmem_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << dmemSourceWidth), sourceId = IdRange(0, 1 << dmemSourceWidth),
supports = TLSlaveToMasterTransferSizes( supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize), probe = TransferSizes(1, tcSmemSize),
@@ -315,8 +327,9 @@ class RadianceTile private (
putFull = TransferSizes(1, tcSmemSize), putFull = TransferSizes(1, tcSmemSize),
), ),
requestFifo = true requestFifo = true
))) )))))
))) else None }
} else Seq.empty
// combine outgoing per-lane dmemNode into 1 idenity node // combine outgoing per-lane dmemNode into 1 idenity node
// //
@@ -406,7 +419,7 @@ class RadianceTile private (
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ } // imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
tcGmemNode.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n } tcGmemNodes.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
} }
/* below are copied from rocket */ /* below are copied from rocket */
@@ -822,86 +835,160 @@ class RadianceTileModuleImp(outer: RadianceTile)
core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare core.io.tc_d_bits_tag := DontCare
} }
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
} }
def connectTensorBlackwell = { def connectTensorBlackwell = {
if (outer.radianceParams.core.tensorCoreBlackwell) { if (outer.radianceParams.core.tensorCoreBlackwell) {
require(outer.tcSmemNodes.nonEmpty) require(outer.tcSmemNodes.nonEmpty)
require(outer.tcSmemNodes.length == outer.numTensorCores)
require(outer.tcGmemNodes.length == outer.numTensorCores)
// TMEM C matrix: direct SRAM (no TileLink), connected via VortexCore IO val nTC = outer.numTensorCores
val tcPorts = 3
val tcDataBits = outer.tcSmemSize * 8
val tmemAddrBits = 9
val tmemDataBits = outer.numLsuLanes * 32
val tmemMaskBits = outer.numLsuLanes * 4
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
def port(tc: Int, p: Int): Int = tc * tcPorts + p
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W)))
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
tcAReady.foreach(_ := false.B)
tcDValid.foreach(_ := false.B)
tcDData.foreach(_ := 0.U)
tcDTag.foreach(_ := 0.U)
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB // Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem( val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
tmemDepth, UInt((outer.tcSmemSize * 8).W))) tmemDepth, UInt((outer.tcSmemSize * 8).W)))
tmem.io.ren0 := core.io.tc_tmem_C_ren
tmem.io.raddr0 := core.io.tc_tmem_C_raddr
core.io.tc_tmem_C_rdata := tmem.io.rdata0
tmem.io.ren1 := false.B
tmem.io.raddr1 := 0.U
tmem.io.wen := core.io.tc_tmem_C_wen
tmem.io.waddr := core.io.tc_tmem_C_waddr
tmem.io.wdata := core.io.tc_tmem_C_wdata
tmem.io.mask := core.io.tc_tmem_C_mask
// smem_B (port 2): Global Memory via TileLink val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
val smemBBundle = new { val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
val addr = core.io.tc_a_bits_address(95, 64)
val tag = core.io.tc_a_bits_tag(8 + outer.tensorTagWidth - 1, 8) class TmemWriteReq extends Bundle {
val write = core.io.tc_a_bits_write(2) val addr = UInt(tmemAddrBits.W)
val mask = core.io.tc_a_bits_mask(95, 64) val data = UInt(tmemDataBits.W)
val data = core.io.tc_a_bits_data(767, 512) val mask = UInt(tmemMaskBits.W)
val aValid = core.io.tc_a_valid(2)
val dReady = core.io.tc_d_ready(2)
} }
val client = outer.tcSmemNodes.head.out.head val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
(0 until nTC).foreach { tc =>
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
}
aReadArb.io.out.ready := true.B
cReadArb.io.out.ready := true.B
cWriteArb.io.out.ready := true.B
tmem.io.ren0 := aReadArb.io.out.fire
tmem.io.raddr0 := aReadArb.io.out.bits
tmem.io.ren1 := cReadArb.io.out.fire
tmem.io.raddr1 := cReadArb.io.out.bits
tmem.io.wen := cWriteArb.io.out.fire
tmem.io.waddr := cWriteArb.io.out.bits.addr
tmem.io.wdata := cWriteArb.io.out.bits.data
tmem.io.mask := cWriteArb.io.out.bits.mask
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
}).asUInt
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
}).asUInt
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
(0 until nTC).foreach { tc =>
val p2 = port(tc, 2)
val client = outer.tcSmemNodes(tc).out.head
val adapter = Module(new VortexTLAdapter( val adapter = Module(new VortexTLAdapter(
outer.smemSourceWidth, outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
client client
)) ))
adapter.io.inReq.bits <> DontCare adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := smemBBundle.aValid adapter.io.inReq.valid := core.io.tc_a_valid(p2)
adapter.io.inReq.bits.address := smemBBundle.addr adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2)
adapter.io.inReq.bits.source := smemBBundle.tag adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
adapter.io.inReq.bits.size := 5.U adapter.io.inReq.bits.size := 5.U
adapter.io.inReq.bits.opcode := Mux(smemBBundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := smemBBundle.mask adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2)
adapter.io.inReq.bits.data := smemBBundle.data adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2)
adapter.io.inResp.ready := smemBBundle.dReady adapter.io.inResp.ready := core.io.tc_d_ready(p2)
client._1.a <> adapter.io.outReq client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d adapter.io.outResp <> client._1.d
// port 0: global memory (cp/cb) tcAReady(p2) := adapter.io.inReq.ready
val gmemClient = outer.tcGmemNode.get.out.head tcDValid(p2) := adapter.io.inResp.valid
tcDData(p2) := adapter.io.inResp.bits.data
tcDTag(p2) := adapter.io.inResp.bits.source
}
// port 0: global memory (cp/cb), one TL client per tensor core.
(0 until nTC).foreach { tc =>
val p0 = port(tc, 0)
val gmemClient = outer.tcGmemNodes(tc).out.head
val gmemAdapter = Module(new VortexTLAdapter( val gmemAdapter = Module(new VortexTLAdapter(
outer.dmemSourceWidth, outer.dmemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
gmemClient gmemClient
)) ))
gmemAdapter.io.inReq.bits <> DontCare gmemAdapter.io.inReq.bits <> DontCare
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(0) gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
gmemAdapter.io.inReq.bits.address := core.io.tc_a_bits_address(31, 0) gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
gmemAdapter.io.inReq.bits.source := core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
gmemAdapter.io.inReq.bits.size := 5.U gmemAdapter.io.inReq.bits.size := 5.U
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(0).asBool, TLMessages.PutFullData, TLMessages.Get) gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
gmemAdapter.io.inReq.bits.mask := core.io.tc_a_bits_mask(31, 0) gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)
gmemAdapter.io.inReq.bits.data := core.io.tc_a_bits_data(255, 0) gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0)
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(0) gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
gmemClient._1.a <> gmemAdapter.io.outReq gmemClient._1.a <> gmemAdapter.io.outReq
gmemAdapter.io.outResp <> gmemClient._1.d gmemAdapter.io.outResp <> gmemClient._1.d
core.io.tc_a_ready := Cat(adapter.io.inReq.ready, 0.U(1.W), gmemAdapter.io.inReq.ready) tcAReady(p0) := gmemAdapter.io.inReq.ready
core.io.tc_d_valid := Cat(adapter.io.inResp.valid, 0.U(1.W), gmemAdapter.io.inResp.valid) tcDValid(p0) := gmemAdapter.io.inResp.valid
core.io.tc_d_bits_data := Cat(adapter.io.inResp.bits.data, 0.U((outer.tcSmemSize * 8).W), gmemAdapter.io.inResp.bits.data) tcDData(p0) := gmemAdapter.io.inResp.bits.data
core.io.tc_d_bits_tag := Cat(adapter.io.inResp.bits.source, 0.U(outer.tensorTagWidth.W), gmemAdapter.io.inResp.bits.source) tcDTag(p0) := gmemAdapter.io.inResp.bits.source
}
core.io.tc_a_ready := tcAReady.asUInt
core.io.tc_d_valid := tcDValid.asUInt
core.io.tc_d_bits_data := tcDData.asUInt
core.io.tc_d_bits_tag := tcDTag.asUInt
} else { } else {
core.io.tc_a_ready := false.B core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare core.io.tc_d_bits_tag := DontCare
core.io.tc_tmem_A_rready := DontCare
core.io.tc_tmem_A_rdata := DontCare
core.io.tc_tmem_C_rready := DontCare
core.io.tc_tmem_C_rdata := DontCare core.io.tc_tmem_C_rdata := DontCare
core.io.tc_tmem_C_wready := DontCare
} }
} }
@@ -1006,7 +1093,11 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B tensor.io.writeback.ready := false.B
tensor.io.tmemC.rdata := DontCare tensor.io.tmemC.aRready := false.B
tensor.io.tmemC.aRdata := DontCare
tensor.io.tmemC.cRready := false.B
tensor.io.tmemC.cRdata := DontCare
tensor.io.tmemC.cWready := false.B
dontTouch(tensor.io) dontTouch(tensor.io)
} else { } else {
if (outer.radianceParams.core.tensorCoreFP16) { if (outer.radianceParams.core.tensorCoreFP16) {

View File

@@ -90,28 +90,36 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val numTensorCores = if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1
val tcPortCount = 3 val tcPortCount = 3
val tc_a_valid = Output(UInt(tcPortCount.W)) val tcFlatPortCount = tcPortCount * numTensorCores
val tc_a_bits_write = Output(UInt(tcPortCount.W)) val tc_a_valid = Output(UInt(tcFlatPortCount.W))
val tc_a_bits_address = Output(UInt((tcPortCount * 32).W)) val tc_a_bits_write = Output(UInt(tcFlatPortCount.W))
val tc_a_bits_tag = Output(UInt((tcPortCount * 4).W)) val tc_a_bits_address = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_bits_mask = Output(UInt((tcPortCount * 32).W)) val tc_a_bits_tag = Output(UInt((tcFlatPortCount * 4).W))
val tc_a_bits_data = Output(UInt((tcPortCount * 32 * 8).W)) val tc_a_bits_mask = Output(UInt((tcFlatPortCount * 32).W))
val tc_a_ready = Input(UInt(tcPortCount.W)) val tc_a_bits_data = Output(UInt((tcFlatPortCount * 32 * 8).W))
val tc_d_valid = Input(UInt(tcPortCount.W)) val tc_a_ready = Input(UInt(tcFlatPortCount.W))
val tc_d_bits_data = Input(UInt((tcPortCount * 32 * 8).W)) val tc_d_valid = Input(UInt(tcFlatPortCount.W))
val tc_d_bits_tag = Input(UInt((tcPortCount * 4).W)) val tc_d_bits_data = Input(UInt((tcFlatPortCount * 32 * 8).W))
val tc_d_ready = Output(UInt(tcPortCount.W)) val tc_d_bits_tag = Input(UInt((tcFlatPortCount * 4).W))
val tc_d_ready = Output(UInt(tcFlatPortCount.W))
// Direct SRAM port for TMEM C (bypasses TileLink) // Direct SRAM ports for shared TMEM (bypasses TileLink)
val numLanes = tile.numLsuLanes val numLanes = tile.numLsuLanes
val tc_tmem_C_wen = Output(Bool()) val tc_tmem_A_ren = Output(UInt(numTensorCores.W))
val tc_tmem_C_ren = Output(Bool()) val tc_tmem_A_rready = Input(UInt(numTensorCores.W))
val tc_tmem_C_waddr = Output(UInt(9.W)) val tc_tmem_A_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_raddr = Output(UInt(9.W)) val tc_tmem_A_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_wdata = Output(UInt((numLanes * 32).W)) val tc_tmem_C_ren = Output(UInt(numTensorCores.W))
val tc_tmem_C_mask = Output(UInt((numLanes * 4).W)) val tc_tmem_C_rready = Input(UInt(numTensorCores.W))
val tc_tmem_C_rdata = Input(UInt((numLanes * 32).W)) val tc_tmem_C_raddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_wen = Output(UInt(numTensorCores.W))
val tc_tmem_C_wready = Input(UInt(numTensorCores.W))
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
// FIXME: hardcoded // FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
@@ -147,7 +155,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
"CORE_ID" -> tile.radianceParams.coreId, "CORE_ID" -> tile.radianceParams.coreId,
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0), "TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
"STARTUP_ADDR" -> tile.radianceParams.core.startupAddress, "STARTUP_ADDR" -> tile.radianceParams.core.startupAddress,
"NUM_THREADS" -> tile.numLsuLanes "NUM_THREADS" -> tile.numLsuLanes,
"NUM_TENSOR_CORES" -> (if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1)
) )
) )
with HasBlackBoxResource with HasBlackBoxPath { with HasBlackBoxResource with HasBlackBoxPath {
@@ -211,6 +220,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ctrl_unit.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv")
addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh") addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh")
addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")