From 136cf70a5800fd466758c4f22285668ef50860ed Mon Sep 17 00:00:00 2001 From: abnerhexu Date: Sat, 25 Apr 2026 10:15:31 +0800 Subject: [PATCH] Add Blackwell tensor core baseline plumbing --- radiance.mk | 3 + src/main/resources/vsrc/vortex | 2 +- .../radiance/core/TensorCoreBlackwell.scala | 238 ++++++++++++++++++ .../scala/radiance/subsystem/Configs.scala | 21 +- .../scala/radiance/tile/RadianceTile.scala | 141 ++++++++++- src/main/scala/radiance/tile/VortexCore.scala | 5 + .../radiance/TensorCoreBlackwellTest.scala | 125 +++++++++ 7 files changed, 528 insertions(+), 7 deletions(-) create mode 100644 src/main/scala/radiance/core/TensorCoreBlackwell.scala create mode 100644 src/test/scala/radiance/TensorCoreBlackwellTest.scala diff --git a/radiance.mk b/radiance.mk index 82e9a6f..bd84102 100644 --- a/radiance.mk +++ b/radiance.mk @@ -23,6 +23,9 @@ endif ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG)) EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER endif +ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG)) + EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL +endif ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG)) EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 endif diff --git a/src/main/resources/vsrc/vortex b/src/main/resources/vsrc/vortex index f1d0fac..cb912d3 160000 --- a/src/main/resources/vsrc/vortex +++ b/src/main/resources/vsrc/vortex @@ -1 +1 @@ -Subproject commit f1d0fac51869eb10b218fead7607459f96ec99b6 +Subproject commit cb912d3b8b689683f0a283039aa4c1633cddd2f3 diff --git a/src/main/scala/radiance/core/TensorCoreBlackwell.scala b/src/main/scala/radiance/core/TensorCoreBlackwell.scala new file mode 100644 index 0000000..6b57361 --- /dev/null +++ b/src/main/scala/radiance/core/TensorCoreBlackwell.scala @@ -0,0 +1,238 @@ +// See LICENSE.SiFive for license details. +// See LICENSE.Berkeley for license details. + +package radiance.core + +import chisel3._ +import chisel3.util._ + +class TensorCoreBlackwell( + val numWarps: Int, + val numLanes: Int, + val half: Boolean, + val numSourceIds: Int = 16, + val numFPRegs: Int = 32 +) extends Module { + val numWarpBits = log2Ceil(numWarps) + val sourceWidth = log2Ceil(numSourceIds) + val laneWidth = 4 * 8 + val memWidth = numLanes * laneWidth + val numFPRegBits = log2Ceil(numFPRegs) + val addressWidth = 32 + val maskWidth = memWidth / 8 + + object Ops { + val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: Nil = Enum(6) + } + + class TensorMemReq( + sourceWidth: Int, + dataWidth: Int + ) extends Bundle { + val rw = Bool() + val byteen = UInt((dataWidth / 8).W) + val source = UInt(sourceWidth.W) + val address = UInt(addressWidth.W) + val data = UInt(dataWidth.W) + } + + class TensorMemResp( + sourceWidth: Int, + dataWidth: Int + ) extends Bundle { + val source = UInt(sourceWidth.W) + val data = UInt(dataWidth.W) + } + + val io = IO(new Bundle { + val initiate = Flipped(Decoupled(new Bundle { + val op = UInt(3.W) + val wid = UInt(numWarpBits.W) + val rd = UInt(numFPRegBits.W) + val addressA = UInt(addressWidth.W) + val addressB = UInt(addressWidth.W) + })) + val writeback = Decoupled(new Bundle { + val last = Bool() + val wid = UInt(numWarpBits.W) + val rd = UInt(numFPRegBits.W) + val data = Vec(numLanes, UInt(laneWidth.W)) + }) + val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth))) + val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth))) + val respC = Input(UInt(memWidth.W)) + val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth)) + val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth)) + val reqC = Output(Valid(UInt(numFPRegBits.W))) + }) + + object State extends ChiselEnum { + val idle, bwReq, bwResp, cpRead, cpWrite, ldReq, stReq, waitWb = Value + } + val state = RegInit(State.idle) + + val opReg = RegInit(0.U(3.W)) + val widReg = RegInit(0.U(numWarpBits.W)) + val rdReg = RegInit(0.U(numFPRegBits.W)) + val addrAReg = RegInit(0.U(addressWidth.W)) + val addrBReg = RegInit(0.U(addressWidth.W)) + val aDataReg = Reg(UInt(memWidth.W)) + val bDataReg = Reg(UInt(memWidth.W)) + val haveA = RegInit(false.B) + val haveB = RegInit(false.B) + val sourceCounter = RegInit(0.U(sourceWidth.W)) + + private def bumpSource(): Unit = { + sourceCounter := sourceCounter + 1.U + } + + val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth))) + val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth))) + reqA.valid := false.B + reqA.bits := 0.U.asTypeOf(reqA.bits) + reqB.valid := false.B + reqB.bits := 0.U.asTypeOf(reqB.bits) + io.reqA <> reqA + io.reqB <> reqB + + val wbValid = RegInit(false.B) + val wbData = Reg(Vec(numLanes, UInt(laneWidth.W))) + io.writeback.valid := wbValid + io.writeback.bits.last := true.B + io.writeback.bits.wid := widReg + io.writeback.bits.rd := rdReg + io.writeback.bits.data := wbData + + io.reqC.valid := false.B + io.reqC.bits := rdReg + + io.respA.ready := false.B + io.respB.ready := false.B + io.initiate.ready := state === State.idle && !wbValid + + when(io.writeback.fire) { + wbValid := false.B + } + + when(io.initiate.fire) { + opReg := io.initiate.bits.op + widReg := io.initiate.bits.wid + rdReg := io.initiate.bits.rd + addrAReg := io.initiate.bits.addressA + addrBReg := io.initiate.bits.addressB + haveA := false.B + haveB := false.B + switch(io.initiate.bits.op) { + is(Ops.bwgmma) { state := State.bwReq } + is(Ops.tcgen05Cp) { state := State.cpRead } + is(Ops.tcgen05Ld) { state := State.ldReq } + is(Ops.tcgen05St) { state := State.stReq } + is(Ops.bwgmmaWait) { state := State.idle } + is(Ops.tcgen05CpWait) { state := State.idle } + } + } + + when(state === State.bwReq) { + reqA.valid := true.B + reqA.bits.rw := false.B + reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqA.bits.address := addrAReg + reqA.bits.source := sourceCounter + + reqB.valid := true.B + reqB.bits.rw := false.B + reqB.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqB.bits.address := addrBReg + reqB.bits.source := sourceCounter + + io.reqC.valid := true.B + when(reqA.fire && reqB.fire) { + bumpSource() + state := State.bwResp + } + } + + when(state === State.bwResp) { + io.respA.ready := true.B + io.respB.ready := true.B + when(io.respA.fire) { + aDataReg := io.respA.bits.data + haveA := true.B + } + when(io.respB.fire) { + bDataReg := io.respB.bits.data + haveB := true.B + } + when(haveA && haveB) { + val cWords = io.respC.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + val aWords = aDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + val bWords = bDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + for (i <- 0 until numLanes) { + wbData(i) := aWords(i) + bWords(i) + cWords(i) + } + wbValid := true.B + state := State.idle + } + } + + when(state === State.cpRead) { + reqB.valid := true.B + reqB.bits.rw := false.B + reqB.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqB.bits.address := addrBReg + reqB.bits.source := sourceCounter + when(reqB.fire) { + bumpSource() + state := State.cpWrite + } + } + + when(state === State.cpWrite) { + io.respB.ready := reqA.ready + reqA.valid := io.respB.valid + reqA.bits.rw := true.B + reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqA.bits.address := addrAReg + reqA.bits.source := sourceCounter + reqA.bits.data := io.respB.bits.data + when(reqA.fire) { + bumpSource() + state := State.idle + } + } + + when(state === State.ldReq) { + reqA.valid := true.B + reqA.bits.rw := false.B + reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqA.bits.address := addrAReg + reqA.bits.source := sourceCounter + when(reqA.fire) { + bumpSource() + state := State.waitWb + } + } + + when(state === State.waitWb && opReg === Ops.tcgen05Ld) { + io.respA.ready := !wbValid + when(io.respA.fire) { + wbData := io.respA.bits.data.asTypeOf(Vec(numLanes, UInt(laneWidth.W))) + wbValid := true.B + state := State.idle + } + } + + when(state === State.stReq) { + io.reqC.valid := true.B + reqA.valid := true.B + reqA.bits.rw := true.B + reqA.bits.byteen := Fill(maskWidth, 1.U(1.W)) + reqA.bits.address := addrAReg + reqA.bits.source := sourceCounter + reqA.bits.data := io.respC + when(reqA.fire) { + bumpSource() + state := State.idle + } + } +} diff --git a/src/main/scala/radiance/subsystem/Configs.scala b/src/main/scala/radiance/subsystem/Configs.scala index c2cdb18..ab60748 100644 --- a/src/main/scala/radiance/subsystem/Configs.scala +++ b/src/main/scala/radiance/subsystem/Configs.scala @@ -50,6 +50,7 @@ class WithRadianceCores( crossing: RocketCrossingParams, tensorCoreFP16: Boolean, tensorCoreDecoupled: Boolean, + tensorCoreBlackwell: Boolean, useVxCache: Boolean ) extends Config((site, _, up) => { case TilesLocated(`location`) => { @@ -59,7 +60,8 @@ class WithRadianceCores( val vortex = RadianceTileParams( core = VortexCoreParams( tensorCoreFP16 = tensorCoreFP16, - tensorCoreDecoupled = tensorCoreDecoupled + tensorCoreDecoupled = tensorCoreDecoupled, + tensorCoreBlackwell = tensorCoreBlackwell ), btb = None, useVxCache = useVxCache, @@ -96,6 +98,7 @@ class WithRadianceCores( // constructor override that omits `crossing` def this(n: Int, location: HierarchicalLocation = InSubsystem, tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false, + tensorCoreBlackwell: Boolean = false, useVxCache: Boolean = false) = this(n, location, RocketCrossingParams( master = HierarchicalElementMasterPortParams.locationDefault(location), @@ -104,9 +107,23 @@ class WithRadianceCores( case InSubsystem => CBUS case InCluster(clusterId) => CCBUS(clusterId) } - ), tensorCoreFP16, tensorCoreDecoupled, useVxCache) + ), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, useVxCache) } +class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => { + case TilesLocated(`location`) => + up(TilesLocated(`location`)).map { + case r: RadianceTileAttachParams => + r.copy(tileParams = r.tileParams.copy( + core = r.tileParams.core.copy( + tensorCoreBlackwell = true, + tensorCoreDecoupled = false + ) + )) + case other => other + } +}) + class WithEmulatorCores( n: Int, useVxCache: Boolean diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index f4e4165..d16084d 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -101,6 +101,7 @@ case class VortexCoreParams( fpu: Option[FPUParams] = None, tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling + tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata haveCease: Boolean = true, // non-standard CEASE instruction haveSimTimeout: Boolean = true // add plusarg for simulation timeout @@ -152,6 +153,10 @@ class RadianceTile private ( p(SIMTCoreKey).isDefined, "SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile" ) + require( + !(radianceParams.core.tensorCoreDecoupled && radianceParams.core.tensorCoreBlackwell), + "tensorCoreDecoupled and tensorCoreBlackwell are mutually exclusive" + ) // NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in // radiance.mk as well! @@ -280,7 +285,9 @@ class RadianceTile private ( } val tcSmemSize = 32 - val tcSmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreDecoupled) 2 else 0) { i => + val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell + val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0 + val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i => TLClientNode(Seq(TLMasterPortParameters.v2( masters = Seq(TLMasterParameters.v2( name = s"rad_tc_${radianceParams.coreId}_$i", @@ -294,6 +301,42 @@ class RadianceTile private ( ))) } + val tmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreBlackwell) 2 else 0) { i => + TLClientNode(Seq(TLMasterPortParameters.v2( + masters = Seq(TLMasterParameters.v2( + name = s"rad_tmem_${radianceParams.coreId}_$i", + sourceId = IdRange(0, 1 << smemSourceWidth), + supports = TLSlaveToMasterTransferSizes( + probe = TransferSizes(1, tcSmemSize), + get = TransferSizes(1, tcSmemSize), + putFull = TransferSizes(1, tcSmemSize), + putPartial = TransferSizes(1, tcSmemSize), + ), + requestFifo = true + )) + ))) + } + + val tmemNode = if (radianceParams.core.tensorCoreBlackwell) { + Some(LazyModule(new TLRAM( + address = AddressSet(0x0, 0x3fff), + beatBytes = tcSmemSize + ))) + } else { + None + } + val tmemXbar = if (radianceParams.core.tensorCoreBlackwell) { + Some(LazyModule(new TLXbar)) + } else { + None + } + (tmemNode, tmemXbar) match { + case (Some(tmem), Some(xbar)) => + tmem.node :=* xbar.node + tmemNodes.foreach { node => xbar.node :=* node } + case _ => + } + // combine outgoing per-lane dmemNode into 1 idenity node // // NOTE: We need TLWidthWidget here because there might be a data width @@ -743,12 +786,18 @@ class RadianceTileModuleImp(outer: RadianceTile) val tcb0 = new { val addr = core.io.tc_a_bits_address(31, 0) val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) + val write = core.io.tc_a_bits_write(0) + val mask = core.io.tc_a_bits_mask(31, 0) + val data = core.io.tc_a_bits_data(255, 0) val aValid = core.io.tc_a_valid(0) val dReady = core.io.tc_d_ready(0) } val tcb1 = new { val addr = core.io.tc_a_bits_address(63, 32) val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4) + val write = core.io.tc_a_bits_write(1) + val mask = core.io.tc_a_bits_mask(63, 32) + val data = core.io.tc_a_bits_data(511, 256) val aValid = core.io.tc_a_valid(1) val dReady = core.io.tc_d_ready(1) } @@ -770,8 +819,9 @@ class RadianceTileModuleImp(outer: RadianceTile) adapter.io.inReq.bits.address := bundle.addr adapter.io.inReq.bits.source := bundle.tag adapter.io.inReq.bits.size := 5.U // 256 bits - adapter.io.inReq.bits.opcode := TLMessages.Get - adapter.io.inReq.bits.mask := x"ffffffff".U + adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) + adapter.io.inReq.bits.mask := bundle.mask + adapter.io.inReq.bits.data := bundle.data adapter.io.inResp.ready := bundle.dReady client._1.a <> adapter.io.outReq @@ -792,6 +842,71 @@ class RadianceTileModuleImp(outer: RadianceTile) } } + def connectTensorBlackwell = { + if (outer.radianceParams.core.tensorCoreBlackwell) { + require(outer.tmemNodes.nonEmpty) + require(outer.tcSmemNodes.nonEmpty) + + val bundles = Seq( + (outer.tmemNodes.head, new { + val addr = core.io.tc_a_bits_address(31, 0) + val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0) + val write = core.io.tc_a_bits_write(0) + val mask = core.io.tc_a_bits_mask(31, 0) + val data = core.io.tc_a_bits_data(255, 0) + val aValid = core.io.tc_a_valid(0) + val dReady = core.io.tc_d_ready(0) + }), + (outer.tcSmemNodes.head, new { + val addr = core.io.tc_a_bits_address(63, 32) + val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4) + val write = core.io.tc_a_bits_write(1) + val mask = core.io.tc_a_bits_mask(63, 32) + val data = core.io.tc_a_bits_data(511, 256) + val aValid = core.io.tc_a_valid(1) + val dReady = core.io.tc_d_ready(1) + }) + ) + + val adapters = bundles.map { case (node, bundle) => + val client = node.out.head + val adapter = Module( + new VortexTLAdapter( + outer.smemSourceWidth, + new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), + new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8), + client + ) + ) + require(adapter.io.inReq.bits.source.widthOption.get == bundle.tag.widthOption.get) + require(adapter.io.inReq.bits.address.widthOption.get == bundle.addr.widthOption.get) + adapter.io.inReq.bits <> DontCare + adapter.io.inReq.valid := bundle.aValid + adapter.io.inReq.bits.address := bundle.addr + adapter.io.inReq.bits.source := bundle.tag + adapter.io.inReq.bits.size := 5.U + adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get) + adapter.io.inReq.bits.mask := bundle.mask + adapter.io.inReq.bits.data := bundle.data + adapter.io.inResp.ready := bundle.dReady + + client._1.a <> adapter.io.outReq + adapter.io.outResp <> client._1.d + adapter + } + + core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready) + core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid) + core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data) + core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source) + } else { + core.io.tc_a_ready := false.B + core.io.tc_d_valid := false.B + core.io.tc_d_bits_data := DontCare + core.io.tc_d_bits_tag := DontCare + } + } + def connectBarrier = { require(outer.barrierMasterNode.out.length == 1) // FIXME: bits not flattened @@ -847,7 +962,11 @@ class RadianceTileModuleImp(outer: RadianceTile) connectImem connectDmem connectSmem - connectTensor + if (outer.radianceParams.core.tensorCoreBlackwell) { + connectTensorBlackwell + } else { + connectTensor + } connectBarrier connectAccelerator } @@ -874,6 +993,20 @@ class RadianceTileModuleImp(outer: RadianceTile) tensor.io.reqA.ready := false.B tensor.io.reqB.ready := false.B tensor.io.writeback.ready := false.B + } else if (outer.radianceParams.core.tensorCoreBlackwell) { + val tensorNumSourceIds = (1 << outer.tensorTagWidth) + val tensor = Module(new radiance.core.TensorCoreBlackwell( + 8, 8, half = true, tensorNumSourceIds)) + tensor.io.initiate.valid := false.B + tensor.io.initiate.bits := DontCare + tensor.io.respA.valid := false.B + tensor.io.respA.bits := DontCare + tensor.io.respB.valid := false.B + tensor.io.respB.bits := DontCare + tensor.io.respC := DontCare + tensor.io.reqA.ready := false.B + tensor.io.reqB.ready := false.B + tensor.io.writeback.ready := false.B } else { if (outer.radianceParams.core.tensorCoreFP16) { val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true)) diff --git a/src/main/scala/radiance/tile/VortexCore.scala b/src/main/scala/radiance/tile/VortexCore.scala index d078b2c..9f89527 100644 --- a/src/main/scala/radiance/tile/VortexCore.scala +++ b/src/main/scala/radiance/tile/VortexCore.scala @@ -91,8 +91,11 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) val tc_a_valid = Output(UInt(2.W)) + val tc_a_bits_write = Output(UInt(2.W)) val tc_a_bits_address = Output(UInt((2 * 32).W)) val tc_a_bits_tag = Output(UInt((2 * 4).W)) + val tc_a_bits_mask = Output(UInt((2 * 32).W)) + val tc_a_bits_data = Output(UInt((2 * 32 * 8).W)) val tc_a_ready = Input(UInt(2.W)) val tc_d_valid = Input(UInt(2.W)) val tc_d_bits_data = Input(UInt((2 * 32 * 8).W)) @@ -411,6 +414,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters) // hopper-style SMEM operand decoupling if (tile.radianceParams.core.tensorCoreDecoupled) { addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv") + } else if (tile.radianceParams.core.tensorCoreBlackwell) { + addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_blackwell_core.sv") // addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh") def addHopperTensorCore = { addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv") diff --git a/src/test/scala/radiance/TensorCoreBlackwellTest.scala b/src/test/scala/radiance/TensorCoreBlackwellTest.scala new file mode 100644 index 0000000..3c8546b --- /dev/null +++ b/src/test/scala/radiance/TensorCoreBlackwellTest.scala @@ -0,0 +1,125 @@ +package radiance.core + +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec + +class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester { + behavior of "TensorCoreBlackwell" + + private def idleIO(c: TensorCoreBlackwell): Unit = { + c.io.initiate.valid.poke(false.B) + c.io.respA.valid.poke(false.B) + c.io.respB.valid.poke(false.B) + c.io.respA.bits.source.poke(0.U) + c.io.respB.bits.source.poke(0.U) + c.io.respA.bits.data.poke(0.U) + c.io.respB.bits.data.poke(0.U) + c.io.respC.poke(0.U) + c.io.writeback.ready.poke(false.B) + } + + it should "run a minimal BWGMMA path" in { + test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c => + idleIO(c) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(0.U) + c.io.initiate.bits.wid.poke(1.U) + c.io.initiate.bits.rd.poke(3.U) + c.io.initiate.bits.addressA.poke(0x40.U) + c.io.initiate.bits.addressB.poke(0x80.U) + c.io.reqA.ready.poke(true.B) + c.io.reqB.ready.poke(true.B) + c.io.respC.poke("h0000000800000007000000060000000500000004000000030000000200000001".U) + c.clock.step() + + c.io.initiate.valid.poke(false.B) + c.io.reqA.valid.expect(true.B) + c.io.reqB.valid.expect(true.B) + c.clock.step() + + c.io.respA.valid.poke(true.B) + c.io.respB.valid.poke(true.B) + c.io.respA.bits.data.poke("h0000000800000007000000060000000500000004000000030000000200000001".U) + c.io.respB.bits.data.poke("h000000100000000f0000000e0000000d0000000c0000000b0000000a00000009".U) + c.clock.step() + + c.io.respA.valid.poke(false.B) + c.io.respB.valid.poke(false.B) + c.clock.step() + c.clock.step() + c.io.writeback.valid.expect(true.B) + c.io.writeback.bits.rd.expect(3.U) + c.io.writeback.bits.wid.expect(1.U) + c.io.writeback.ready.poke(true.B) + c.clock.step() + } + } + + it should "copy from SMEM to TMEM on TCGEN05_CP" in { + test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c => + idleIO(c) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(2.U) + c.io.initiate.bits.wid.poke(0.U) + c.io.initiate.bits.rd.poke(0.U) + c.io.initiate.bits.addressA.poke(0x100.U) + c.io.initiate.bits.addressB.poke(0x200.U) + c.io.reqB.ready.poke(true.B) + c.clock.step() + + c.io.initiate.valid.poke(false.B) + c.io.reqB.valid.expect(true.B) + c.io.respB.valid.poke(true.B) + c.io.respB.bits.data.poke("hdeadbeef".U) + c.io.reqA.ready.poke(true.B) + c.clock.step() + c.io.reqA.valid.expect(true.B) + c.io.reqA.bits.rw.expect(true.B) + c.io.reqA.bits.address.expect(0x100.U) + } + } + + it should "load and store fragments through TMEM" in { + test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c => + idleIO(c) + + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(4.U) + c.io.initiate.bits.wid.poke(2.U) + c.io.initiate.bits.rd.poke(5.U) + c.io.initiate.bits.addressA.poke(0x300.U) + c.io.initiate.bits.addressB.poke(0.U) + c.io.reqA.ready.poke(true.B) + c.clock.step() + + c.io.initiate.valid.poke(false.B) + c.clock.step() + c.io.respA.valid.poke(true.B) + c.io.respA.bits.data.poke("h1234".U) + c.clock.step() + c.io.respA.valid.poke(false.B) + c.clock.step() + c.io.writeback.valid.expect(true.B) + c.io.writeback.bits.rd.expect(5.U) + c.io.writeback.ready.poke(true.B) + c.clock.step() + + idleIO(c) + c.io.initiate.valid.poke(true.B) + c.io.initiate.bits.op.poke(5.U) + c.io.initiate.bits.wid.poke(2.U) + c.io.initiate.bits.rd.poke(6.U) + c.io.initiate.bits.addressA.poke(0x340.U) + c.io.initiate.bits.addressB.poke(0.U) + c.io.reqA.ready.poke(true.B) + c.io.respC.poke("habcd".U) + c.clock.step() + c.io.reqA.valid.expect(true.B) + c.io.reqA.bits.rw.expect(true.B) + c.io.reqA.bits.address.expect(0x340.U) + } + } +}