Add Blackwell tensor core baseline plumbing

This commit is contained in:
2026-04-25 10:15:31 +08:00
parent 4a0b1c05cd
commit 136cf70a58
7 changed files with 528 additions and 7 deletions

View File

@@ -23,6 +23,9 @@ endif
ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
endif
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL
endif
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4
endif

View File

@@ -0,0 +1,238 @@
// See LICENSE.SiFive for license details.
// See LICENSE.Berkeley for license details.
package radiance.core
import chisel3._
import chisel3.util._
class TensorCoreBlackwell(
val numWarps: Int,
val numLanes: Int,
val half: Boolean,
val numSourceIds: Int = 16,
val numFPRegs: Int = 32
) extends Module {
val numWarpBits = log2Ceil(numWarps)
val sourceWidth = log2Ceil(numSourceIds)
val laneWidth = 4 * 8
val memWidth = numLanes * laneWidth
val numFPRegBits = log2Ceil(numFPRegs)
val addressWidth = 32
val maskWidth = memWidth / 8
object Ops {
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: Nil = Enum(6)
}
class TensorMemReq(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val rw = Bool()
val byteen = UInt((dataWidth / 8).W)
val source = UInt(sourceWidth.W)
val address = UInt(addressWidth.W)
val data = UInt(dataWidth.W)
}
class TensorMemResp(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val data = UInt(dataWidth.W)
}
val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle {
val op = UInt(3.W)
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val addressA = UInt(addressWidth.W)
val addressB = UInt(addressWidth.W)
}))
val writeback = Decoupled(new Bundle {
val last = Bool()
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val data = Vec(numLanes, UInt(laneWidth.W))
})
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respC = Input(UInt(memWidth.W))
val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W)))
})
object State extends ChiselEnum {
val idle, bwReq, bwResp, cpRead, cpWrite, ldReq, stReq, waitWb = Value
}
val state = RegInit(State.idle)
val opReg = RegInit(0.U(3.W))
val widReg = RegInit(0.U(numWarpBits.W))
val rdReg = RegInit(0.U(numFPRegBits.W))
val addrAReg = RegInit(0.U(addressWidth.W))
val addrBReg = RegInit(0.U(addressWidth.W))
val aDataReg = Reg(UInt(memWidth.W))
val bDataReg = Reg(UInt(memWidth.W))
val haveA = RegInit(false.B)
val haveB = RegInit(false.B)
val sourceCounter = RegInit(0.U(sourceWidth.W))
private def bumpSource(): Unit = {
sourceCounter := sourceCounter + 1.U
}
val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
reqA.valid := false.B
reqA.bits := 0.U.asTypeOf(reqA.bits)
reqB.valid := false.B
reqB.bits := 0.U.asTypeOf(reqB.bits)
io.reqA <> reqA
io.reqB <> reqB
val wbValid = RegInit(false.B)
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
io.writeback.valid := wbValid
io.writeback.bits.last := true.B
io.writeback.bits.wid := widReg
io.writeback.bits.rd := rdReg
io.writeback.bits.data := wbData
io.reqC.valid := false.B
io.reqC.bits := rdReg
io.respA.ready := false.B
io.respB.ready := false.B
io.initiate.ready := state === State.idle && !wbValid
when(io.writeback.fire) {
wbValid := false.B
}
when(io.initiate.fire) {
opReg := io.initiate.bits.op
widReg := io.initiate.bits.wid
rdReg := io.initiate.bits.rd
addrAReg := io.initiate.bits.addressA
addrBReg := io.initiate.bits.addressB
haveA := false.B
haveB := false.B
switch(io.initiate.bits.op) {
is(Ops.bwgmma) { state := State.bwReq }
is(Ops.tcgen05Cp) { state := State.cpRead }
is(Ops.tcgen05Ld) { state := State.ldReq }
is(Ops.tcgen05St) { state := State.stReq }
is(Ops.bwgmmaWait) { state := State.idle }
is(Ops.tcgen05CpWait) { state := State.idle }
}
}
when(state === State.bwReq) {
reqA.valid := true.B
reqA.bits.rw := false.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
reqB.valid := true.B
reqB.bits.rw := false.B
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqB.bits.address := addrBReg
reqB.bits.source := sourceCounter
io.reqC.valid := true.B
when(reqA.fire && reqB.fire) {
bumpSource()
state := State.bwResp
}
}
when(state === State.bwResp) {
io.respA.ready := true.B
io.respB.ready := true.B
when(io.respA.fire) {
aDataReg := io.respA.bits.data
haveA := true.B
}
when(io.respB.fire) {
bDataReg := io.respB.bits.data
haveB := true.B
}
when(haveA && haveB) {
val cWords = io.respC.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val aWords = aDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
val bWords = bDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
for (i <- 0 until numLanes) {
wbData(i) := aWords(i) + bWords(i) + cWords(i)
}
wbValid := true.B
state := State.idle
}
}
when(state === State.cpRead) {
reqB.valid := true.B
reqB.bits.rw := false.B
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqB.bits.address := addrBReg
reqB.bits.source := sourceCounter
when(reqB.fire) {
bumpSource()
state := State.cpWrite
}
}
when(state === State.cpWrite) {
io.respB.ready := reqA.ready
reqA.valid := io.respB.valid
reqA.bits.rw := true.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
reqA.bits.data := io.respB.bits.data
when(reqA.fire) {
bumpSource()
state := State.idle
}
}
when(state === State.ldReq) {
reqA.valid := true.B
reqA.bits.rw := false.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
when(reqA.fire) {
bumpSource()
state := State.waitWb
}
}
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
io.respA.ready := !wbValid
when(io.respA.fire) {
wbData := io.respA.bits.data.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
wbValid := true.B
state := State.idle
}
}
when(state === State.stReq) {
io.reqC.valid := true.B
reqA.valid := true.B
reqA.bits.rw := true.B
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
reqA.bits.address := addrAReg
reqA.bits.source := sourceCounter
reqA.bits.data := io.respC
when(reqA.fire) {
bumpSource()
state := State.idle
}
}
}

