diff --git a/radiance.mk b/radiance.mk index bd84102..d0a913a 100644 --- a/radiance.mk +++ b/radiance.mk @@ -24,7 +24,7 @@ ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG)) EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER endif ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG)) - EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL + EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=1 +define+NUM_WARPS=4 +define+NUM_THREADS=4 +define+NUM_TENSOR_WARPS=2 +define+EXT_T_BLACKWELL endif ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG)) EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 diff --git a/src/main/resources/vsrc/vortex b/src/main/resources/vsrc/vortex index 323ed7d..0ad87bd 160000 --- a/src/main/resources/vsrc/vortex +++ b/src/main/resources/vsrc/vortex @@ -1 +1 @@ -Subproject commit 323ed7d7e9c3fd403a5ceb5bba7e4371aa37f6fc +Subproject commit 0ad87bde81fe0d45a2936c57975b4b6010657bca diff --git a/src/main/scala/radiance/core/TensorCoreBlackwell.scala b/src/main/scala/radiance/core/TensorCoreBlackwell.scala index 2068867..a1e4327 100644 --- a/src/main/scala/radiance/core/TensorCoreBlackwell.scala +++ b/src/main/scala/radiance/core/TensorCoreBlackwell.scala @@ -57,13 +57,21 @@ class TensorCoreBlackwell( // Direct SRAM port for TMEM (no TileLink overhead) class TmemSramPort extends Bundle { - val wen = Output(Bool()) - val ren = Output(Bool()) - val waddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) - val raddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) - val wdata = Output(UInt(memWidth.W)) - val mask = Output(UInt(maskWidth.W)) - val rdata = Input(UInt(memWidth.W)) + val aRen = Output(Bool()) + val aRready = Input(Bool()) + val aRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) + val aRdata = Input(UInt(memWidth.W)) + + val cRen = Output(Bool()) + val cRready = Input(Bool()) + val cRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) + val cRdata = Input(UInt(memWidth.W)) + + val cWen = Output(Bool()) + val cWready = Input(Bool()) + val cWaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W)) + val cWdata = Output(UInt(memWidth.W)) + val cMask = Output(UInt(maskWidth.W)) } val io = IO(new Bundle { @@ -94,7 +102,7 @@ class TensorCoreBlackwell( val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp, bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq, bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb, - cbRead, cbWrite = Value + cbRead, cbCapture, cbWrite = Value } val state = RegInit(State.idle) @@ -147,12 +155,14 @@ class TensorCoreBlackwell( io.reqA <> reqA io.reqB <> reqB - io.tmemC.wen := false.B - io.tmemC.ren := false.B - io.tmemC.waddr := 0.U - io.tmemC.raddr := 0.U - io.tmemC.wdata := 0.U - io.tmemC.mask := 0.U + io.tmemC.aRen := false.B + io.tmemC.aRaddr := 0.U + io.tmemC.cRen := false.B + io.tmemC.cRaddr := 0.U + io.tmemC.cWen := false.B + io.tmemC.cWaddr := 0.U + io.tmemC.cWdata := 0.U + io.tmemC.cMask := 0.U val wbValid = RegInit(false.B) val wbData = Reg(Vec(numLanes, UInt(laneWidth.W))) @@ -229,13 +239,15 @@ class TensorCoreBlackwell( } when(state === State.bwLoadAReq) { - io.tmemC.ren := true.B - io.tmemC.raddr := tmemABase + aFragIndex - state := State.bwLoadAResp + io.tmemC.aRen := true.B + io.tmemC.aRaddr := tmemABase + aFragIndex + when(io.tmemC.aRready) { + state := State.bwLoadAResp + } } when(state === State.bwLoadAResp) { - aBuf(aIndexReg) := io.tmemC.rdata + aBuf(aIndexReg) := io.tmemC.aRdata when(aIndexReg === (numAFragsPerSet - 1).U) { bGroupReg := 0.U bIndexReg := 0.U @@ -274,13 +286,15 @@ class TensorCoreBlackwell( } when(state === State.bwReadCReq) { - io.tmemC.ren := true.B - io.tmemC.raddr := tmemCBase + cFragIndex - state := State.bwReadCResp + io.tmemC.cRen := true.B + io.tmemC.cRaddr := tmemCBase + cFragIndex + when(io.tmemC.cRready) { + state := State.bwReadCResp + } } when(state === State.bwReadCResp) { - cDataReg := io.tmemC.rdata + cDataReg := io.tmemC.cRdata elemReg := 0.U state := State.bwCompute } @@ -303,34 +317,36 @@ class TensorCoreBlackwell( } when(state === State.bwWriteCReq) { - io.tmemC.wen := true.B - io.tmemC.waddr := tmemCBase + cFragIndex - io.tmemC.wdata := mmaDataReg.asUInt - io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) - when(substepReg === 0.U) { - substepReg := 1.U - state := State.bwReadCReq - }.elsewhen(mGroupReg =/= (numMGroups - 1).U) { - substepReg := 0.U - mGroupReg := mGroupReg + 1.U - state := State.bwReadCReq - }.elsewhen(bGroupReg =/= (numBGroups - 1).U) { - substepReg := 0.U - mGroupReg := 0.U - bGroupReg := bGroupReg + 1.U - bIndexReg := 0.U - state := State.bwLoadBReq - }.elsewhen(setReg =/= (numSets - 1).U) { - substepReg := 0.U - mGroupReg := 0.U - bGroupReg := 0.U - bIndexReg := 0.U - setReg := setReg + 1.U - aIndexReg := 0.U - state := State.bwLoadAReq - }.otherwise { - waitCounter := 7.U - state := State.bwWriteCWait + io.tmemC.cWen := true.B + io.tmemC.cWaddr := tmemCBase + cFragIndex + io.tmemC.cWdata := mmaDataReg.asUInt + io.tmemC.cMask := Fill(maskWidth, 1.U(1.W)) + when(io.tmemC.cWready) { + when(substepReg === 0.U) { + substepReg := 1.U + state := State.bwReadCReq + }.elsewhen(mGroupReg =/= (numMGroups - 1).U) { + substepReg := 0.U + mGroupReg := mGroupReg + 1.U + state := State.bwReadCReq + }.elsewhen(bGroupReg =/= (numBGroups - 1).U) { + substepReg := 0.U + mGroupReg := 0.U + bGroupReg := bGroupReg + 1.U + bIndexReg := 0.U + state := State.bwLoadBReq + }.elsewhen(setReg =/= (numSets - 1).U) { + substepReg := 0.U + mGroupReg := 0.U + bGroupReg := 0.U + bIndexReg := 0.U + setReg := setReg + 1.U + aIndexReg := 0.U + state := State.bwLoadAReq + }.otherwise { + waitCounter := 7.U + state := State.bwWriteCWait + } } } @@ -361,24 +377,26 @@ class TensorCoreBlackwell( } when(state === State.cpWrite) { - io.respA.ready := true.B + io.respA.ready := io.tmemC.cWready + io.tmemC.cWen := io.respA.valid + io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt + io.tmemC.cWdata := io.respA.bits.data + io.tmemC.cMask := Fill(maskWidth, 1.U(1.W)) when(io.respA.fire) { - io.tmemC.wen := true.B - io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt - io.tmemC.wdata := io.respA.bits.data - io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) state := State.idle } } when(state === State.ldReq) { - io.tmemC.ren := true.B - io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt - state := State.waitWb + io.tmemC.cRen := true.B + io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt + when(io.tmemC.cRready) { + state := State.waitWb + } } when(state === State.waitWb && opReg === Ops.tcgen05Ld) { - wbData := io.tmemC.rdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + wbData := io.tmemC.cRdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) wbValid := true.B state := State.idle } @@ -389,16 +407,25 @@ class TensorCoreBlackwell( } when(state === State.stWrite) { - io.tmemC.wen := true.B - io.tmemC.waddr := (addrAReg >> fragOffsetBits.U).asUInt - io.tmemC.wdata := io.respC - io.tmemC.mask := Fill(maskWidth, 1.U(1.W)) - state := State.idle + io.tmemC.cWen := true.B + io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt + io.tmemC.cWdata := io.respC + io.tmemC.cMask := Fill(maskWidth, 1.U(1.W)) + when(io.tmemC.cWready) { + state := State.idle + } } when(state === State.cbRead) { - io.tmemC.ren := true.B - io.tmemC.raddr := (addrAReg >> fragOffsetBits.U).asUInt + io.tmemC.cRen := true.B + io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt + when(io.tmemC.cRready) { + state := State.cbCapture + } + } + + when(state === State.cbCapture) { + cDataReg := io.tmemC.cRdata state := State.cbWrite } @@ -408,7 +435,7 @@ class TensorCoreBlackwell( reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) reqA.bits.address := addrBReg reqA.bits.source := sourceCounter - reqA.bits.data := io.tmemC.rdata + reqA.bits.data := cDataReg when(reqA.fire) { bumpSource() state := State.waitWb diff --git a/src/main/scala/radiance/subsystem/Configs.scala b/src/main/scala/radiance/subsystem/Configs.scala index 58c4187..239d43b 100644 --- a/src/main/scala/radiance/subsystem/Configs.scala +++ b/src/main/scala/radiance/subsystem/Configs.scala @@ -51,6 +51,7 @@ class WithRadianceCores( tensorCoreFP16: Boolean, tensorCoreDecoupled: Boolean, tensorCoreBlackwell: Boolean, + numTensorWarps: Int, startupAddress: BigInt, useVxCache: Boolean ) extends Config((site, _, up) => { @@ -63,6 +64,7 @@ class WithRadianceCores( tensorCoreFP16 = tensorCoreFP16, tensorCoreDecoupled = tensorCoreDecoupled, tensorCoreBlackwell = tensorCoreBlackwell, + numTensorWarps = numTensorWarps, startupAddress = startupAddress ), btb = None, @@ -101,6 +103,7 @@ class WithRadianceCores( def this(n: Int, location: HierarchicalLocation = InSubsystem, tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false, tensorCoreBlackwell: Boolean = false, + numTensorWarps: Int = 4, startupAddress: BigInt = BigInt("10100", 16), useVxCache: Boolean = false) = this(n, location, RocketCrossingParams( @@ -110,7 +113,7 @@ class WithRadianceCores( case InSubsystem => CBUS case InCluster(clusterId) => CCBUS(clusterId) } - ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, startupAddress, useVxCache) + ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, numTensorWarps, startupAddress, useVxCache) } class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => { diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index d74e91b..40ebc82 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -102,6 +102,7 @@ case class VortexCoreParams( tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core + numTensorWarps: Int = 4, startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata haveCease: Boolean = true, // non-standard CEASE instruction @@ -210,7 +211,9 @@ class RadianceTile private ( case Some(false) => 1 case None => 1 } - val imemTagWidth = UUID_WIDTH + NW_WIDTH + // Must match VX_gpu_pkg.sv: ICACHE_TAG_WIDTH = domain + UUID + wid. + val imemDomainWidth = 1 + val imemTagWidth = imemDomainWidth + UUID_WIDTH + NW_WIDTH require(numWarps >= numLsuLanes, s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})") @@ -286,8 +289,16 @@ class RadianceTile private ( } val tcSmemSize = 32 + val numTensorWarps = radianceParams.core.numTensorWarps + val numScalarWarps = numWarps - numTensorWarps + require(numTensorWarps > 0 && numTensorWarps < numWarps, + s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})") + val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1 + if (radianceParams.core.tensorCoreBlackwell) { + require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp") + } val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell - val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0 + val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0 val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i => TLClientNode(Seq(TLMasterPortParameters.v2( masters = Seq(TLMasterParameters.v2( @@ -304,19 +315,21 @@ class RadianceTile private ( } // For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand) - // tcGmemNode provides global memory access for cp (global→tmem) and cb (tmem→global) - val tcGmemNode = if (radianceParams.core.tensorCoreBlackwell) Some(TLClientNode(Seq( - TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2( - name = s"rad_tc_gmem_${radianceParams.coreId}", - sourceId = IdRange(0, 1 << dmemSourceWidth), - supports = TLSlaveToMasterTransferSizes( - probe = TransferSizes(1, tcSmemSize), - get = TransferSizes(1, tcSmemSize), - putFull = TransferSizes(1, tcSmemSize), - ), - requestFifo = true - ))) - ))) else None + // tcGmemNodes provide global memory access for cp (global→tmem) and cb (tmem→global) + val tcGmemNodes = if (radianceParams.core.tensorCoreBlackwell) { + Seq.tabulate(numTensorCores) { i => + TLClientNode(Seq(TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2( + name = s"rad_tc_gmem_${radianceParams.coreId}_$i", + sourceId = IdRange(0, 1 << dmemSourceWidth), + supports = TLSlaveToMasterTransferSizes( + probe = TransferSizes(1, tcSmemSize), + get = TransferSizes(1, tcSmemSize), + putFull = TransferSizes(1, tcSmemSize), + ), + requestFifo = true + ))))) + } + } else Seq.empty // combine outgoing per-lane dmemNode into 1 idenity node // @@ -406,7 +419,7 @@ class RadianceTile private ( // imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ } tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode - tcGmemNode.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n } + tcGmemNodes.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n } } /* below are copied from rocket */ @@ -822,86 +835,160 @@ class RadianceTileModuleImp(outer: RadianceTile) core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_tag := DontCare } + core.io.tc_tmem_A_rready := DontCare + core.io.tc_tmem_A_rdata := DontCare + core.io.tc_tmem_C_rready := DontCare + core.io.tc_tmem_C_rdata := DontCare + core.io.tc_tmem_C_wready := DontCare } def connectTensorBlackwell = { if (outer.radianceParams.core.tensorCoreBlackwell) { require(outer.tcSmemNodes.nonEmpty) + require(outer.tcSmemNodes.length == outer.numTensorCores) + require(outer.tcGmemNodes.length == outer.numTensorCores) - // TMEM C matrix: direct SRAM (no TileLink), connected via VortexCore IO + val nTC = outer.numTensorCores + val tcPorts = 3 + val tcDataBits = outer.tcSmemSize * 8 + val tmemAddrBits = 9 + val tmemDataBits = outer.numLsuLanes * 32 + val tmemMaskBits = outer.numLsuLanes * 4 + + def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx) + def port(tc: Int, p: Int): Int = tc * tcPorts + p + + val tcAReady = Wire(Vec(nTC * tcPorts, Bool())) + val tcDValid = Wire(Vec(nTC * tcPorts, Bool())) + val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W))) + val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W))) + tcAReady.foreach(_ := false.B) + tcDValid.foreach(_ := false.B) + tcDData.foreach(_ := 0.U) + tcDTag.foreach(_ := 0.U) + + // TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C. // Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem( tmemDepth, UInt((outer.tcSmemSize * 8).W))) - tmem.io.ren0 := core.io.tc_tmem_C_ren - tmem.io.raddr0 := core.io.tc_tmem_C_raddr - core.io.tc_tmem_C_rdata := tmem.io.rdata0 - tmem.io.ren1 := false.B - tmem.io.raddr1 := 0.U - tmem.io.wen := core.io.tc_tmem_C_wen - tmem.io.waddr := core.io.tc_tmem_C_waddr - tmem.io.wdata := core.io.tc_tmem_C_wdata - tmem.io.mask := core.io.tc_tmem_C_mask - // smem_B (port 2): Global Memory via TileLink - val smemBBundle = new { - val addr = core.io.tc_a_bits_address(95, 64) - val tag = core.io.tc_a_bits_tag(8 + outer.tensorTagWidth - 1, 8) - val write = core.io.tc_a_bits_write(2) - val mask = core.io.tc_a_bits_mask(95, 64) - val data = core.io.tc_a_bits_data(767, 512) - val aValid = core.io.tc_a_valid(2) - val dReady = core.io.tc_d_ready(2) + val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC)) + val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC)) + + class TmemWriteReq extends Bundle { + val addr = UInt(tmemAddrBits.W) + val data = UInt(tmemDataBits.W) + val mask = UInt(tmemMaskBits.W) } - val client = outer.tcSmemNodes.head.out.head - val adapter = Module(new VortexTLAdapter( - outer.smemSourceWidth, - new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), - new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), - client - )) - adapter.io.inReq.bits <> DontCare - adapter.io.inReq.valid := smemBBundle.aValid - adapter.io.inReq.bits.address := smemBBundle.addr - adapter.io.inReq.bits.source := smemBBundle.tag - adapter.io.inReq.bits.size := 5.U - adapter.io.inReq.bits.opcode := Mux(smemBBundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) - adapter.io.inReq.bits.mask := smemBBundle.mask - adapter.io.inReq.bits.data := smemBBundle.data - adapter.io.inResp.ready := smemBBundle.dReady - client._1.a <> adapter.io.outReq - adapter.io.outResp <> client._1.d + val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC)) - // port 0: global memory (cp/cb) - val gmemClient = outer.tcGmemNode.get.out.head - val gmemAdapter = Module(new VortexTLAdapter( - outer.dmemSourceWidth, - new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), - new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), - gmemClient - )) - gmemAdapter.io.inReq.bits <> DontCare - gmemAdapter.io.inReq.valid := core.io.tc_a_valid(0) - gmemAdapter.io.inReq.bits.address := core.io.tc_a_bits_address(31, 0) - gmemAdapter.io.inReq.bits.source := core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) - gmemAdapter.io.inReq.bits.size := 5.U - gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(0).asBool, TLMessages.PutFullData, TLMessages.Get) - gmemAdapter.io.inReq.bits.mask := core.io.tc_a_bits_mask(31, 0) - gmemAdapter.io.inReq.bits.data := core.io.tc_a_bits_data(255, 0) - gmemAdapter.io.inResp.ready := core.io.tc_d_ready(0) - gmemClient._1.a <> gmemAdapter.io.outReq - gmemAdapter.io.outResp <> gmemClient._1.d + (0 until nTC).foreach { tc => + aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc) + aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc) + cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc) + cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc) + cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc) + cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc) + cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc) + cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc) + } - core.io.tc_a_ready := Cat(adapter.io.inReq.ready, 0.U(1.W), gmemAdapter.io.inReq.ready) - core.io.tc_d_valid := Cat(adapter.io.inResp.valid, 0.U(1.W), gmemAdapter.io.inResp.valid) - core.io.tc_d_bits_data := Cat(adapter.io.inResp.bits.data, 0.U((outer.tcSmemSize * 8).W), gmemAdapter.io.inResp.bits.data) - core.io.tc_d_bits_tag := Cat(adapter.io.inResp.bits.source, 0.U(outer.tensorTagWidth.W), gmemAdapter.io.inResp.bits.source) + aReadArb.io.out.ready := true.B + cReadArb.io.out.ready := true.B + cWriteArb.io.out.ready := true.B + + tmem.io.ren0 := aReadArb.io.out.fire + tmem.io.raddr0 := aReadArb.io.out.bits + tmem.io.ren1 := cReadArb.io.out.fire + tmem.io.raddr1 := cReadArb.io.out.bits + tmem.io.wen := cWriteArb.io.out.fire + tmem.io.waddr := cWriteArb.io.out.bits.addr + tmem.io.wdata := cWriteArb.io.out.bits.data + tmem.io.mask := cWriteArb.io.out.bits.mask + + val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W))) + val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W))) + core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt + core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt + core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt + core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc => + Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W)) + }).asUInt + core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc => + Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W)) + }).asUInt + + // port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them. + (0 until nTC).foreach { tc => + val p2 = port(tc, 2) + val client = outer.tcSmemNodes(tc).out.head + val adapter = Module(new VortexTLAdapter( + outer.smemSourceWidth, + new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), + new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), + client + )) + adapter.io.inReq.bits <> DontCare + adapter.io.inReq.valid := core.io.tc_a_valid(p2) + adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2) + adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2) + adapter.io.inReq.bits.size := 5.U + adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get) + adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2) + adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2) + adapter.io.inResp.ready := core.io.tc_d_ready(p2) + client._1.a <> adapter.io.outReq + adapter.io.outResp <> client._1.d + + tcAReady(p2) := adapter.io.inReq.ready + tcDValid(p2) := adapter.io.inResp.valid + tcDData(p2) := adapter.io.inResp.bits.data + tcDTag(p2) := adapter.io.inResp.bits.source + } + + // port 0: global memory (cp/cb), one TL client per tensor core. + (0 until nTC).foreach { tc => + val p0 = port(tc, 0) + val gmemClient = outer.tcGmemNodes(tc).out.head + val gmemAdapter = Module(new VortexTLAdapter( + outer.dmemSourceWidth, + new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), + new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits), + gmemClient + )) + gmemAdapter.io.inReq.bits <> DontCare + gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0) + gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0) + gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0) + gmemAdapter.io.inReq.bits.size := 5.U + gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get) + gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0) + gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0) + gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0) + gmemClient._1.a <> gmemAdapter.io.outReq + gmemAdapter.io.outResp <> gmemClient._1.d + + tcAReady(p0) := gmemAdapter.io.inReq.ready + tcDValid(p0) := gmemAdapter.io.inResp.valid + tcDData(p0) := gmemAdapter.io.inResp.bits.data + tcDTag(p0) := gmemAdapter.io.inResp.bits.source + } + + core.io.tc_a_ready := tcAReady.asUInt + core.io.tc_d_valid := tcDValid.asUInt + core.io.tc_d_bits_data := tcDData.asUInt + core.io.tc_d_bits_tag := tcDTag.asUInt } else { core.io.tc_a_ready := false.B core.io.tc_d_valid := false.B core.io.tc_d_bits_data := DontCare core.io.tc_d_bits_tag := DontCare + core.io.tc_tmem_A_rready := DontCare + core.io.tc_tmem_A_rdata := DontCare + core.io.tc_tmem_C_rready := DontCare core.io.tc_tmem_C_rdata := DontCare + core.io.tc_tmem_C_wready := DontCare } } @@ -1006,7 +1093,11 @@ class RadianceTileModuleImp(outer: RadianceTile) tensor.io.reqA.ready := false.B tensor.io.reqB.ready := false.B tensor.io.writeback.ready := false.B - tensor.io.tmemC.rdata := DontCare + tensor.io.tmemC.aRready := false.B + tensor.io.tmemC.aRdata := DontCare + tensor.io.tmemC.cRready := false.B + tensor.io.tmemC.cRdata := DontCare + tensor.io.tmemC.cWready := false.B dontTouch(tensor.io) } else { if (outer.radianceParams.core.tensorCoreFP16) { diff --git a/src/main/scala/radiance/tile/VortexCore.scala b/src/main/scala/radiance/tile/VortexCore.scala index 91b65b3..803322c 100644 --- a/src/main/scala/radiance/tile/VortexCore.scala +++ b/src/main/scala/radiance/tile/VortexCore.scala @@ -90,28 +90,36 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) + val numTensorCores = if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1 val tcPortCount = 3 - val tc_a_valid = Output(UInt(tcPortCount.W)) - val tc_a_bits_write = Output(UInt(tcPortCount.W)) - val tc_a_bits_address = Output(UInt((tcPortCount * 32).W)) - val tc_a_bits_tag = Output(UInt((tcPortCount * 4).W)) - val tc_a_bits_mask = Output(UInt((tcPortCount * 32).W)) - val tc_a_bits_data = Output(UInt((tcPortCount * 32 * 8).W)) - val tc_a_ready = Input(UInt(tcPortCount.W)) - val tc_d_valid = Input(UInt(tcPortCount.W)) - val tc_d_bits_data = Input(UInt((tcPortCount * 32 * 8).W)) - val tc_d_bits_tag = Input(UInt((tcPortCount * 4).W)) - val tc_d_ready = Output(UInt(tcPortCount.W)) + val tcFlatPortCount = tcPortCount * numTensorCores + val tc_a_valid = Output(UInt(tcFlatPortCount.W)) + val tc_a_bits_write = Output(UInt(tcFlatPortCount.W)) + val tc_a_bits_address = Output(UInt((tcFlatPortCount * 32).W)) + val tc_a_bits_tag = Output(UInt((tcFlatPortCount * 4).W)) + val tc_a_bits_mask = Output(UInt((tcFlatPortCount * 32).W)) + val tc_a_bits_data = Output(UInt((tcFlatPortCount * 32 * 8).W)) + val tc_a_ready = Input(UInt(tcFlatPortCount.W)) + val tc_d_valid = Input(UInt(tcFlatPortCount.W)) + val tc_d_bits_data = Input(UInt((tcFlatPortCount * 32 * 8).W)) + val tc_d_bits_tag = Input(UInt((tcFlatPortCount * 4).W)) + val tc_d_ready = Output(UInt(tcFlatPortCount.W)) - // Direct SRAM port for TMEM C (bypasses TileLink) + // Direct SRAM ports for shared TMEM (bypasses TileLink) val numLanes = tile.numLsuLanes - val tc_tmem_C_wen = Output(Bool()) - val tc_tmem_C_ren = Output(Bool()) - val tc_tmem_C_waddr = Output(UInt(9.W)) - val tc_tmem_C_raddr = Output(UInt(9.W)) - val tc_tmem_C_wdata = Output(UInt((numLanes * 32).W)) - val tc_tmem_C_mask = Output(UInt((numLanes * 4).W)) - val tc_tmem_C_rdata = Input(UInt((numLanes * 32).W)) + val tc_tmem_A_ren = Output(UInt(numTensorCores.W)) + val tc_tmem_A_rready = Input(UInt(numTensorCores.W)) + val tc_tmem_A_raddr = Output(UInt((numTensorCores * 9).W)) + val tc_tmem_A_rdata = Input(UInt((numTensorCores * numLanes * 32).W)) + val tc_tmem_C_ren = Output(UInt(numTensorCores.W)) + val tc_tmem_C_rready = Input(UInt(numTensorCores.W)) + val tc_tmem_C_raddr = Output(UInt((numTensorCores * 9).W)) + val tc_tmem_C_rdata = Input(UInt((numTensorCores * numLanes * 32).W)) + val tc_tmem_C_wen = Output(UInt(numTensorCores.W)) + val tc_tmem_C_wready = Input(UInt(numTensorCores.W)) + val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W)) + val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W)) + val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W)) // FIXME: hardcoded val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits @@ -147,7 +155,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters) "CORE_ID" -> tile.radianceParams.coreId, "TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0), "STARTUP_ADDR" -> tile.radianceParams.core.startupAddress, - "NUM_THREADS" -> tile.numLsuLanes + "NUM_THREADS" -> tile.numLsuLanes, + "NUM_TENSOR_CORES" -> (if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1) ) ) with HasBlackBoxResource with HasBlackBoxPath { @@ -211,6 +220,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters) addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv") + addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ctrl_unit.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv") addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh") addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")