Add Blackwell tensor core baseline plumbing
This commit is contained in:
@@ -23,6 +23,9 @@ endif
|
|||||||
ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
|
ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
|
||||||
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
|
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
|
||||||
endif
|
endif
|
||||||
|
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
|
||||||
|
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_BLACKWELL
|
||||||
|
endif
|
||||||
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
|
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
|
||||||
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4
|
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4
|
||||||
endif
|
endif
|
||||||
|
|||||||
Submodule src/main/resources/vsrc/vortex updated: f1d0fac518...cb912d3b8b
238
src/main/scala/radiance/core/TensorCoreBlackwell.scala
Normal file
238
src/main/scala/radiance/core/TensorCoreBlackwell.scala
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
// See LICENSE.SiFive for license details.
|
||||||
|
// See LICENSE.Berkeley for license details.
|
||||||
|
|
||||||
|
package radiance.core
|
||||||
|
|
||||||
|
import chisel3._
|
||||||
|
import chisel3.util._
|
||||||
|
|
||||||
|
class TensorCoreBlackwell(
|
||||||
|
val numWarps: Int,
|
||||||
|
val numLanes: Int,
|
||||||
|
val half: Boolean,
|
||||||
|
val numSourceIds: Int = 16,
|
||||||
|
val numFPRegs: Int = 32
|
||||||
|
) extends Module {
|
||||||
|
val numWarpBits = log2Ceil(numWarps)
|
||||||
|
val sourceWidth = log2Ceil(numSourceIds)
|
||||||
|
val laneWidth = 4 * 8
|
||||||
|
val memWidth = numLanes * laneWidth
|
||||||
|
val numFPRegBits = log2Ceil(numFPRegs)
|
||||||
|
val addressWidth = 32
|
||||||
|
val maskWidth = memWidth / 8
|
||||||
|
|
||||||
|
object Ops {
|
||||||
|
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: Nil = Enum(6)
|
||||||
|
}
|
||||||
|
|
||||||
|
class TensorMemReq(
|
||||||
|
sourceWidth: Int,
|
||||||
|
dataWidth: Int
|
||||||
|
) extends Bundle {
|
||||||
|
val rw = Bool()
|
||||||
|
val byteen = UInt((dataWidth / 8).W)
|
||||||
|
val source = UInt(sourceWidth.W)
|
||||||
|
val address = UInt(addressWidth.W)
|
||||||
|
val data = UInt(dataWidth.W)
|
||||||
|
}
|
||||||
|
|
||||||
|
class TensorMemResp(
|
||||||
|
sourceWidth: Int,
|
||||||
|
dataWidth: Int
|
||||||
|
) extends Bundle {
|
||||||
|
val source = UInt(sourceWidth.W)
|
||||||
|
val data = UInt(dataWidth.W)
|
||||||
|
}
|
||||||
|
|
||||||
|
val io = IO(new Bundle {
|
||||||
|
val initiate = Flipped(Decoupled(new Bundle {
|
||||||
|
val op = UInt(3.W)
|
||||||
|
val wid = UInt(numWarpBits.W)
|
||||||
|
val rd = UInt(numFPRegBits.W)
|
||||||
|
val addressA = UInt(addressWidth.W)
|
||||||
|
val addressB = UInt(addressWidth.W)
|
||||||
|
}))
|
||||||
|
val writeback = Decoupled(new Bundle {
|
||||||
|
val last = Bool()
|
||||||
|
val wid = UInt(numWarpBits.W)
|
||||||
|
val rd = UInt(numFPRegBits.W)
|
||||||
|
val data = Vec(numLanes, UInt(laneWidth.W))
|
||||||
|
})
|
||||||
|
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
||||||
|
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
||||||
|
val respC = Input(UInt(memWidth.W))
|
||||||
|
val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth))
|
||||||
|
val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth))
|
||||||
|
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
||||||
|
})
|
||||||
|
|
||||||
|
object State extends ChiselEnum {
|
||||||
|
val idle, bwReq, bwResp, cpRead, cpWrite, ldReq, stReq, waitWb = Value
|
||||||
|
}
|
||||||
|
val state = RegInit(State.idle)
|
||||||
|
|
||||||
|
val opReg = RegInit(0.U(3.W))
|
||||||
|
val widReg = RegInit(0.U(numWarpBits.W))
|
||||||
|
val rdReg = RegInit(0.U(numFPRegBits.W))
|
||||||
|
val addrAReg = RegInit(0.U(addressWidth.W))
|
||||||
|
val addrBReg = RegInit(0.U(addressWidth.W))
|
||||||
|
val aDataReg = Reg(UInt(memWidth.W))
|
||||||
|
val bDataReg = Reg(UInt(memWidth.W))
|
||||||
|
val haveA = RegInit(false.B)
|
||||||
|
val haveB = RegInit(false.B)
|
||||||
|
val sourceCounter = RegInit(0.U(sourceWidth.W))
|
||||||
|
|
||||||
|
private def bumpSource(): Unit = {
|
||||||
|
sourceCounter := sourceCounter + 1.U
|
||||||
|
}
|
||||||
|
|
||||||
|
val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
|
||||||
|
val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
|
||||||
|
reqA.valid := false.B
|
||||||
|
reqA.bits := 0.U.asTypeOf(reqA.bits)
|
||||||
|
reqB.valid := false.B
|
||||||
|
reqB.bits := 0.U.asTypeOf(reqB.bits)
|
||||||
|
io.reqA <> reqA
|
||||||
|
io.reqB <> reqB
|
||||||
|
|
||||||
|
val wbValid = RegInit(false.B)
|
||||||
|
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
|
||||||
|
io.writeback.valid := wbValid
|
||||||
|
io.writeback.bits.last := true.B
|
||||||
|
io.writeback.bits.wid := widReg
|
||||||
|
io.writeback.bits.rd := rdReg
|
||||||
|
io.writeback.bits.data := wbData
|
||||||
|
|
||||||
|
io.reqC.valid := false.B
|
||||||
|
io.reqC.bits := rdReg
|
||||||
|
|
||||||
|
io.respA.ready := false.B
|
||||||
|
io.respB.ready := false.B
|
||||||
|
io.initiate.ready := state === State.idle && !wbValid
|
||||||
|
|
||||||
|
when(io.writeback.fire) {
|
||||||
|
wbValid := false.B
|
||||||
|
}
|
||||||
|
|
||||||
|
when(io.initiate.fire) {
|
||||||
|
opReg := io.initiate.bits.op
|
||||||
|
widReg := io.initiate.bits.wid
|
||||||
|
rdReg := io.initiate.bits.rd
|
||||||
|
addrAReg := io.initiate.bits.addressA
|
||||||
|
addrBReg := io.initiate.bits.addressB
|
||||||
|
haveA := false.B
|
||||||
|
haveB := false.B
|
||||||
|
switch(io.initiate.bits.op) {
|
||||||
|
is(Ops.bwgmma) { state := State.bwReq }
|
||||||
|
is(Ops.tcgen05Cp) { state := State.cpRead }
|
||||||
|
is(Ops.tcgen05Ld) { state := State.ldReq }
|
||||||
|
is(Ops.tcgen05St) { state := State.stReq }
|
||||||
|
is(Ops.bwgmmaWait) { state := State.idle }
|
||||||
|
is(Ops.tcgen05CpWait) { state := State.idle }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.bwReq) {
|
||||||
|
reqA.valid := true.B
|
||||||
|
reqA.bits.rw := false.B
|
||||||
|
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
||||||
|
reqA.bits.address := addrAReg
|
||||||
|
reqA.bits.source := sourceCounter
|
||||||
|
|
||||||
|
reqB.valid := true.B
|
||||||
|
reqB.bits.rw := false.B
|
||||||
|
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
||||||
|
reqB.bits.address := addrBReg
|
||||||
|
reqB.bits.source := sourceCounter
|
||||||
|
|
||||||
|
io.reqC.valid := true.B
|
||||||
|
when(reqA.fire && reqB.fire) {
|
||||||
|
bumpSource()
|
||||||
|
state := State.bwResp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.bwResp) {
|
||||||
|
io.respA.ready := true.B
|
||||||
|
io.respB.ready := true.B
|
||||||
|
when(io.respA.fire) {
|
||||||
|
aDataReg := io.respA.bits.data
|
||||||
|
haveA := true.B
|
||||||
|
}
|
||||||
|
when(io.respB.fire) {
|
||||||
|
bDataReg := io.respB.bits.data
|
||||||
|
haveB := true.B
|
||||||
|
}
|
||||||
|
when(haveA && haveB) {
|
||||||
|
val cWords = io.respC.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
||||||
|
val aWords = aDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
||||||
|
val bWords = bDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
||||||
|
for (i <- 0 until numLanes) {
|
||||||
|
wbData(i) := aWords(i) + bWords(i) + cWords(i)
|
||||||
|
}
|
||||||
|
wbValid := true.B
|
||||||
|
state := State.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.cpRead) {
|
||||||
|
reqB.valid := true.B
|
||||||
|
reqB.bits.rw := false.B
|
||||||
|
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
||||||
|
reqB.bits.address := addrBReg
|
||||||
|
reqB.bits.source := sourceCounter
|
||||||
|
when(reqB.fire) {
|
||||||
|
bumpSource()
|
||||||
|
state := State.cpWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.cpWrite) {
|
||||||
|
io.respB.ready := reqA.ready
|
||||||
|
reqA.valid := io.respB.valid
|
||||||
|
reqA.bits.rw := true.B
|
||||||
|
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
||||||
|
reqA.bits.address := addrAReg
|
||||||
|
reqA.bits.source := sourceCounter
|
||||||
|
reqA.bits.data := io.respB.bits.data
|
||||||
|
when(reqA.fire) {
|
||||||
|
bumpSource()
|
||||||
|
state := State.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.ldReq) {
|
||||||
|
reqA.valid := true.B
|
||||||
|
reqA.bits.rw := false.B
|
||||||
|
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
||||||
|
reqA.bits.address := addrAReg
|
||||||
|
reqA.bits.source := sourceCounter
|
||||||
|
when(reqA.fire) {
|
||||||
|
bumpSource()
|
||||||
|
state := State.waitWb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
|
||||||
|
io.respA.ready := !wbValid
|
||||||
|
when(io.respA.fire) {
|
||||||
|
wbData := io.respA.bits.data.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
||||||
|
wbValid := true.B
|
||||||
|
state := State.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
when(state === State.stReq) {
|
||||||
|
io.reqC.valid := true.B
|
||||||
|
reqA.valid := true.B
|
||||||
|
reqA.bits.rw := true.B
|
||||||
|
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
||||||
|
reqA.bits.address := addrAReg
|
||||||
|
reqA.bits.source := sourceCounter
|
||||||
|
reqA.bits.data := io.respC
|
||||||
|
when(reqA.fire) {
|
||||||
|
bumpSource()
|
||||||
|
state := State.idle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -50,6 +50,7 @@ class WithRadianceCores(
|
|||||||
crossing: RocketCrossingParams,
|
crossing: RocketCrossingParams,
|
||||||
tensorCoreFP16: Boolean,
|
tensorCoreFP16: Boolean,
|
||||||
tensorCoreDecoupled: Boolean,
|
tensorCoreDecoupled: Boolean,
|
||||||
|
tensorCoreBlackwell: Boolean,
|
||||||
useVxCache: Boolean
|
useVxCache: Boolean
|
||||||
) extends Config((site, _, up) => {
|
) extends Config((site, _, up) => {
|
||||||
case TilesLocated(`location`) => {
|
case TilesLocated(`location`) => {
|
||||||
@@ -59,7 +60,8 @@ class WithRadianceCores(
|
|||||||
val vortex = RadianceTileParams(
|
val vortex = RadianceTileParams(
|
||||||
core = VortexCoreParams(
|
core = VortexCoreParams(
|
||||||
tensorCoreFP16 = tensorCoreFP16,
|
tensorCoreFP16 = tensorCoreFP16,
|
||||||
tensorCoreDecoupled = tensorCoreDecoupled
|
tensorCoreDecoupled = tensorCoreDecoupled,
|
||||||
|
tensorCoreBlackwell = tensorCoreBlackwell
|
||||||
),
|
),
|
||||||
btb = None,
|
btb = None,
|
||||||
useVxCache = useVxCache,
|
useVxCache = useVxCache,
|
||||||
@@ -96,6 +98,7 @@ class WithRadianceCores(
|
|||||||
// constructor override that omits `crossing`
|
// constructor override that omits `crossing`
|
||||||
def this(n: Int, location: HierarchicalLocation = InSubsystem,
|
def this(n: Int, location: HierarchicalLocation = InSubsystem,
|
||||||
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
|
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
|
||||||
|
tensorCoreBlackwell: Boolean = false,
|
||||||
useVxCache: Boolean = false)
|
useVxCache: Boolean = false)
|
||||||
= this(n, location, RocketCrossingParams(
|
= this(n, location, RocketCrossingParams(
|
||||||
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
||||||
@@ -104,9 +107,23 @@ class WithRadianceCores(
|
|||||||
case InSubsystem => CBUS
|
case InSubsystem => CBUS
|
||||||
case InCluster(clusterId) => CCBUS(clusterId)
|
case InCluster(clusterId) => CCBUS(clusterId)
|
||||||
}
|
}
|
||||||
), tensorCoreFP16, tensorCoreDecoupled, useVxCache)
|
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, useVxCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {
|
||||||
|
case TilesLocated(`location`) =>
|
||||||
|
up(TilesLocated(`location`)).map {
|
||||||
|
case r: RadianceTileAttachParams =>
|
||||||
|
r.copy(tileParams = r.tileParams.copy(
|
||||||
|
core = r.tileParams.core.copy(
|
||||||
|
tensorCoreBlackwell = true,
|
||||||
|
tensorCoreDecoupled = false
|
||||||
|
)
|
||||||
|
))
|
||||||
|
case other => other
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
class WithEmulatorCores(
|
class WithEmulatorCores(
|
||||||
n: Int,
|
n: Int,
|
||||||
useVxCache: Boolean
|
useVxCache: Boolean
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ case class VortexCoreParams(
|
|||||||
fpu: Option[FPUParams] = None,
|
fpu: Option[FPUParams] = None,
|
||||||
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
|
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
|
||||||
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
|
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
|
||||||
|
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
|
||||||
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
|
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
|
||||||
haveCease: Boolean = true, // non-standard CEASE instruction
|
haveCease: Boolean = true, // non-standard CEASE instruction
|
||||||
haveSimTimeout: Boolean = true // add plusarg for simulation timeout
|
haveSimTimeout: Boolean = true // add plusarg for simulation timeout
|
||||||
@@ -152,6 +153,10 @@ class RadianceTile private (
|
|||||||
p(SIMTCoreKey).isDefined,
|
p(SIMTCoreKey).isDefined,
|
||||||
"SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile"
|
"SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile"
|
||||||
)
|
)
|
||||||
|
require(
|
||||||
|
!(radianceParams.core.tensorCoreDecoupled && radianceParams.core.tensorCoreBlackwell),
|
||||||
|
"tensorCoreDecoupled and tensorCoreBlackwell are mutually exclusive"
|
||||||
|
)
|
||||||
|
|
||||||
// NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in
|
// NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in
|
||||||
// radiance.mk as well!
|
// radiance.mk as well!
|
||||||
@@ -280,7 +285,9 @@ class RadianceTile private (
|
|||||||
}
|
}
|
||||||
|
|
||||||
val tcSmemSize = 32
|
val tcSmemSize = 32
|
||||||
val tcSmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreDecoupled) 2 else 0) { i =>
|
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
|
||||||
|
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) 1 else 0
|
||||||
|
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
|
||||||
TLClientNode(Seq(TLMasterPortParameters.v2(
|
TLClientNode(Seq(TLMasterPortParameters.v2(
|
||||||
masters = Seq(TLMasterParameters.v2(
|
masters = Seq(TLMasterParameters.v2(
|
||||||
name = s"rad_tc_${radianceParams.coreId}_$i",
|
name = s"rad_tc_${radianceParams.coreId}_$i",
|
||||||
@@ -294,6 +301,42 @@ class RadianceTile private (
|
|||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val tmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreBlackwell) 2 else 0) { i =>
|
||||||
|
TLClientNode(Seq(TLMasterPortParameters.v2(
|
||||||
|
masters = Seq(TLMasterParameters.v2(
|
||||||
|
name = s"rad_tmem_${radianceParams.coreId}_$i",
|
||||||
|
sourceId = IdRange(0, 1 << smemSourceWidth),
|
||||||
|
supports = TLSlaveToMasterTransferSizes(
|
||||||
|
probe = TransferSizes(1, tcSmemSize),
|
||||||
|
get = TransferSizes(1, tcSmemSize),
|
||||||
|
putFull = TransferSizes(1, tcSmemSize),
|
||||||
|
putPartial = TransferSizes(1, tcSmemSize),
|
||||||
|
),
|
||||||
|
requestFifo = true
|
||||||
|
))
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
val tmemNode = if (radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
Some(LazyModule(new TLRAM(
|
||||||
|
address = AddressSet(0x0, 0x3fff),
|
||||||
|
beatBytes = tcSmemSize
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
val tmemXbar = if (radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
Some(LazyModule(new TLXbar))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
(tmemNode, tmemXbar) match {
|
||||||
|
case (Some(tmem), Some(xbar)) =>
|
||||||
|
tmem.node :=* xbar.node
|
||||||
|
tmemNodes.foreach { node => xbar.node :=* node }
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
|
||||||
// combine outgoing per-lane dmemNode into 1 idenity node
|
// combine outgoing per-lane dmemNode into 1 idenity node
|
||||||
//
|
//
|
||||||
// NOTE: We need TLWidthWidget here because there might be a data width
|
// NOTE: We need TLWidthWidget here because there might be a data width
|
||||||
@@ -743,12 +786,18 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
val tcb0 = new {
|
val tcb0 = new {
|
||||||
val addr = core.io.tc_a_bits_address(31, 0)
|
val addr = core.io.tc_a_bits_address(31, 0)
|
||||||
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
|
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
|
||||||
|
val write = core.io.tc_a_bits_write(0)
|
||||||
|
val mask = core.io.tc_a_bits_mask(31, 0)
|
||||||
|
val data = core.io.tc_a_bits_data(255, 0)
|
||||||
val aValid = core.io.tc_a_valid(0)
|
val aValid = core.io.tc_a_valid(0)
|
||||||
val dReady = core.io.tc_d_ready(0)
|
val dReady = core.io.tc_d_ready(0)
|
||||||
}
|
}
|
||||||
val tcb1 = new {
|
val tcb1 = new {
|
||||||
val addr = core.io.tc_a_bits_address(63, 32)
|
val addr = core.io.tc_a_bits_address(63, 32)
|
||||||
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
|
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
|
||||||
|
val write = core.io.tc_a_bits_write(1)
|
||||||
|
val mask = core.io.tc_a_bits_mask(63, 32)
|
||||||
|
val data = core.io.tc_a_bits_data(511, 256)
|
||||||
val aValid = core.io.tc_a_valid(1)
|
val aValid = core.io.tc_a_valid(1)
|
||||||
val dReady = core.io.tc_d_ready(1)
|
val dReady = core.io.tc_d_ready(1)
|
||||||
}
|
}
|
||||||
@@ -770,8 +819,9 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
adapter.io.inReq.bits.address := bundle.addr
|
adapter.io.inReq.bits.address := bundle.addr
|
||||||
adapter.io.inReq.bits.source := bundle.tag
|
adapter.io.inReq.bits.source := bundle.tag
|
||||||
adapter.io.inReq.bits.size := 5.U // 256 bits
|
adapter.io.inReq.bits.size := 5.U // 256 bits
|
||||||
adapter.io.inReq.bits.opcode := TLMessages.Get
|
adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||||
adapter.io.inReq.bits.mask := x"ffffffff".U
|
adapter.io.inReq.bits.mask := bundle.mask
|
||||||
|
adapter.io.inReq.bits.data := bundle.data
|
||||||
adapter.io.inResp.ready := bundle.dReady
|
adapter.io.inResp.ready := bundle.dReady
|
||||||
|
|
||||||
client._1.a <> adapter.io.outReq
|
client._1.a <> adapter.io.outReq
|
||||||
@@ -792,6 +842,71 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def connectTensorBlackwell = {
|
||||||
|
if (outer.radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
require(outer.tmemNodes.nonEmpty)
|
||||||
|
require(outer.tcSmemNodes.nonEmpty)
|
||||||
|
|
||||||
|
val bundles = Seq(
|
||||||
|
(outer.tmemNodes.head, new {
|
||||||
|
val addr = core.io.tc_a_bits_address(31, 0)
|
||||||
|
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
|
||||||
|
val write = core.io.tc_a_bits_write(0)
|
||||||
|
val mask = core.io.tc_a_bits_mask(31, 0)
|
||||||
|
val data = core.io.tc_a_bits_data(255, 0)
|
||||||
|
val aValid = core.io.tc_a_valid(0)
|
||||||
|
val dReady = core.io.tc_d_ready(0)
|
||||||
|
}),
|
||||||
|
(outer.tcSmemNodes.head, new {
|
||||||
|
val addr = core.io.tc_a_bits_address(63, 32)
|
||||||
|
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
|
||||||
|
val write = core.io.tc_a_bits_write(1)
|
||||||
|
val mask = core.io.tc_a_bits_mask(63, 32)
|
||||||
|
val data = core.io.tc_a_bits_data(511, 256)
|
||||||
|
val aValid = core.io.tc_a_valid(1)
|
||||||
|
val dReady = core.io.tc_d_ready(1)
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
val adapters = bundles.map { case (node, bundle) =>
|
||||||
|
val client = node.out.head
|
||||||
|
val adapter = Module(
|
||||||
|
new VortexTLAdapter(
|
||||||
|
outer.smemSourceWidth,
|
||||||
|
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
|
||||||
|
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = 32 * 8),
|
||||||
|
client
|
||||||
|
)
|
||||||
|
)
|
||||||
|
require(adapter.io.inReq.bits.source.widthOption.get == bundle.tag.widthOption.get)
|
||||||
|
require(adapter.io.inReq.bits.address.widthOption.get == bundle.addr.widthOption.get)
|
||||||
|
adapter.io.inReq.bits <> DontCare
|
||||||
|
adapter.io.inReq.valid := bundle.aValid
|
||||||
|
adapter.io.inReq.bits.address := bundle.addr
|
||||||
|
adapter.io.inReq.bits.source := bundle.tag
|
||||||
|
adapter.io.inReq.bits.size := 5.U
|
||||||
|
adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
|
||||||
|
adapter.io.inReq.bits.mask := bundle.mask
|
||||||
|
adapter.io.inReq.bits.data := bundle.data
|
||||||
|
adapter.io.inResp.ready := bundle.dReady
|
||||||
|
|
||||||
|
client._1.a <> adapter.io.outReq
|
||||||
|
adapter.io.outResp <> client._1.d
|
||||||
|
adapter
|
||||||
|
}
|
||||||
|
|
||||||
|
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
|
||||||
|
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
|
||||||
|
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
|
||||||
|
core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
|
||||||
|
} else {
|
||||||
|
core.io.tc_a_ready := false.B
|
||||||
|
core.io.tc_d_valid := false.B
|
||||||
|
core.io.tc_d_bits_data := DontCare
|
||||||
|
core.io.tc_d_bits_tag := DontCare
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def connectBarrier = {
|
def connectBarrier = {
|
||||||
require(outer.barrierMasterNode.out.length == 1)
|
require(outer.barrierMasterNode.out.length == 1)
|
||||||
// FIXME: bits not flattened
|
// FIXME: bits not flattened
|
||||||
@@ -847,7 +962,11 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
connectImem
|
connectImem
|
||||||
connectDmem
|
connectDmem
|
||||||
connectSmem
|
connectSmem
|
||||||
connectTensor
|
if (outer.radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
connectTensorBlackwell
|
||||||
|
} else {
|
||||||
|
connectTensor
|
||||||
|
}
|
||||||
connectBarrier
|
connectBarrier
|
||||||
connectAccelerator
|
connectAccelerator
|
||||||
}
|
}
|
||||||
@@ -874,6 +993,20 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
tensor.io.reqA.ready := false.B
|
tensor.io.reqA.ready := false.B
|
||||||
tensor.io.reqB.ready := false.B
|
tensor.io.reqB.ready := false.B
|
||||||
tensor.io.writeback.ready := false.B
|
tensor.io.writeback.ready := false.B
|
||||||
|
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
||||||
|
val tensor = Module(new radiance.core.TensorCoreBlackwell(
|
||||||
|
8, 8, half = true, tensorNumSourceIds))
|
||||||
|
tensor.io.initiate.valid := false.B
|
||||||
|
tensor.io.initiate.bits := DontCare
|
||||||
|
tensor.io.respA.valid := false.B
|
||||||
|
tensor.io.respA.bits := DontCare
|
||||||
|
tensor.io.respB.valid := false.B
|
||||||
|
tensor.io.respB.bits := DontCare
|
||||||
|
tensor.io.respC := DontCare
|
||||||
|
tensor.io.reqA.ready := false.B
|
||||||
|
tensor.io.reqB.ready := false.B
|
||||||
|
tensor.io.writeback.ready := false.B
|
||||||
} else {
|
} else {
|
||||||
if (outer.radianceParams.core.tensorCoreFP16) {
|
if (outer.radianceParams.core.tensorCoreFP16) {
|
||||||
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))
|
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))
|
||||||
|
|||||||
@@ -91,8 +91,11 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
|
|||||||
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
|
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
|
||||||
|
|
||||||
val tc_a_valid = Output(UInt(2.W))
|
val tc_a_valid = Output(UInt(2.W))
|
||||||
|
val tc_a_bits_write = Output(UInt(2.W))
|
||||||
val tc_a_bits_address = Output(UInt((2 * 32).W))
|
val tc_a_bits_address = Output(UInt((2 * 32).W))
|
||||||
val tc_a_bits_tag = Output(UInt((2 * 4).W))
|
val tc_a_bits_tag = Output(UInt((2 * 4).W))
|
||||||
|
val tc_a_bits_mask = Output(UInt((2 * 32).W))
|
||||||
|
val tc_a_bits_data = Output(UInt((2 * 32 * 8).W))
|
||||||
val tc_a_ready = Input(UInt(2.W))
|
val tc_a_ready = Input(UInt(2.W))
|
||||||
val tc_d_valid = Input(UInt(2.W))
|
val tc_d_valid = Input(UInt(2.W))
|
||||||
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W))
|
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W))
|
||||||
@@ -411,6 +414,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
|||||||
// hopper-style SMEM operand decoupling
|
// hopper-style SMEM operand decoupling
|
||||||
if (tile.radianceParams.core.tensorCoreDecoupled) {
|
if (tile.radianceParams.core.tensorCoreDecoupled) {
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
|
||||||
|
} else if (tile.radianceParams.core.tensorCoreBlackwell) {
|
||||||
|
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_blackwell_core.sv")
|
||||||
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
|
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
|
||||||
def addHopperTensorCore = {
|
def addHopperTensorCore = {
|
||||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")
|
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")
|
||||||
|
|||||||
125
src/test/scala/radiance/TensorCoreBlackwellTest.scala
Normal file
125
src/test/scala/radiance/TensorCoreBlackwellTest.scala
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package radiance.core
|
||||||
|
|
||||||
|
import chisel3._
|
||||||
|
import chiseltest._
|
||||||
|
import org.scalatest.flatspec.AnyFlatSpec
|
||||||
|
|
||||||
|
class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
||||||
|
behavior of "TensorCoreBlackwell"
|
||||||
|
|
||||||
|
private def idleIO(c: TensorCoreBlackwell): Unit = {
|
||||||
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
c.io.respA.valid.poke(false.B)
|
||||||
|
c.io.respB.valid.poke(false.B)
|
||||||
|
c.io.respA.bits.source.poke(0.U)
|
||||||
|
c.io.respB.bits.source.poke(0.U)
|
||||||
|
c.io.respA.bits.data.poke(0.U)
|
||||||
|
c.io.respB.bits.data.poke(0.U)
|
||||||
|
c.io.respC.poke(0.U)
|
||||||
|
c.io.writeback.ready.poke(false.B)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "run a minimal BWGMMA path" in {
|
||||||
|
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
|
||||||
|
idleIO(c)
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(true.B)
|
||||||
|
c.io.initiate.bits.op.poke(0.U)
|
||||||
|
c.io.initiate.bits.wid.poke(1.U)
|
||||||
|
c.io.initiate.bits.rd.poke(3.U)
|
||||||
|
c.io.initiate.bits.addressA.poke(0x40.U)
|
||||||
|
c.io.initiate.bits.addressB.poke(0x80.U)
|
||||||
|
c.io.reqA.ready.poke(true.B)
|
||||||
|
c.io.reqB.ready.poke(true.B)
|
||||||
|
c.io.respC.poke("h0000000800000007000000060000000500000004000000030000000200000001".U)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
c.io.reqA.valid.expect(true.B)
|
||||||
|
c.io.reqB.valid.expect(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
c.io.respA.valid.poke(true.B)
|
||||||
|
c.io.respB.valid.poke(true.B)
|
||||||
|
c.io.respA.bits.data.poke("h0000000800000007000000060000000500000004000000030000000200000001".U)
|
||||||
|
c.io.respB.bits.data.poke("h000000100000000f0000000e0000000d0000000c0000000b0000000a00000009".U)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
c.io.respA.valid.poke(false.B)
|
||||||
|
c.io.respB.valid.poke(false.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.clock.step()
|
||||||
|
c.io.writeback.valid.expect(true.B)
|
||||||
|
c.io.writeback.bits.rd.expect(3.U)
|
||||||
|
c.io.writeback.bits.wid.expect(1.U)
|
||||||
|
c.io.writeback.ready.poke(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "copy from SMEM to TMEM on TCGEN05_CP" in {
|
||||||
|
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
|
||||||
|
idleIO(c)
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(true.B)
|
||||||
|
c.io.initiate.bits.op.poke(2.U)
|
||||||
|
c.io.initiate.bits.wid.poke(0.U)
|
||||||
|
c.io.initiate.bits.rd.poke(0.U)
|
||||||
|
c.io.initiate.bits.addressA.poke(0x100.U)
|
||||||
|
c.io.initiate.bits.addressB.poke(0x200.U)
|
||||||
|
c.io.reqB.ready.poke(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
c.io.reqB.valid.expect(true.B)
|
||||||
|
c.io.respB.valid.poke(true.B)
|
||||||
|
c.io.respB.bits.data.poke("hdeadbeef".U)
|
||||||
|
c.io.reqA.ready.poke(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.reqA.valid.expect(true.B)
|
||||||
|
c.io.reqA.bits.rw.expect(true.B)
|
||||||
|
c.io.reqA.bits.address.expect(0x100.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "load and store fragments through TMEM" in {
|
||||||
|
test(new TensorCoreBlackwell(8, 8, numSourceIds = 4, half = true)) { c =>
|
||||||
|
idleIO(c)
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(true.B)
|
||||||
|
c.io.initiate.bits.op.poke(4.U)
|
||||||
|
c.io.initiate.bits.wid.poke(2.U)
|
||||||
|
c.io.initiate.bits.rd.poke(5.U)
|
||||||
|
c.io.initiate.bits.addressA.poke(0x300.U)
|
||||||
|
c.io.initiate.bits.addressB.poke(0.U)
|
||||||
|
c.io.reqA.ready.poke(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
c.io.initiate.valid.poke(false.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.respA.valid.poke(true.B)
|
||||||
|
c.io.respA.bits.data.poke("h1234".U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.respA.valid.poke(false.B)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.writeback.valid.expect(true.B)
|
||||||
|
c.io.writeback.bits.rd.expect(5.U)
|
||||||
|
c.io.writeback.ready.poke(true.B)
|
||||||
|
c.clock.step()
|
||||||
|
|
||||||
|
idleIO(c)
|
||||||
|
c.io.initiate.valid.poke(true.B)
|
||||||
|
c.io.initiate.bits.op.poke(5.U)
|
||||||
|
c.io.initiate.bits.wid.poke(2.U)
|
||||||
|
c.io.initiate.bits.rd.poke(6.U)
|
||||||
|
c.io.initiate.bits.addressA.poke(0x340.U)
|
||||||
|
c.io.initiate.bits.addressB.poke(0.U)
|
||||||
|
c.io.reqA.ready.poke(true.B)
|
||||||
|
c.io.respC.poke("habcd".U)
|
||||||
|
c.clock.step()
|
||||||
|
c.io.reqA.valid.expect(true.B)
|
||||||
|
c.io.reqA.bits.rw.expect(true.B)
|
||||||
|
c.io.reqA.bits.address.expect(0x340.U)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user