View File

@@ -50,6 +50,7 @@ class WithRadianceCores(
crossing: RocketCrossingParams,
tensorCoreFP16: Boolean,
tensorCoreDecoupled: Boolean,
tensorCoreBlackwell: Boolean,
useVxCache: Boolean
) extends Config((site, _, up) => {
case TilesLocated(`location`) => {
@@ -59,7 +60,8 @@ class WithRadianceCores(
val vortex = RadianceTileParams(
core = VortexCoreParams(
tensorCoreFP16 = tensorCoreFP16,
tensorCoreDecoupled = tensorCoreDecoupled
tensorCoreDecoupled = tensorCoreDecoupled,
tensorCoreBlackwell = tensorCoreBlackwell
),
btb = None,
useVxCache = useVxCache,
@@ -96,6 +98,7 @@ class WithRadianceCores(
// constructor override that omits `crossing`
def this(n: Int, location: HierarchicalLocation = InSubsystem,
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
tensorCoreBlackwell: Boolean = false,
useVxCache: Boolean = false)
= this(n, location, RocketCrossingParams(
master = HierarchicalElementMasterPortParams.locationDefault(location),
@@ -104,9 +107,23 @@ class WithRadianceCores(
case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId)
}
), tensorCoreFP16, tensorCoreDecoupled, useVxCache)
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, useVxCache)
}
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {
case TilesLocated(`location`) =>
up(TilesLocated(`location`)).map {
case r: RadianceTileAttachParams =>
r.copy(tileParams = r.tileParams.copy(
core = r.tileParams.core.copy(
tensorCoreBlackwell = true,
tensorCoreDecoupled = false
)
))
case other => other
}
})
class WithEmulatorCores(
n: Int,
useVxCache: Boolean

View File

@@ -101,6 +101,7 @@ case class VortexCoreParams(
fpu: Option[FPUParams] = None,
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
haveCease: Boolean = true, // non-standard CEASE instruction
haveSimTimeout: Boolean = true // add plusarg for simulation timeout
@@ -152,6 +153,10 @@ class RadianceTile private (
p(SIMTCoreKey).isDefined,
"SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile"
)
require(
!(radianceParams.core.tensorCoreDecoupled && radianceParams.core.tensorCoreBlackwell),
"tensorCoreDecoupled and tensorCoreBlackwell are mutually exclusive"
)
// NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in
// radiance.mk as well!
@@ -280,7 +285,9 @@ class RadianceTile private (
}
val tcSmemSize = 32
val tcSmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreDecoupled) 2 else 0) { i =>
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}_$i",
@@ -294,6 +301,42 @@ class RadianceTile private (
)))
}
val tmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreBlackwell) 2 else 0) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2(
name = s"rad_tmem_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
putPartial = TransferSizes(1, tcSmemSize),
),
requestFifo = true
))
)))
}
val tmemNode = if (radianceParams.core.tensorCoreBlackwell) {
Some(LazyModule(new TLRAM(
address = AddressSet(0x0, 0x3fff),
beatBytes = tcSmemSize
)))
} else {
None
}
val tmemXbar = if (radianceParams.core.tensorCoreBlackwell) {
Some(LazyModule(new TLXbar))
} else {
None
}
(tmemNode, tmemXbar) match {
case (Some(tmem), Some(xbar)) =>
tmem.node :=* xbar.node
tmemNodes.foreach { node => xbar.node :=* node }
case _ =>
}
// combine outgoing per-lane dmemNode into 1 idenity node
//
// NOTE: We need TLWidthWidget here because there might be a data width
@@ -743,12 +786,18 @@ class RadianceTileModuleImp(outer: RadianceTile)
val tcb0 = new {
val addr = core.io.tc_a_bits_address(31, 0)
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
val write = core.io.tc_a_bits_write(0)
val mask = core.io.tc_a_bits_mask(31, 0)
val data = core.io.tc_a_bits_data(255, 0)
val aValid = core.io.tc_a_valid(0)
val dReady = core.io.tc_d_ready(0)
}
val tcb1 = new {
val addr = core.io.tc_a_bits_address(63, 32)
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
val write = core.io.tc_a_bits_write(1)
val mask = core.io.tc_a_bits_mask(63, 32)
val data = core.io.tc_a_bits_data(511, 256)
val aValid = core.io.tc_a_valid(1)
val dReady = core.io.tc_d_ready(1)
}
@@ -770,8 +819,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
adapter.io.inReq.bits.address := bundle.addr
adapter.io.inReq.bits.source := bundle.tag
adapter.io.inReq.bits.size := 5.U // 256 bits
adapter.io.inReq.bits.opcode := TLMessages.Get
adapter.io.inReq.bits.mask := x"ffffffff".U
adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := bundle.mask
adapter.io.inReq.bits.data := bundle.data
adapter.io.inResp.ready := bundle.dReady
client._1.a <> adapter.io.outReq
@@ -792,6 +842,71 @@ class RadianceTileModuleImp(outer: RadianceTile)
}
}
def connectTensorBlackwell = {
if (outer.radianceParams.core.tensorCoreBlackwell) {
require(outer.tmemNodes.nonEmpty)
require(outer.tcSmemNodes.nonEmpty)
val bundles = Seq(
(outer.tmemNodes.head, new {
val addr = core.io.tc_a_bits_address(31, 0)
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
val write = core.io.tc_a_bits_write(0)
val mask = core.io.tc_a_bits_mask(31, 0)
val data = core.io.tc_a_bits_data(255, 0)
val aValid = core.io.tc_a_valid(0)
val dReady = core.io.tc_d_ready(0)
}),
(outer.tcSmemNodes.head, new {
val addr = core.io.tc_a_bits_address(63, 32)
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
val write = core.io.tc_a_bits_write(1)
val mask = core.io.tc_a_bits_mask(63, 32)
val data = core.io.tc_a_bits_data(511, 256)
val aValid = core.io.tc_a_valid(1)
val dReady = core.io.tc_d_ready(1)
})
)
val adapters = bundles.map { case (node, bundle) =>
val client = node.out.head
val adapter = Module(
new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
client
)
)
require(adapter.io.inReq.bits.source.widthOption.get == bundle.tag.widthOption.get)
require(adapter.io.inReq.bits.address.widthOption.get == bundle.addr.widthOption.get)
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := bundle.aValid
adapter.io.inReq.bits.address := bundle.addr
adapter.io.inReq.bits.source := bundle.tag
adapter.io.inReq.bits.size := 5.U
adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
adapter.io.inReq.bits.mask := bundle.mask
adapter.io.inReq.bits.data := bundle.data
adapter.io.inResp.ready := bundle.dReady
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
adapter
}
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
} else {
core.io.tc_a_ready := false.B
core.io.tc_d_valid := false.B
core.io.tc_d_bits_data := DontCare
core.io.tc_d_bits_tag := DontCare
}
}
def connectBarrier = {
require(outer.barrierMasterNode.out.length == 1)
// FIXME: bits not flattened
@@ -847,7 +962,11 @@ class RadianceTileModuleImp(outer: RadianceTile)
connectImem
connectDmem
connectSmem
if (outer.radianceParams.core.tensorCoreBlackwell) {
connectTensorBlackwell
} else {
connectTensor
}
connectBarrier
connectAccelerator
}
@@ -874,6 +993,20 @@ class RadianceTileModuleImp(outer: RadianceTile)
tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
val tensor = Module(new radiance.core.TensorCoreBlackwell(
8, 8, half = true, tensorNumSourceIds))
tensor.io.initiate.valid := false.B
tensor.io.initiate.bits := DontCare
tensor.io.respA.valid := false.B
tensor.io.respA.bits := DontCare
tensor.io.respB.valid := false.B
tensor.io.respB.bits := DontCare
tensor.io.respC := DontCare
tensor.io.reqA.ready := false.B
tensor.io.reqB.ready := false.B
tensor.io.writeback.ready := false.B
} else {
if (outer.radianceParams.core.tensorCoreFP16) {
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))

View File

@@ -91,8 +91,11 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val tc_a_valid = Output(UInt(2.W))
val tc_a_bits_write = Output(UInt(2.W))
val tc_a_bits_address = Output(UInt((2 * 32).W))
val tc_a_bits_tag = Output(UInt((2 * 4).W))
val tc_a_bits_mask = Output(UInt((2 * 32).W))
val tc_a_bits_data = Output(UInt((2 * 32 * 8).W))
val tc_a_ready = Input(UInt(2.W))
val tc_d_valid = Input(UInt(2.W))
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W))
@@ -411,6 +414,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
// hopper-style SMEM operand decoupling
if (tile.radianceParams.core.tensorCoreDecoupled) {
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
} else if (tile.radianceParams.core.tensorCoreBlackwell) {
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_blackwell_core.sv")
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
def addHopperTensorCore = {
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")

View File

@@ -0,0 +1,125 @@
package radiance.core
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "TensorCoreBlackwell"
private def idleIO(c: TensorCoreBlackwell): Unit = {
c.io.initiate.valid.poke(false.B)
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.io.respA.bits.source.poke(0.U)
c.io.respB.bits.source.poke(0.U)
c.io.respA.bits.data.poke(0.U)
c.io.respB.bits.data.poke(0.U)
c.io.respC.poke(0.U)
c.io.writeback.ready.poke(false.B)
}
it should "run a minimal BWGMMA path" in {
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
idleIO(c)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(0.U)
c.io.initiate.bits.wid.poke(1.U)
c.io.initiate.bits.rd.poke(3.U)
c.io.initiate.bits.addressA.poke(0x40.U)
c.io.initiate.bits.addressB.poke(0x80.U)
c.io.reqA.ready.poke(true.B)
c.io.reqB.ready.poke(true.B)
c.io.respC.poke("h0000000800000007000000060000000500000004000000030000000200000001".U)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqA.valid.expect(true.B)
c.io.reqB.valid.expect(true.B)
c.clock.step()
c.io.respA.valid.poke(true.B)
c.io.respB.valid.poke(true.B)
c.io.respA.bits.data.poke("h0000000800000007000000060000000500000004000000030000000200000001".U)
c.io.respB.bits.data.poke("h000000100000000f0000000e0000000d0000000c0000000b0000000a00000009".U)
c.clock.step()
c.io.respA.valid.poke(false.B)
c.io.respB.valid.poke(false.B)
c.clock.step()
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(3.U)
c.io.writeback.bits.wid.expect(1.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
}
}
it should "copy from SMEM to TMEM on TCGEN05_CP" in {
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
idleIO(c)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(2.U)
c.io.initiate.bits.wid.poke(0.U)
c.io.initiate.bits.rd.poke(0.U)
c.io.initiate.bits.addressA.poke(0x100.U)
c.io.initiate.bits.addressB.poke(0x200.U)
c.io.reqB.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.io.reqB.valid.expect(true.B)
c.io.respB.valid.poke(true.B)
c.io.respB.bits.data.poke("hdeadbeef".U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect(0x100.U)
}
}
it should "load and store fragments through TMEM" in {
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
idleIO(c)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(4.U)
c.io.initiate.bits.wid.poke(2.U)
c.io.initiate.bits.rd.poke(5.U)
c.io.initiate.bits.addressA.poke(0x300.U)
c.io.initiate.bits.addressB.poke(0.U)
c.io.reqA.ready.poke(true.B)
c.clock.step()
c.io.initiate.valid.poke(false.B)
c.clock.step()
c.io.respA.valid.poke(true.B)
c.io.respA.bits.data.poke("h1234".U)
c.clock.step()
c.io.respA.valid.poke(false.B)
c.clock.step()
c.io.writeback.valid.expect(true.B)
c.io.writeback.bits.rd.expect(5.U)
c.io.writeback.ready.poke(true.B)
c.clock.step()
idleIO(c)
c.io.initiate.valid.poke(true.B)
c.io.initiate.bits.op.poke(5.U)
c.io.initiate.bits.wid.poke(2.U)
c.io.initiate.bits.rd.poke(6.U)
c.io.initiate.bits.addressA.poke(0x340.U)
c.io.initiate.bits.addressB.poke(0.U)
c.io.reqA.ready.poke(true.B)
c.io.respC.poke("habcd".U)
c.clock.step()
c.io.reqA.valid.expect(true.B)
c.io.reqA.bits.rw.expect(true.B)
c.io.reqA.bits.address.expect(0x340.U)
}
}
}