Compare commits
3 Commits
wu-archite
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bbee0542b | ||
|
|
ae552222d6 | ||
|
|
a88da88a63 |
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +1,6 @@
|
|||||||
[submodule "src/main/resources/vsrc/vortex"]
|
[submodule "src/main/resources/vsrc/vortex"]
|
||||||
path = src/main/resources/vsrc/vortex
|
path = src/main/resources/vsrc/vortex
|
||||||
url = https://github.com/hansungk/vortex.git
|
url = https://github.com/hansungk/vortex.git
|
||||||
|
[submodule "cyclotron"]
|
||||||
|
path = cyclotron
|
||||||
|
url = https://github.com/hansungk/cyclotron.git
|
||||||
|
|||||||
1
cyclotron
Submodule
1
cyclotron
Submodule
Submodule cyclotron added at ca6933c4ec
@@ -23,9 +23,6 @@ endif
|
|||||||
ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
|
ifeq ($(shell echo $(CONFIG) | grep -E "HopperConfig$$"),$(CONFIG))
|
||||||
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
|
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4 +define+EXT_T_HOPPER
|
||||||
endif
|
endif
|
||||||
ifeq ($(shell echo $(CONFIG) | grep -E "BlackwellConfig$$"),$(CONFIG))
|
|
||||||
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=1 +define+NUM_WARPS=4 +define+NUM_THREADS=4 +define+NUM_TENSOR_WARPS=2 +define+EXT_T_BLACKWELL
|
|
||||||
endif
|
|
||||||
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
|
ifeq ($(shell echo $(CONFIG) | grep -E "FlashConfig$$"),$(CONFIG))
|
||||||
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4
|
EXTRA_SIM_PREPROC_DEFINES += +define+NUM_CORES=4
|
||||||
endif
|
endif
|
||||||
|
|||||||
Submodule src/main/resources/vsrc/vortex updated: 0ad87bde81...c8529c4339
@@ -1,451 +0,0 @@
|
|||||||
// See LICENSE.SiFive for license details.
|
|
||||||
// See LICENSE.Berkeley for license details.
|
|
||||||
|
|
||||||
package radiance.core
|
|
||||||
|
|
||||||
import chisel3._
|
|
||||||
import chisel3.util._
|
|
||||||
|
|
||||||
class TensorCoreBlackwell(
|
|
||||||
val numWarps: Int,
|
|
||||||
val numLanes: Int,
|
|
||||||
val half: Boolean,
|
|
||||||
val numSourceIds: Int = 16,
|
|
||||||
val numFPRegs: Int = 32
|
|
||||||
) extends Module {
|
|
||||||
require(half, "Blackwell MMA currently supports FP16 inputs only")
|
|
||||||
require(numLanes == 8, "Blackwell MMA currently assumes 8 lanes")
|
|
||||||
|
|
||||||
val numWarpBits = log2Ceil(numWarps)
|
|
||||||
val sourceWidth = log2Ceil(numSourceIds)
|
|
||||||
val laneWidth = 4 * 8
|
|
||||||
val memWidth = numLanes * laneWidth
|
|
||||||
val numFPRegBits = log2Ceil(numFPRegs)
|
|
||||||
val addressWidth = 32
|
|
||||||
val maskWidth = memWidth / 8
|
|
||||||
val fragOffsetBits = log2Ceil(memWidth / 8)
|
|
||||||
|
|
||||||
val numSets = 4
|
|
||||||
val numAFragsPerSet = 8
|
|
||||||
val numBGroups = 4
|
|
||||||
val numBFragsPerGroup = 2
|
|
||||||
val numMGroups = 4
|
|
||||||
val numCFrags = 32
|
|
||||||
|
|
||||||
object Ops {
|
|
||||||
val bwgmma :: bwgmmaWait :: tcgen05Cp :: tcgen05CpWait :: tcgen05Ld :: tcgen05St :: tcgen05Cb :: Nil = Enum(7)
|
|
||||||
}
|
|
||||||
|
|
||||||
class TensorMemReq(
|
|
||||||
sourceWidth: Int,
|
|
||||||
dataWidth: Int
|
|
||||||
) extends Bundle {
|
|
||||||
val rw = Bool()
|
|
||||||
val byteen = UInt((dataWidth / 8).W)
|
|
||||||
val source = UInt(sourceWidth.W)
|
|
||||||
val address = UInt(addressWidth.W)
|
|
||||||
val data = UInt(dataWidth.W)
|
|
||||||
}
|
|
||||||
|
|
||||||
class TensorMemResp(
|
|
||||||
sourceWidth: Int,
|
|
||||||
dataWidth: Int
|
|
||||||
) extends Bundle {
|
|
||||||
val source = UInt(sourceWidth.W)
|
|
||||||
val data = UInt(dataWidth.W)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct SRAM port for TMEM (no TileLink overhead)
|
|
||||||
class TmemSramPort extends Bundle {
|
|
||||||
val aRen = Output(Bool())
|
|
||||||
val aRready = Input(Bool())
|
|
||||||
val aRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
|
|
||||||
val aRdata = Input(UInt(memWidth.W))
|
|
||||||
|
|
||||||
val cRen = Output(Bool())
|
|
||||||
val cRready = Input(Bool())
|
|
||||||
val cRaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
|
|
||||||
val cRdata = Input(UInt(memWidth.W))
|
|
||||||
|
|
||||||
val cWen = Output(Bool())
|
|
||||||
val cWready = Input(Bool())
|
|
||||||
val cWaddr = Output(UInt(log2Ceil(numWarps * numCFrags * 2).W))
|
|
||||||
val cWdata = Output(UInt(memWidth.W))
|
|
||||||
val cMask = Output(UInt(maskWidth.W))
|
|
||||||
}
|
|
||||||
|
|
||||||
val io = IO(new Bundle {
|
|
||||||
val initiate = Flipped(Decoupled(new Bundle {
|
|
||||||
val op = UInt(3.W)
|
|
||||||
val wid = UInt(numWarpBits.W)
|
|
||||||
val rd = UInt(numFPRegBits.W)
|
|
||||||
val addressA = UInt(addressWidth.W)
|
|
||||||
val addressB = UInt(addressWidth.W)
|
|
||||||
val addressC = UInt(addressWidth.W)
|
|
||||||
}))
|
|
||||||
val writeback = Decoupled(new Bundle {
|
|
||||||
val last = Bool()
|
|
||||||
val wid = UInt(numWarpBits.W)
|
|
||||||
val rd = UInt(numFPRegBits.W)
|
|
||||||
val data = Vec(numLanes, UInt(laneWidth.W))
|
|
||||||
})
|
|
||||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
|
||||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
|
||||||
val respC = Input(UInt(memWidth.W))
|
|
||||||
val reqA = Decoupled(new TensorMemReq(sourceWidth, memWidth))
|
|
||||||
val reqB = Decoupled(new TensorMemReq(sourceWidth, memWidth))
|
|
||||||
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
|
||||||
val tmemC = new TmemSramPort // direct SRAM for C matrix (replaces reqCmem/respCmem)
|
|
||||||
})
|
|
||||||
|
|
||||||
object State extends ChiselEnum {
|
|
||||||
val idle, bwLoadAReq, bwLoadAResp, bwLoadBReq, bwLoadBResp,
|
|
||||||
bwReadCReq, bwReadCResp, bwCompute, bwDpuResp, bwWriteCReq,
|
|
||||||
bwWriteCWait, bwDone, cpRead, cpWrite, ldReq, stReq, stWrite, waitWb,
|
|
||||||
cbRead, cbCapture, cbWrite = Value
|
|
||||||
}
|
|
||||||
val state = RegInit(State.idle)
|
|
||||||
|
|
||||||
val opReg = RegInit(0.U(3.W))
|
|
||||||
val widReg = RegInit(0.U(numWarpBits.W))
|
|
||||||
val rdReg = RegInit(0.U(numFPRegBits.W))
|
|
||||||
val addrAReg = RegInit(0.U(addressWidth.W))
|
|
||||||
val addrBReg = RegInit(0.U(addressWidth.W))
|
|
||||||
val addrCReg = RegInit(0.U(addressWidth.W))
|
|
||||||
val sourceCounter = RegInit(0.U(sourceWidth.W))
|
|
||||||
|
|
||||||
val setReg = RegInit(0.U(log2Ceil(numSets).W))
|
|
||||||
val aIndexReg = RegInit(0.U(log2Ceil(numAFragsPerSet).W))
|
|
||||||
val bGroupReg = RegInit(0.U(log2Ceil(numBGroups).W))
|
|
||||||
val bIndexReg = RegInit(0.U(log2Ceil(numBFragsPerGroup).W))
|
|
||||||
val mGroupReg = RegInit(0.U(log2Ceil(numMGroups).W))
|
|
||||||
val substepReg = RegInit(0.U(1.W))
|
|
||||||
val elemReg = RegInit(0.U(log2Ceil(numLanes).W))
|
|
||||||
val waitCounter = RegInit(0.U(3.W))
|
|
||||||
|
|
||||||
val aBuf = Reg(Vec(numAFragsPerSet, UInt(memWidth.W)))
|
|
||||||
val bBuf = Reg(Vec(numBFragsPerGroup, UInt(memWidth.W)))
|
|
||||||
val cDataReg = Reg(UInt(memWidth.W))
|
|
||||||
val mmaDataReg = Reg(Vec(numLanes, UInt(laneWidth.W)))
|
|
||||||
|
|
||||||
private def bumpSource(): Unit = {
|
|
||||||
sourceCounter := sourceCounter + 1.U
|
|
||||||
}
|
|
||||||
|
|
||||||
private def byteAddress(base: UInt, fragIndex: UInt): UInt = {
|
|
||||||
base + (fragIndex << fragOffsetBits).asUInt
|
|
||||||
}
|
|
||||||
|
|
||||||
val aFragIndex = (setReg << 3) + aIndexReg
|
|
||||||
val bFragIndex = (setReg << 3) + (bGroupReg << 1) + bIndexReg
|
|
||||||
val stepIndex = Cat(bGroupReg, mGroupReg)
|
|
||||||
val cFragIndex = (stepIndex << 1) + substepReg
|
|
||||||
val aReqAddress = byteAddress(addrAReg, aFragIndex)
|
|
||||||
val bReqAddress = byteAddress(addrBReg, bFragIndex)
|
|
||||||
val cReqAddress = byteAddress(addrCReg, cFragIndex)
|
|
||||||
val tmemABase = (addrAReg >> fragOffsetBits.U).asUInt
|
|
||||||
val tmemCBase = (addrCReg >> fragOffsetBits.U).asUInt
|
|
||||||
|
|
||||||
val reqA = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
|
|
||||||
val reqB = Wire(Decoupled(new TensorMemReq(sourceWidth, memWidth)))
|
|
||||||
reqA.valid := false.B
|
|
||||||
reqA.bits := 0.U.asTypeOf(reqA.bits)
|
|
||||||
reqB.valid := false.B
|
|
||||||
reqB.bits := 0.U.asTypeOf(reqB.bits)
|
|
||||||
io.reqA <> reqA
|
|
||||||
io.reqB <> reqB
|
|
||||||
|
|
||||||
io.tmemC.aRen := false.B
|
|
||||||
io.tmemC.aRaddr := 0.U
|
|
||||||
io.tmemC.cRen := false.B
|
|
||||||
io.tmemC.cRaddr := 0.U
|
|
||||||
io.tmemC.cWen := false.B
|
|
||||||
io.tmemC.cWaddr := 0.U
|
|
||||||
io.tmemC.cWdata := 0.U
|
|
||||||
io.tmemC.cMask := 0.U
|
|
||||||
|
|
||||||
val wbValid = RegInit(false.B)
|
|
||||||
val wbData = Reg(Vec(numLanes, UInt(laneWidth.W)))
|
|
||||||
io.writeback.valid := wbValid
|
|
||||||
io.writeback.bits.last := true.B
|
|
||||||
io.writeback.bits.wid := widReg
|
|
||||||
io.writeback.bits.rd := rdReg
|
|
||||||
io.writeback.bits.data := wbData
|
|
||||||
|
|
||||||
io.reqC.valid := false.B
|
|
||||||
io.reqC.bits := rdReg
|
|
||||||
|
|
||||||
// drain stale write-ack from TMEM so TLRAM doesn't stall on r_full
|
|
||||||
io.respA.ready := state === State.idle
|
|
||||||
io.respB.ready := false.B
|
|
||||||
io.initiate.ready := state === State.idle && !wbValid
|
|
||||||
|
|
||||||
val operandA = Cat(aBuf((mGroupReg << 1) + 1.U), aBuf(mGroupReg << 1))
|
|
||||||
val operandB = bBuf(substepReg)
|
|
||||||
val cWords = cDataReg.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
|
||||||
val dpuInValid = WireDefault(false.B)
|
|
||||||
val dpu = Module(new TensorDotProductUnit(
|
|
||||||
dim = 8,
|
|
||||||
half = true
|
|
||||||
))
|
|
||||||
|
|
||||||
private def halfWord(x: UInt, idx: Int): UInt = {
|
|
||||||
x((idx + 1) * 16 - 1, idx * 16)
|
|
||||||
}
|
|
||||||
|
|
||||||
val elemM = elemReg(1, 0)
|
|
||||||
val elemN = elemReg(2)
|
|
||||||
dpu.io.in.valid := dpuInValid
|
|
||||||
for (k <- 0 until 8) {
|
|
||||||
dpu.io.in.bits.a(k) := MuxLookup(elemM, halfWord(operandA, k))(Seq(
|
|
||||||
0.U -> halfWord(operandA, k),
|
|
||||||
1.U -> halfWord(operandA, 8 + k),
|
|
||||||
2.U -> halfWord(operandA, 16 + k),
|
|
||||||
3.U -> halfWord(operandA, 24 + k)
|
|
||||||
))
|
|
||||||
dpu.io.in.bits.b(k) := Mux(elemN.asBool, halfWord(operandB, 8 + k), halfWord(operandB, k))
|
|
||||||
}
|
|
||||||
dpu.io.in.bits.c := cWords(elemReg)
|
|
||||||
dpu.io.stall := false.B
|
|
||||||
val dpuValid = dpu.io.out.valid
|
|
||||||
|
|
||||||
when(io.writeback.fire) {
|
|
||||||
wbValid := false.B
|
|
||||||
}
|
|
||||||
|
|
||||||
when(io.initiate.fire) {
|
|
||||||
opReg := io.initiate.bits.op
|
|
||||||
widReg := io.initiate.bits.wid
|
|
||||||
rdReg := io.initiate.bits.rd
|
|
||||||
addrAReg := io.initiate.bits.addressA
|
|
||||||
addrBReg := io.initiate.bits.addressB
|
|
||||||
addrCReg := io.initiate.bits.addressC
|
|
||||||
setReg := 0.U
|
|
||||||
aIndexReg := 0.U
|
|
||||||
bGroupReg := 0.U
|
|
||||||
bIndexReg := 0.U
|
|
||||||
mGroupReg := 0.U
|
|
||||||
substepReg := 0.U
|
|
||||||
elemReg := 0.U
|
|
||||||
switch(io.initiate.bits.op) {
|
|
||||||
is(Ops.bwgmma) { state := State.bwLoadAReq }
|
|
||||||
is(Ops.tcgen05Cp) { state := State.cpRead }
|
|
||||||
is(Ops.tcgen05Ld) { state := State.ldReq }
|
|
||||||
is(Ops.tcgen05St) { state := State.stReq }
|
|
||||||
is(Ops.bwgmmaWait) { state := State.idle }
|
|
||||||
is(Ops.tcgen05CpWait) { state := State.idle }
|
|
||||||
is(Ops.tcgen05Cb) { state := State.cbRead }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwLoadAReq) {
|
|
||||||
io.tmemC.aRen := true.B
|
|
||||||
io.tmemC.aRaddr := tmemABase + aFragIndex
|
|
||||||
when(io.tmemC.aRready) {
|
|
||||||
state := State.bwLoadAResp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwLoadAResp) {
|
|
||||||
aBuf(aIndexReg) := io.tmemC.aRdata
|
|
||||||
when(aIndexReg === (numAFragsPerSet - 1).U) {
|
|
||||||
bGroupReg := 0.U
|
|
||||||
bIndexReg := 0.U
|
|
||||||
state := State.bwLoadBReq
|
|
||||||
}.otherwise {
|
|
||||||
aIndexReg := aIndexReg + 1.U
|
|
||||||
state := State.bwLoadAReq
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwLoadBReq) {
|
|
||||||
reqB.valid := true.B
|
|
||||||
reqB.bits.rw := false.B
|
|
||||||
reqB.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
|
||||||
reqB.bits.address := bReqAddress
|
|
||||||
reqB.bits.source := sourceCounter
|
|
||||||
when(reqB.fire) {
|
|
||||||
bumpSource()
|
|
||||||
state := State.bwLoadBResp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwLoadBResp) {
|
|
||||||
io.respB.ready := true.B
|
|
||||||
when(io.respB.fire) {
|
|
||||||
bBuf(bIndexReg) := io.respB.bits.data
|
|
||||||
when(bIndexReg === (numBFragsPerGroup - 1).U) {
|
|
||||||
mGroupReg := 0.U
|
|
||||||
substepReg := 0.U
|
|
||||||
state := State.bwReadCReq
|
|
||||||
}.otherwise {
|
|
||||||
bIndexReg := bIndexReg + 1.U
|
|
||||||
state := State.bwLoadBReq
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwReadCReq) {
|
|
||||||
io.tmemC.cRen := true.B
|
|
||||||
io.tmemC.cRaddr := tmemCBase + cFragIndex
|
|
||||||
when(io.tmemC.cRready) {
|
|
||||||
state := State.bwReadCResp
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwReadCResp) {
|
|
||||||
cDataReg := io.tmemC.cRdata
|
|
||||||
elemReg := 0.U
|
|
||||||
state := State.bwCompute
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwCompute) {
|
|
||||||
dpuInValid := true.B
|
|
||||||
state := State.bwDpuResp
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwDpuResp) {
|
|
||||||
when(dpuValid) {
|
|
||||||
mmaDataReg(elemReg) := dpu.io.out.bits.data
|
|
||||||
when(elemReg === (numLanes - 1).U) {
|
|
||||||
state := State.bwWriteCReq
|
|
||||||
}.otherwise {
|
|
||||||
elemReg := elemReg + 1.U
|
|
||||||
state := State.bwCompute
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwWriteCReq) {
|
|
||||||
io.tmemC.cWen := true.B
|
|
||||||
io.tmemC.cWaddr := tmemCBase + cFragIndex
|
|
||||||
io.tmemC.cWdata := mmaDataReg.asUInt
|
|
||||||
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
|
|
||||||
when(io.tmemC.cWready) {
|
|
||||||
when(substepReg === 0.U) {
|
|
||||||
substepReg := 1.U
|
|
||||||
state := State.bwReadCReq
|
|
||||||
}.elsewhen(mGroupReg =/= (numMGroups - 1).U) {
|
|
||||||
substepReg := 0.U
|
|
||||||
mGroupReg := mGroupReg + 1.U
|
|
||||||
state := State.bwReadCReq
|
|
||||||
}.elsewhen(bGroupReg =/= (numBGroups - 1).U) {
|
|
||||||
substepReg := 0.U
|
|
||||||
mGroupReg := 0.U
|
|
||||||
bGroupReg := bGroupReg + 1.U
|
|
||||||
bIndexReg := 0.U
|
|
||||||
state := State.bwLoadBReq
|
|
||||||
}.elsewhen(setReg =/= (numSets - 1).U) {
|
|
||||||
substepReg := 0.U
|
|
||||||
mGroupReg := 0.U
|
|
||||||
bGroupReg := 0.U
|
|
||||||
bIndexReg := 0.U
|
|
||||||
setReg := setReg + 1.U
|
|
||||||
aIndexReg := 0.U
|
|
||||||
state := State.bwLoadAReq
|
|
||||||
}.otherwise {
|
|
||||||
waitCounter := 7.U
|
|
||||||
state := State.bwWriteCWait
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwWriteCWait) {
|
|
||||||
when(waitCounter === 0.U) {
|
|
||||||
state := State.bwDone
|
|
||||||
}.otherwise {
|
|
||||||
waitCounter := waitCounter - 1.U
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.bwDone) {
|
|
||||||
wbData := mmaDataReg
|
|
||||||
wbValid := true.B
|
|
||||||
state := State.idle
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.cpRead) {
|
|
||||||
reqA.valid := true.B
|
|
||||||
reqA.bits.rw := false.B
|
|
||||||
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
|
||||||
reqA.bits.address := addrBReg
|
|
||||||
reqA.bits.source := sourceCounter
|
|
||||||
when(reqA.fire) {
|
|
||||||
bumpSource()
|
|
||||||
state := State.cpWrite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.cpWrite) {
|
|
||||||
io.respA.ready := io.tmemC.cWready
|
|
||||||
io.tmemC.cWen := io.respA.valid
|
|
||||||
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
|
|
||||||
io.tmemC.cWdata := io.respA.bits.data
|
|
||||||
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
|
|
||||||
when(io.respA.fire) {
|
|
||||||
state := State.idle
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.ldReq) {
|
|
||||||
io.tmemC.cRen := true.B
|
|
||||||
io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
|
|
||||||
when(io.tmemC.cRready) {
|
|
||||||
state := State.waitWb
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.waitWb && opReg === Ops.tcgen05Ld) {
|
|
||||||
wbData := io.tmemC.cRdata.asTypeOf(Vec(numLanes, UInt(laneWidth.W)))
|
|
||||||
wbValid := true.B
|
|
||||||
state := State.idle
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.stReq) {
|
|
||||||
io.reqC.valid := true.B
|
|
||||||
state := State.stWrite
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.stWrite) {
|
|
||||||
io.tmemC.cWen := true.B
|
|
||||||
io.tmemC.cWaddr := (addrAReg >> fragOffsetBits.U).asUInt
|
|
||||||
io.tmemC.cWdata := io.respC
|
|
||||||
io.tmemC.cMask := Fill(maskWidth, 1.U(1.W))
|
|
||||||
when(io.tmemC.cWready) {
|
|
||||||
state := State.idle
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.cbRead) {
|
|
||||||
io.tmemC.cRen := true.B
|
|
||||||
io.tmemC.cRaddr := (addrAReg >> fragOffsetBits.U).asUInt
|
|
||||||
when(io.tmemC.cRready) {
|
|
||||||
state := State.cbCapture
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.cbCapture) {
|
|
||||||
cDataReg := io.tmemC.cRdata
|
|
||||||
state := State.cbWrite
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.cbWrite) {
|
|
||||||
reqA.valid := true.B
|
|
||||||
reqA.bits.rw := true.B
|
|
||||||
reqA.bits.byteen := Fill(maskWidth, 1.U(1.W))
|
|
||||||
reqA.bits.address := addrBReg
|
|
||||||
reqA.bits.source := sourceCounter
|
|
||||||
reqA.bits.data := cDataReg
|
|
||||||
when(reqA.fire) {
|
|
||||||
bumpSource()
|
|
||||||
state := State.waitWb
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
when(state === State.waitWb && opReg === Ops.tcgen05Cb) {
|
|
||||||
io.respA.ready := true.B
|
|
||||||
when(io.respA.fire) {
|
|
||||||
state := State.idle
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -201,10 +201,8 @@ class DotProductPipe(dim: Int, inputType: tile.FType, outputType: tile.FType) ex
|
|||||||
// pipeline and connect outputs to the next stage
|
// pipeline and connect outputs to the next stage
|
||||||
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
|
outputs := StallingPipe(io.stall, inputs.valid, VecInit(addOuts))
|
||||||
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
|
outC := StallingPipe(io.stall, inputs.valid, inC.bits)
|
||||||
when (inputs.valid =/= inC.valid) {
|
assert(inputs.valid === inC.valid,
|
||||||
printf("WARN: DotProductPipe input/C valid mismatch: inputs=%d c=%d\n",
|
"adder inputs valid and C pipe valid went out-of-sync")
|
||||||
inputs.valid, inC.valid)
|
|
||||||
}
|
|
||||||
|
|
||||||
(outputs, outC)
|
(outputs, outC)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,9 +50,6 @@ class WithRadianceCores(
|
|||||||
crossing: RocketCrossingParams,
|
crossing: RocketCrossingParams,
|
||||||
tensorCoreFP16: Boolean,
|
tensorCoreFP16: Boolean,
|
||||||
tensorCoreDecoupled: Boolean,
|
tensorCoreDecoupled: Boolean,
|
||||||
tensorCoreBlackwell: Boolean,
|
|
||||||
numTensorWarps: Int,
|
|
||||||
startupAddress: BigInt,
|
|
||||||
useVxCache: Boolean
|
useVxCache: Boolean
|
||||||
) extends Config((site, _, up) => {
|
) extends Config((site, _, up) => {
|
||||||
case TilesLocated(`location`) => {
|
case TilesLocated(`location`) => {
|
||||||
@@ -62,10 +59,7 @@ class WithRadianceCores(
|
|||||||
val vortex = RadianceTileParams(
|
val vortex = RadianceTileParams(
|
||||||
core = VortexCoreParams(
|
core = VortexCoreParams(
|
||||||
tensorCoreFP16 = tensorCoreFP16,
|
tensorCoreFP16 = tensorCoreFP16,
|
||||||
tensorCoreDecoupled = tensorCoreDecoupled,
|
tensorCoreDecoupled = tensorCoreDecoupled
|
||||||
tensorCoreBlackwell = tensorCoreBlackwell,
|
|
||||||
numTensorWarps = numTensorWarps,
|
|
||||||
startupAddress = startupAddress
|
|
||||||
),
|
),
|
||||||
btb = None,
|
btb = None,
|
||||||
useVxCache = useVxCache,
|
useVxCache = useVxCache,
|
||||||
@@ -102,9 +96,6 @@ class WithRadianceCores(
|
|||||||
// constructor override that omits `crossing`
|
// constructor override that omits `crossing`
|
||||||
def this(n: Int, location: HierarchicalLocation = InSubsystem,
|
def this(n: Int, location: HierarchicalLocation = InSubsystem,
|
||||||
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
|
tensorCoreFP16: Boolean = false, tensorCoreDecoupled: Boolean = false,
|
||||||
tensorCoreBlackwell: Boolean = false,
|
|
||||||
numTensorWarps: Int = 4,
|
|
||||||
startupAddress: BigInt = BigInt("10100", 16),
|
|
||||||
useVxCache: Boolean = false)
|
useVxCache: Boolean = false)
|
||||||
= this(n, location, RocketCrossingParams(
|
= this(n, location, RocketCrossingParams(
|
||||||
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
||||||
@@ -113,23 +104,9 @@ class WithRadianceCores(
|
|||||||
case InSubsystem => CBUS
|
case InSubsystem => CBUS
|
||||||
case InCluster(clusterId) => CCBUS(clusterId)
|
case InCluster(clusterId) => CCBUS(clusterId)
|
||||||
}
|
}
|
||||||
), tensorCoreFP16, tensorCoreDecoupled, tensorCoreBlackwell, numTensorWarps, startupAddress, useVxCache)
|
), tensorCoreFP16, tensorCoreDecoupled, useVxCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
class WithBlackwellTensorCore(location: HierarchicalLocation = InSubsystem) extends Config((site, _, up) => {
|
|
||||||
case TilesLocated(`location`) =>
|
|
||||||
up(TilesLocated(`location`)).map {
|
|
||||||
case r: RadianceTileAttachParams =>
|
|
||||||
r.copy(tileParams = r.tileParams.copy(
|
|
||||||
core = r.tileParams.core.copy(
|
|
||||||
tensorCoreBlackwell = true,
|
|
||||||
tensorCoreDecoupled = false
|
|
||||||
)
|
|
||||||
))
|
|
||||||
case other => other
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
class WithEmulatorCores(
|
class WithEmulatorCores(
|
||||||
n: Int,
|
n: Int,
|
||||||
useVxCache: Boolean
|
useVxCache: Boolean
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
|
|||||||
val squareBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
|
val squareBoundsInst = ciscInstT.Lit(_.inst -> 0x1220b07b.U, _.rs1 -> 0.U,
|
||||||
_.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U)
|
_.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U)
|
||||||
val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst)
|
val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst)
|
||||||
|
val nopInst = ciscInstT.Lit(_.inst -> 0.U, _.rs1 -> 0.U, _.rs2 -> 0.U)
|
||||||
|
|
||||||
def genStrideInst(tileA: UInt, tileB: UInt) = {
|
def genStrideInst(tileA: UInt, tileB: UInt) = {
|
||||||
val inst = Wire(ciscInstT)
|
val inst = Wire(ciscInstT)
|
||||||
@@ -249,7 +250,9 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
|
|||||||
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U)
|
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U)
|
||||||
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
|
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
|
||||||
}
|
}
|
||||||
is (2.U) {} // no actual invocation, fake job placeholder
|
is (2.U) {
|
||||||
|
ciscInst := microcodeEntry(Seq(nopInst))
|
||||||
|
} // no actual invocation, fake job placeholder
|
||||||
is (8.U) { // set a, b stride
|
is (8.U) { // set a, b stride
|
||||||
val inst = Wire(ciscInstT)
|
val inst = Wire(ciscInstT)
|
||||||
inst.inst := 0x1820b07b.U
|
inst.inst := 0x1820b07b.U
|
||||||
@@ -337,7 +340,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
|
|||||||
gemminiIO.bits.inst := Mux(ciscValid, ciscInst.inst.asTypeOf(gemminiIO.bits.inst), regCommand)
|
gemminiIO.bits.inst := Mux(ciscValid, ciscInst.inst.asTypeOf(gemminiIO.bits.inst), regCommand)
|
||||||
gemminiIO.bits.rs1 := Mux(ciscValid, ciscInst.rs1, Cat(gemminiRs1RegMSB, gemminiRs1RegLSB))
|
gemminiIO.bits.rs1 := Mux(ciscValid, ciscInst.rs1, Cat(gemminiRs1RegMSB, gemminiRs1RegLSB))
|
||||||
gemminiIO.bits.rs2 := Mux(ciscValid, ciscInst.rs2, Cat(gemminiRs2RegMSB, gemminiRs2RegLSB))
|
gemminiIO.bits.rs2 := Mux(ciscValid, ciscInst.rs2, Cat(gemminiRs2RegMSB, gemminiRs2RegLSB))
|
||||||
gemminiIO.valid := ciscValid || regValid
|
gemminiIO.valid := (ciscValid && (ciscInst.inst =/= 0.U)) || regValid
|
||||||
assert(gemminiIO.ready || !gemminiIO.valid)
|
assert(gemminiIO.ready || !gemminiIO.valid)
|
||||||
|
|
||||||
accSlave.status := RegNext(outer.gemmini.module.io.busy).asUInt
|
accSlave.status := RegNext(outer.gemmini.module.io.busy).asUInt
|
||||||
|
|||||||
@@ -101,9 +101,6 @@ case class VortexCoreParams(
|
|||||||
fpu: Option[FPUParams] = None,
|
fpu: Option[FPUParams] = None,
|
||||||
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
|
tensorCoreFP16: Boolean = false, // FP16 if true, FP32 if false
|
||||||
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
|
tensorCoreDecoupled: Boolean = false, // hopper-style SMEM operand decoupling
|
||||||
tensorCoreBlackwell: Boolean = false, // blackwell-style TMEM + SMEM tensor core
|
|
||||||
numTensorWarps: Int = 4,
|
|
||||||
startupAddress: BigInt = BigInt("10100", 16), // initial warp PC programmed through startup DCRs
|
|
||||||
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
|
debugROB: Boolean = false, // if enabled, uses a C++ debug ROB to generate trace-with-wdata
|
||||||
haveCease: Boolean = true, // non-standard CEASE instruction
|
haveCease: Boolean = true, // non-standard CEASE instruction
|
||||||
haveSimTimeout: Boolean = true // add plusarg for simulation timeout
|
haveSimTimeout: Boolean = true // add plusarg for simulation timeout
|
||||||
@@ -155,10 +152,6 @@ class RadianceTile private (
|
|||||||
p(SIMTCoreKey).isDefined,
|
p(SIMTCoreKey).isDefined,
|
||||||
"SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile"
|
"SIMTCoreKey not defined; make sure to use WithSimtConfig when using RadianceTile"
|
||||||
)
|
)
|
||||||
require(
|
|
||||||
!(radianceParams.core.tensorCoreDecoupled && radianceParams.core.tensorCoreBlackwell),
|
|
||||||
"tensorCoreDecoupled and tensorCoreBlackwell are mutually exclusive"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in
|
// NOTE: when changing these, remember to change +define+NUM_CORES/THREADS/WARPS in
|
||||||
// radiance.mk as well!
|
// radiance.mk as well!
|
||||||
@@ -211,9 +204,7 @@ class RadianceTile private (
|
|||||||
case Some(false) => 1
|
case Some(false) => 1
|
||||||
case None => 1
|
case None => 1
|
||||||
}
|
}
|
||||||
// Must match VX_gpu_pkg.sv: ICACHE_TAG_WIDTH = domain + UUID + wid.
|
val imemTagWidth = UUID_WIDTH + NW_WIDTH
|
||||||
val imemDomainWidth = 1
|
|
||||||
val imemTagWidth = imemDomainWidth + UUID_WIDTH + NW_WIDTH
|
|
||||||
|
|
||||||
require(numWarps >= numLsuLanes,
|
require(numWarps >= numLsuLanes,
|
||||||
s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})")
|
s"Vortex core requires numWarps (${numWarps}) >= numLsuLanes (${numLsuLanes})")
|
||||||
@@ -289,17 +280,7 @@ class RadianceTile private (
|
|||||||
}
|
}
|
||||||
|
|
||||||
val tcSmemSize = 32
|
val tcSmemSize = 32
|
||||||
val numTensorWarps = radianceParams.core.numTensorWarps
|
val tcSmemNodes = Seq.tabulate(if (radianceParams.core.tensorCoreDecoupled) 2 else 0) { i =>
|
||||||
val numScalarWarps = numWarps - numTensorWarps
|
|
||||||
require(numTensorWarps > 0 && numTensorWarps < numWarps,
|
|
||||||
s"Wu requires 0 < numTensorWarps (${numTensorWarps}) < numWarps (${numWarps})")
|
|
||||||
val numTensorCores = if (radianceParams.core.tensorCoreBlackwell) numTensorWarps else 1
|
|
||||||
if (radianceParams.core.tensorCoreBlackwell) {
|
|
||||||
require(numTensorCores == numTensorWarps, "Wu Blackwell binding requires one Tensor Core per Tensor warp")
|
|
||||||
}
|
|
||||||
val tensorUsesAsyncMem = radianceParams.core.tensorCoreDecoupled || radianceParams.core.tensorCoreBlackwell
|
|
||||||
val tcSmemNodeCount = if (radianceParams.core.tensorCoreDecoupled) 2 else if (radianceParams.core.tensorCoreBlackwell) numTensorCores else 0
|
|
||||||
val tcSmemNodes = Seq.tabulate(tcSmemNodeCount) { i =>
|
|
||||||
TLClientNode(Seq(TLMasterPortParameters.v2(
|
TLClientNode(Seq(TLMasterPortParameters.v2(
|
||||||
masters = Seq(TLMasterParameters.v2(
|
masters = Seq(TLMasterParameters.v2(
|
||||||
name = s"rad_tc_${radianceParams.coreId}_$i",
|
name = s"rad_tc_${radianceParams.coreId}_$i",
|
||||||
@@ -307,30 +288,12 @@ class RadianceTile private (
|
|||||||
supports = TLSlaveToMasterTransferSizes(
|
supports = TLSlaveToMasterTransferSizes(
|
||||||
probe = TransferSizes(1, tcSmemSize),
|
probe = TransferSizes(1, tcSmemSize),
|
||||||
get = TransferSizes(1, tcSmemSize),
|
get = TransferSizes(1, tcSmemSize),
|
||||||
putFull = TransferSizes(1, tcSmemSize),
|
|
||||||
),
|
),
|
||||||
requestFifo = true
|
requestFifo = true
|
||||||
))
|
))
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Blackwell, tcSmemNodes accesses SMEM (bwgmma B operand)
|
|
||||||
// tcGmemNodes provide global memory access for cp (global→tmem) and cb (tmem→global)
|
|
||||||
val tcGmemNodes = if (radianceParams.core.tensorCoreBlackwell) {
|
|
||||||
Seq.tabulate(numTensorCores) { i =>
|
|
||||||
TLClientNode(Seq(TLMasterPortParameters.v2(masters = Seq(TLMasterParameters.v2(
|
|
||||||
name = s"rad_tc_gmem_${radianceParams.coreId}_$i",
|
|
||||||
sourceId = IdRange(0, 1 << dmemSourceWidth),
|
|
||||||
supports = TLSlaveToMasterTransferSizes(
|
|
||||||
probe = TransferSizes(1, tcSmemSize),
|
|
||||||
get = TransferSizes(1, tcSmemSize),
|
|
||||||
putFull = TransferSizes(1, tcSmemSize),
|
|
||||||
),
|
|
||||||
requestFifo = true
|
|
||||||
)))))
|
|
||||||
}
|
|
||||||
} else Seq.empty
|
|
||||||
|
|
||||||
// combine outgoing per-lane dmemNode into 1 idenity node
|
// combine outgoing per-lane dmemNode into 1 idenity node
|
||||||
//
|
//
|
||||||
// NOTE: We need TLWidthWidget here because there might be a data width
|
// NOTE: We need TLWidthWidget here because there might be a data width
|
||||||
@@ -419,7 +382,6 @@ class RadianceTile private (
|
|||||||
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
|
// imemNodes.foreach { tlMasterXbar.node := TLWidthWidget(4) := _ }
|
||||||
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
|
tlMasterXbar.node :=* AddressOrNode(base) :=* icacheNode
|
||||||
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
|
tlMasterXbar.node :=* AddressOrNode(base) :=* dcacheNode
|
||||||
tcGmemNodes.foreach { n => tlMasterXbar.node := AddressOrNode(base) := n }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* below are copied from rocket */
|
/* below are copied from rocket */
|
||||||
@@ -781,18 +743,12 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
val tcb0 = new {
|
val tcb0 = new {
|
||||||
val addr = core.io.tc_a_bits_address(31, 0)
|
val addr = core.io.tc_a_bits_address(31, 0)
|
||||||
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
|
val tag = core.io.tc_a_bits_tag(outer.tensorTagWidth - 1, 0)
|
||||||
val write = core.io.tc_a_bits_write(0)
|
|
||||||
val mask = core.io.tc_a_bits_mask(31, 0)
|
|
||||||
val data = core.io.tc_a_bits_data(255, 0)
|
|
||||||
val aValid = core.io.tc_a_valid(0)
|
val aValid = core.io.tc_a_valid(0)
|
||||||
val dReady = core.io.tc_d_ready(0)
|
val dReady = core.io.tc_d_ready(0)
|
||||||
}
|
}
|
||||||
val tcb1 = new {
|
val tcb1 = new {
|
||||||
val addr = core.io.tc_a_bits_address(63, 32)
|
val addr = core.io.tc_a_bits_address(63, 32)
|
||||||
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
|
val tag = core.io.tc_a_bits_tag(4 + outer.tensorTagWidth - 1, 4)
|
||||||
val write = core.io.tc_a_bits_write(1)
|
|
||||||
val mask = core.io.tc_a_bits_mask(63, 32)
|
|
||||||
val data = core.io.tc_a_bits_data(511, 256)
|
|
||||||
val aValid = core.io.tc_a_valid(1)
|
val aValid = core.io.tc_a_valid(1)
|
||||||
val dReady = core.io.tc_d_ready(1)
|
val dReady = core.io.tc_d_ready(1)
|
||||||
}
|
}
|
||||||
@@ -814,182 +770,26 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
adapter.io.inReq.bits.address := bundle.addr
|
adapter.io.inReq.bits.address := bundle.addr
|
||||||
adapter.io.inReq.bits.source := bundle.tag
|
adapter.io.inReq.bits.source := bundle.tag
|
||||||
adapter.io.inReq.bits.size := 5.U // 256 bits
|
adapter.io.inReq.bits.size := 5.U // 256 bits
|
||||||
adapter.io.inReq.bits.opcode := Mux(bundle.write.asBool, TLMessages.PutFullData, TLMessages.Get)
|
adapter.io.inReq.bits.opcode := TLMessages.Get
|
||||||
adapter.io.inReq.bits.mask := bundle.mask
|
adapter.io.inReq.bits.mask := x"ffffffff".U
|
||||||
adapter.io.inReq.bits.data := bundle.data
|
|
||||||
adapter.io.inResp.ready := bundle.dReady
|
adapter.io.inResp.ready := bundle.dReady
|
||||||
|
|
||||||
client._1.a <> adapter.io.outReq
|
client._1.a <> adapter.io.outReq
|
||||||
adapter.io.outResp <> client._1.d
|
adapter.io.outResp <> client._1.d
|
||||||
adapter
|
adapter
|
||||||
}
|
}
|
||||||
core.io.tc_a_ready := Cat(0.U(1.W), adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
|
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
|
||||||
core.io.tc_d_valid := Cat(0.U(1.W), adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
|
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
|
||||||
core.io.tc_d_bits_data := Cat(0.U((32 * 8).W), adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
|
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
|
||||||
core.io.tc_d_bits_tag := Cat(0.U(outer.tensorTagWidth.W), adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
|
core.io.tc_d_bits_tag := Cat(adapters.last.io.inResp.bits.source, adapters.head.io.inResp.bits.source)
|
||||||
require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 3)
|
require(core.io.tc_d_bits_data.widthOption.get == adapters.head.io.inResp.bits.data.widthOption.get * 2)
|
||||||
require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 3)
|
require(core.io.tc_d_bits_tag.widthOption.get == adapters.head.io.inResp.bits.source.widthOption.get * 2)
|
||||||
} else {
|
} else {
|
||||||
core.io.tc_a_ready := false.B
|
core.io.tc_a_ready := false.B
|
||||||
core.io.tc_d_valid := false.B
|
core.io.tc_d_valid := false.B
|
||||||
core.io.tc_d_bits_data := DontCare
|
core.io.tc_d_bits_data := DontCare
|
||||||
core.io.tc_d_bits_tag := DontCare
|
core.io.tc_d_bits_tag := DontCare
|
||||||
}
|
}
|
||||||
core.io.tc_tmem_A_rready := DontCare
|
|
||||||
core.io.tc_tmem_A_rdata := DontCare
|
|
||||||
core.io.tc_tmem_C_rready := DontCare
|
|
||||||
core.io.tc_tmem_C_rdata := DontCare
|
|
||||||
core.io.tc_tmem_C_wready := DontCare
|
|
||||||
}
|
|
||||||
|
|
||||||
def connectTensorBlackwell = {
|
|
||||||
if (outer.radianceParams.core.tensorCoreBlackwell) {
|
|
||||||
require(outer.tcSmemNodes.nonEmpty)
|
|
||||||
require(outer.tcSmemNodes.length == outer.numTensorCores)
|
|
||||||
require(outer.tcGmemNodes.length == outer.numTensorCores)
|
|
||||||
|
|
||||||
val nTC = outer.numTensorCores
|
|
||||||
val tcPorts = 3
|
|
||||||
val tcDataBits = outer.tcSmemSize * 8
|
|
||||||
val tmemAddrBits = 9
|
|
||||||
val tmemDataBits = outer.numLsuLanes * 32
|
|
||||||
val tmemMaskBits = outer.numLsuLanes * 4
|
|
||||||
|
|
||||||
def slice(u: UInt, width: Int, idx: Int): UInt = u(width * (idx + 1) - 1, width * idx)
|
|
||||||
def port(tc: Int, p: Int): Int = tc * tcPorts + p
|
|
||||||
|
|
||||||
val tcAReady = Wire(Vec(nTC * tcPorts, Bool()))
|
|
||||||
val tcDValid = Wire(Vec(nTC * tcPorts, Bool()))
|
|
||||||
val tcDData = Wire(Vec(nTC * tcPorts, UInt(tcDataBits.W)))
|
|
||||||
val tcDTag = Wire(Vec(nTC * tcPorts, UInt(outer.tensorTagWidth.W)))
|
|
||||||
tcAReady.foreach(_ := false.B)
|
|
||||||
tcDValid.foreach(_ := false.B)
|
|
||||||
tcDData.foreach(_ := 0.U)
|
|
||||||
tcDTag.foreach(_ := 0.U)
|
|
||||||
|
|
||||||
// TMEM matrix: one shared 2R1W SRAM. read0 is operand A, read1 is C.
|
|
||||||
// Each warp needs 2 tiles (A + C), each tile = 32 frags × 32B = 1KB
|
|
||||||
val tmemDepth = outer.numWarps * outer.tcSmemSize * 2 // numWarps × 64 rows
|
|
||||||
val tmem = Module(new radiance.memory.TwoReadOneWriteSyncMem(
|
|
||||||
tmemDepth, UInt((outer.tcSmemSize * 8).W)))
|
|
||||||
|
|
||||||
val aReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
|
||||||
val cReadArb = Module(new RRArbiter(UInt(tmemAddrBits.W), nTC))
|
|
||||||
|
|
||||||
class TmemWriteReq extends Bundle {
|
|
||||||
val addr = UInt(tmemAddrBits.W)
|
|
||||||
val data = UInt(tmemDataBits.W)
|
|
||||||
val mask = UInt(tmemMaskBits.W)
|
|
||||||
}
|
|
||||||
val cWriteArb = Module(new RRArbiter(new TmemWriteReq, nTC))
|
|
||||||
|
|
||||||
(0 until nTC).foreach { tc =>
|
|
||||||
aReadArb.io.in(tc).valid := core.io.tc_tmem_A_ren(tc)
|
|
||||||
aReadArb.io.in(tc).bits := slice(core.io.tc_tmem_A_raddr, tmemAddrBits, tc)
|
|
||||||
cReadArb.io.in(tc).valid := core.io.tc_tmem_C_ren(tc)
|
|
||||||
cReadArb.io.in(tc).bits := slice(core.io.tc_tmem_C_raddr, tmemAddrBits, tc)
|
|
||||||
cWriteArb.io.in(tc).valid := core.io.tc_tmem_C_wen(tc)
|
|
||||||
cWriteArb.io.in(tc).bits.addr := slice(core.io.tc_tmem_C_waddr, tmemAddrBits, tc)
|
|
||||||
cWriteArb.io.in(tc).bits.data := slice(core.io.tc_tmem_C_wdata, tmemDataBits, tc)
|
|
||||||
cWriteArb.io.in(tc).bits.mask := slice(core.io.tc_tmem_C_mask, tmemMaskBits, tc)
|
|
||||||
}
|
|
||||||
|
|
||||||
aReadArb.io.out.ready := true.B
|
|
||||||
cReadArb.io.out.ready := true.B
|
|
||||||
cWriteArb.io.out.ready := true.B
|
|
||||||
|
|
||||||
tmem.io.ren0 := aReadArb.io.out.fire
|
|
||||||
tmem.io.raddr0 := aReadArb.io.out.bits
|
|
||||||
tmem.io.ren1 := cReadArb.io.out.fire
|
|
||||||
tmem.io.raddr1 := cReadArb.io.out.bits
|
|
||||||
tmem.io.wen := cWriteArb.io.out.fire
|
|
||||||
tmem.io.waddr := cWriteArb.io.out.bits.addr
|
|
||||||
tmem.io.wdata := cWriteArb.io.out.bits.data
|
|
||||||
tmem.io.mask := cWriteArb.io.out.bits.mask
|
|
||||||
|
|
||||||
val aReadGrant = RegNext(Mux(aReadArb.io.out.fire, UIntToOH(aReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
|
||||||
val cReadGrant = RegNext(Mux(cReadArb.io.out.fire, UIntToOH(cReadArb.io.chosen, nTC), 0.U(nTC.W)))
|
|
||||||
core.io.tc_tmem_A_rready := VecInit(aReadArb.io.in.map(_.fire)).asUInt
|
|
||||||
core.io.tc_tmem_C_rready := VecInit(cReadArb.io.in.map(_.fire)).asUInt
|
|
||||||
core.io.tc_tmem_C_wready := VecInit(cWriteArb.io.in.map(_.fire)).asUInt
|
|
||||||
core.io.tc_tmem_A_rdata := VecInit((0 until nTC).map { tc =>
|
|
||||||
Mux(aReadGrant(tc), tmem.io.rdata0, 0.U(tmemDataBits.W))
|
|
||||||
}).asUInt
|
|
||||||
core.io.tc_tmem_C_rdata := VecInit((0 until nTC).map { tc =>
|
|
||||||
Mux(cReadGrant(tc), tmem.io.rdata1, 0.U(tmemDataBits.W))
|
|
||||||
}).asUInt
|
|
||||||
|
|
||||||
// port 2: SMEM B, one TL client per tensor core. RadianceSharedMem arbitrates them.
|
|
||||||
(0 until nTC).foreach { tc =>
|
|
||||||
val p2 = port(tc, 2)
|
|
||||||
val client = outer.tcSmemNodes(tc).out.head
|
|
||||||
val adapter = Module(new VortexTLAdapter(
|
|
||||||
outer.smemSourceWidth,
|
|
||||||
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
|
||||||
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
|
||||||
client
|
|
||||||
))
|
|
||||||
adapter.io.inReq.bits <> DontCare
|
|
||||||
adapter.io.inReq.valid := core.io.tc_a_valid(p2)
|
|
||||||
adapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p2)
|
|
||||||
adapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p2)
|
|
||||||
adapter.io.inReq.bits.size := 5.U
|
|
||||||
adapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p2).asBool, TLMessages.PutFullData, TLMessages.Get)
|
|
||||||
adapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p2)
|
|
||||||
adapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p2)
|
|
||||||
adapter.io.inResp.ready := core.io.tc_d_ready(p2)
|
|
||||||
client._1.a <> adapter.io.outReq
|
|
||||||
adapter.io.outResp <> client._1.d
|
|
||||||
|
|
||||||
tcAReady(p2) := adapter.io.inReq.ready
|
|
||||||
tcDValid(p2) := adapter.io.inResp.valid
|
|
||||||
tcDData(p2) := adapter.io.inResp.bits.data
|
|
||||||
tcDTag(p2) := adapter.io.inResp.bits.source
|
|
||||||
}
|
|
||||||
|
|
||||||
// port 0: global memory (cp/cb), one TL client per tensor core.
|
|
||||||
(0 until nTC).foreach { tc =>
|
|
||||||
val p0 = port(tc, 0)
|
|
||||||
val gmemClient = outer.tcGmemNodes(tc).out.head
|
|
||||||
val gmemAdapter = Module(new VortexTLAdapter(
|
|
||||||
outer.dmemSourceWidth,
|
|
||||||
new VortexBundleA(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
|
||||||
new VortexBundleD(tagWidth = outer.tensorTagWidth, dataWidth = tcDataBits),
|
|
||||||
gmemClient
|
|
||||||
))
|
|
||||||
gmemAdapter.io.inReq.bits <> DontCare
|
|
||||||
gmemAdapter.io.inReq.valid := core.io.tc_a_valid(p0)
|
|
||||||
gmemAdapter.io.inReq.bits.address := slice(core.io.tc_a_bits_address, 32, p0)
|
|
||||||
gmemAdapter.io.inReq.bits.source := slice(core.io.tc_a_bits_tag, outer.tensorTagWidth, p0)
|
|
||||||
gmemAdapter.io.inReq.bits.size := 5.U
|
|
||||||
gmemAdapter.io.inReq.bits.opcode := Mux(core.io.tc_a_bits_write(p0).asBool, TLMessages.PutFullData, TLMessages.Get)
|
|
||||||
gmemAdapter.io.inReq.bits.mask := slice(core.io.tc_a_bits_mask, 32, p0)
|
|
||||||
gmemAdapter.io.inReq.bits.data := slice(core.io.tc_a_bits_data, tcDataBits, p0)
|
|
||||||
gmemAdapter.io.inResp.ready := core.io.tc_d_ready(p0)
|
|
||||||
gmemClient._1.a <> gmemAdapter.io.outReq
|
|
||||||
gmemAdapter.io.outResp <> gmemClient._1.d
|
|
||||||
|
|
||||||
tcAReady(p0) := gmemAdapter.io.inReq.ready
|
|
||||||
tcDValid(p0) := gmemAdapter.io.inResp.valid
|
|
||||||
tcDData(p0) := gmemAdapter.io.inResp.bits.data
|
|
||||||
tcDTag(p0) := gmemAdapter.io.inResp.bits.source
|
|
||||||
}
|
|
||||||
|
|
||||||
core.io.tc_a_ready := tcAReady.asUInt
|
|
||||||
core.io.tc_d_valid := tcDValid.asUInt
|
|
||||||
core.io.tc_d_bits_data := tcDData.asUInt
|
|
||||||
core.io.tc_d_bits_tag := tcDTag.asUInt
|
|
||||||
} else {
|
|
||||||
core.io.tc_a_ready := false.B
|
|
||||||
core.io.tc_d_valid := false.B
|
|
||||||
core.io.tc_d_bits_data := DontCare
|
|
||||||
core.io.tc_d_bits_tag := DontCare
|
|
||||||
core.io.tc_tmem_A_rready := DontCare
|
|
||||||
core.io.tc_tmem_A_rdata := DontCare
|
|
||||||
core.io.tc_tmem_C_rready := DontCare
|
|
||||||
core.io.tc_tmem_C_rdata := DontCare
|
|
||||||
core.io.tc_tmem_C_wready := DontCare
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def connectBarrier = {
|
def connectBarrier = {
|
||||||
@@ -1047,11 +847,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
connectImem
|
connectImem
|
||||||
connectDmem
|
connectDmem
|
||||||
connectSmem
|
connectSmem
|
||||||
if (outer.radianceParams.core.tensorCoreBlackwell) {
|
connectTensor
|
||||||
connectTensorBlackwell
|
|
||||||
} else {
|
|
||||||
connectTensor
|
|
||||||
}
|
|
||||||
connectBarrier
|
connectBarrier
|
||||||
connectAccelerator
|
connectAccelerator
|
||||||
}
|
}
|
||||||
@@ -1078,27 +874,6 @@ class RadianceTileModuleImp(outer: RadianceTile)
|
|||||||
tensor.io.reqA.ready := false.B
|
tensor.io.reqA.ready := false.B
|
||||||
tensor.io.reqB.ready := false.B
|
tensor.io.reqB.ready := false.B
|
||||||
tensor.io.writeback.ready := false.B
|
tensor.io.writeback.ready := false.B
|
||||||
dontTouch(tensor.io)
|
|
||||||
} else if (outer.radianceParams.core.tensorCoreBlackwell) {
|
|
||||||
val tensorNumSourceIds = (1 << outer.tensorTagWidth)
|
|
||||||
val tensor = Module(new radiance.core.TensorCoreBlackwell(
|
|
||||||
8, 8, half = true, tensorNumSourceIds))
|
|
||||||
tensor.io.initiate.valid := false.B
|
|
||||||
tensor.io.initiate.bits := DontCare
|
|
||||||
tensor.io.respA.valid := false.B
|
|
||||||
tensor.io.respA.bits := DontCare
|
|
||||||
tensor.io.respB.valid := false.B
|
|
||||||
tensor.io.respB.bits := DontCare
|
|
||||||
tensor.io.respC := DontCare
|
|
||||||
tensor.io.reqA.ready := false.B
|
|
||||||
tensor.io.reqB.ready := false.B
|
|
||||||
tensor.io.writeback.ready := false.B
|
|
||||||
tensor.io.tmemC.aRready := false.B
|
|
||||||
tensor.io.tmemC.aRdata := DontCare
|
|
||||||
tensor.io.tmemC.cRready := false.B
|
|
||||||
tensor.io.tmemC.cRdata := DontCare
|
|
||||||
tensor.io.tmemC.cWready := false.B
|
|
||||||
dontTouch(tensor.io)
|
|
||||||
} else {
|
} else {
|
||||||
if (outer.radianceParams.core.tensorCoreFP16) {
|
if (outer.radianceParams.core.tensorCoreFP16) {
|
||||||
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))
|
val dpu = Module(new radiance.core.TensorDotProductUnit(4, half = true))
|
||||||
|
|||||||
@@ -90,36 +90,14 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
|
|||||||
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
|
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
|
||||||
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
|
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
|
||||||
|
|
||||||
val numTensorCores = if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1
|
val tc_a_valid = Output(UInt(2.W))
|
||||||
val tcPortCount = 3
|
val tc_a_bits_address = Output(UInt((2 * 32).W))
|
||||||
val tcFlatPortCount = tcPortCount * numTensorCores
|
val tc_a_bits_tag = Output(UInt((2 * 4).W))
|
||||||
val tc_a_valid = Output(UInt(tcFlatPortCount.W))
|
val tc_a_ready = Input(UInt(2.W))
|
||||||
val tc_a_bits_write = Output(UInt(tcFlatPortCount.W))
|
val tc_d_valid = Input(UInt(2.W))
|
||||||
val tc_a_bits_address = Output(UInt((tcFlatPortCount * 32).W))
|
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W))
|
||||||
val tc_a_bits_tag = Output(UInt((tcFlatPortCount * 4).W))
|
val tc_d_bits_tag = Input(UInt((2 * 4).W))
|
||||||
val tc_a_bits_mask = Output(UInt((tcFlatPortCount * 32).W))
|
val tc_d_ready = Output(UInt(2.W))
|
||||||
val tc_a_bits_data = Output(UInt((tcFlatPortCount * 32 * 8).W))
|
|
||||||
val tc_a_ready = Input(UInt(tcFlatPortCount.W))
|
|
||||||
val tc_d_valid = Input(UInt(tcFlatPortCount.W))
|
|
||||||
val tc_d_bits_data = Input(UInt((tcFlatPortCount * 32 * 8).W))
|
|
||||||
val tc_d_bits_tag = Input(UInt((tcFlatPortCount * 4).W))
|
|
||||||
val tc_d_ready = Output(UInt(tcFlatPortCount.W))
|
|
||||||
|
|
||||||
// Direct SRAM ports for shared TMEM (bypasses TileLink)
|
|
||||||
val numLanes = tile.numLsuLanes
|
|
||||||
val tc_tmem_A_ren = Output(UInt(numTensorCores.W))
|
|
||||||
val tc_tmem_A_rready = Input(UInt(numTensorCores.W))
|
|
||||||
val tc_tmem_A_raddr = Output(UInt((numTensorCores * 9).W))
|
|
||||||
val tc_tmem_A_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
|
|
||||||
val tc_tmem_C_ren = Output(UInt(numTensorCores.W))
|
|
||||||
val tc_tmem_C_rready = Input(UInt(numTensorCores.W))
|
|
||||||
val tc_tmem_C_raddr = Output(UInt((numTensorCores * 9).W))
|
|
||||||
val tc_tmem_C_rdata = Input(UInt((numTensorCores * numLanes * 32).W))
|
|
||||||
val tc_tmem_C_wen = Output(UInt(numTensorCores.W))
|
|
||||||
val tc_tmem_C_wready = Input(UInt(numTensorCores.W))
|
|
||||||
val tc_tmem_C_waddr = Output(UInt((numTensorCores * 9).W))
|
|
||||||
val tc_tmem_C_wdata = Output(UInt((numTensorCores * numLanes * 32).W))
|
|
||||||
val tc_tmem_C_mask = Output(UInt((numTensorCores * numLanes * 4).W))
|
|
||||||
|
|
||||||
// FIXME: hardcoded
|
// FIXME: hardcoded
|
||||||
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
|
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
|
||||||
@@ -154,9 +132,9 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
|||||||
Map(
|
Map(
|
||||||
"CORE_ID" -> tile.radianceParams.coreId,
|
"CORE_ID" -> tile.radianceParams.coreId,
|
||||||
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
|
"TENSOR_FP16" -> (if (tile.radianceParams.core.tensorCoreFP16) 1 else 0),
|
||||||
"STARTUP_ADDR" -> tile.radianceParams.core.startupAddress,
|
// TODO: can we get this as a parameter?
|
||||||
"NUM_THREADS" -> tile.numLsuLanes,
|
"BOOTROM_HANG100" -> 0x10100,
|
||||||
"NUM_TENSOR_CORES" -> (if (tile.radianceParams.core.tensorCoreBlackwell) tile.numTensorCores else 1)
|
"NUM_THREADS" -> tile.numLsuLanes
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with HasBlackBoxResource with HasBlackBoxPath {
|
with HasBlackBoxResource with HasBlackBoxPath {
|
||||||
@@ -220,7 +198,6 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
|||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_scoreboard.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_sfu_unit.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_smem_unit.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ctrl_unit.sv")
|
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_split_join.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_trace.vh")
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_wctl_unit.sv")
|
||||||
@@ -434,8 +411,6 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
|||||||
// hopper-style SMEM operand decoupling
|
// hopper-style SMEM operand decoupling
|
||||||
if (tile.radianceParams.core.tensorCoreDecoupled) {
|
if (tile.radianceParams.core.tensorCoreDecoupled) {
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
|
||||||
} else if (tile.radianceParams.core.tensorCoreBlackwell) {
|
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_blackwell_core.sv")
|
|
||||||
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
|
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
|
||||||
def addHopperTensorCore = {
|
def addHopperTensorCore = {
|
||||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")
|
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")
|
||||||
@@ -469,9 +444,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
|||||||
|
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv")
|
addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv")
|
||||||
if (!tile.radianceParams.core.tensorCoreBlackwell) {
|
addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv")
|
||||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv")
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tile.radianceParams.useVxCache) {
|
if (tile.radianceParams.useVxCache) {
|
||||||
addResource("/vsrc/vortex/hw/rtl/libs/VX_pending_size.sv")
|
addResource("/vsrc/vortex/hw/rtl/libs/VX_pending_size.sv")
|
||||||
|
|||||||
@@ -1,338 +0,0 @@
|
|||||||
package radiance.core
|
|
||||||
|
|
||||||
import chisel3._
|
|
||||||
import chiseltest._
|
|
||||||
import chiseltest.simulator.VerilatorBackendAnnotation
|
|
||||||
import org.scalatest.flatspec.AnyFlatSpec
|
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
class TensorCoreBlackwellExtendedTest extends AnyFlatSpec with ChiselScalatestTester {
|
|
||||||
behavior of "TensorCoreBlackwell Extended Tests"
|
|
||||||
|
|
||||||
private val numWarps = 4
|
|
||||||
private val numLanes = 8
|
|
||||||
private val fragBytes = 32
|
|
||||||
|
|
||||||
private def idleIO(c: TensorCoreBlackwell): Unit = {
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.respA.valid.poke(false.B)
|
|
||||||
c.io.respB.valid.poke(false.B)
|
|
||||||
c.io.respA.bits.source.poke(0.U)
|
|
||||||
c.io.respB.bits.source.poke(0.U)
|
|
||||||
c.io.respA.bits.data.poke(0.U)
|
|
||||||
c.io.respB.bits.data.poke(0.U)
|
|
||||||
c.io.reqA.ready.poke(false.B)
|
|
||||||
c.io.reqB.ready.poke(false.B)
|
|
||||||
c.io.respC.poke(0.U)
|
|
||||||
c.io.writeback.ready.poke(false.B)
|
|
||||||
c.io.tmemC.rdata.poke(0.U)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
|
||||||
val mask = (BigInt(1) << width) - 1
|
|
||||||
words.zipWithIndex.foldLeft(BigInt(0)) {
|
|
||||||
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
|
||||||
|
|
||||||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
|
||||||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
|
||||||
val addr = c.io.tmemC.raddr.peek().litValue
|
|
||||||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
|
||||||
}
|
|
||||||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
|
||||||
val addr = c.io.tmemC.waddr.peek().litValue
|
|
||||||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "verify bwgmma address offset with non-zero base addresses" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
|
|
||||||
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val tmem = makeTmem()
|
|
||||||
|
|
||||||
// Use non-zero base addresses to verify offset calculation
|
|
||||||
val aBase = BigInt(0x200) // row 16, A tile rows 16~47
|
|
||||||
val cBase = BigInt(0x600) // row 48, C tile rows 48~79 (no overlap with A)
|
|
||||||
val bBase = BigInt(0x800)
|
|
||||||
|
|
||||||
val fp16One = BigInt(0x3c00)
|
|
||||||
val fp32Zero = BigInt(0)
|
|
||||||
// 4 sets × 8 dot products × (1.0 × 2.0) = 64.0f
|
|
||||||
val fp32SixtyFour = BigInt(0x42800000L)
|
|
||||||
|
|
||||||
// Populate TMEM A at offset aBase (all 32 frags)
|
|
||||||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
|
||||||
val cFrag = packWords(Seq.fill(numLanes)(fp32Zero), 32)
|
|
||||||
for (i <- 0 until 32) {
|
|
||||||
tmem(aBase / fragBytes + i) = aFrag
|
|
||||||
tmem(cBase / fragBytes + i) = cFrag
|
|
||||||
}
|
|
||||||
|
|
||||||
// SMEM B with fp16 2.0
|
|
||||||
val fp16Two = BigInt(0x4000)
|
|
||||||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
|
||||||
val bMem = mutable.Map[BigInt, BigInt]().withDefaultValue(bFrag)
|
|
||||||
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
|
|
||||||
|
|
||||||
c.io.reqB.ready.poke(true.B)
|
|
||||||
c.io.writeback.ready.poke(true.B)
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(0.U)
|
|
||||||
c.io.initiate.bits.wid.poke(0.U)
|
|
||||||
c.io.initiate.bits.rd.poke(0.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(aBase.U)
|
|
||||||
c.io.initiate.bits.addressB.poke(bBase.U)
|
|
||||||
c.io.initiate.bits.addressC.poke(cBase.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
|
|
||||||
var pendingB = Option.empty[(BigInt, BigInt)]
|
|
||||||
var sawWriteback = false
|
|
||||||
|
|
||||||
for (_ <- 0 until 50000 if !sawWriteback) {
|
|
||||||
stepTmem(c, tmem)
|
|
||||||
pendingB.foreach { case (src, data) =>
|
|
||||||
c.io.respB.valid.poke(true.B)
|
|
||||||
c.io.respB.bits.source.poke(src.U)
|
|
||||||
c.io.respB.bits.data.poke(data.U)
|
|
||||||
}
|
|
||||||
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
|
|
||||||
|
|
||||||
if (c.io.writeback.valid.peek().litToBoolean) {
|
|
||||||
sawWriteback = true
|
|
||||||
} else {
|
|
||||||
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
|
|
||||||
val addr = c.io.reqB.bits.address.peek().litValue
|
|
||||||
val src = c.io.reqB.bits.source.peek().litValue
|
|
||||||
Some((src, bMem(addr)))
|
|
||||||
} else None
|
|
||||||
c.clock.step()
|
|
||||||
pendingB = nextB
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sawWriteback, "BWGMMA did not complete")
|
|
||||||
val expectedC = packWords(Seq.fill(numLanes)(fp32SixtyFour), 32)
|
|
||||||
for (i <- 0 until 8) {
|
|
||||||
val row = cBase / fragBytes + i
|
|
||||||
assert(tmem(row) == expectedC,
|
|
||||||
s"C frag $i at row $row: got 0x${tmem(row).toString(16)}, expected 0x${expectedC.toString(16)}")
|
|
||||||
}
|
|
||||||
for (i <- 0 until 8) {
|
|
||||||
assert(tmem(aBase / fragBytes + i) == aFrag, s"A frag $i should be unchanged")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "cp then ld round-trip: data written via cp is readable via ld" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val tmem = makeTmem()
|
|
||||||
val tmemAddr = BigInt(0x100)
|
|
||||||
val cpData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xABCD0000L + i)), 32)
|
|
||||||
|
|
||||||
// Issue cp: global mem -> tmem
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(2.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.initiate.bits.addressB.poke("h10000000".U)
|
|
||||||
c.io.reqA.ready.poke(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
|
|
||||||
// cpRead: reqA issued
|
|
||||||
c.io.reqA.valid.expect(true.B)
|
|
||||||
c.io.reqA.bits.rw.expect(false.B)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// cpWrite: respA fires, tmemC written
|
|
||||||
c.io.respA.valid.poke(true.B)
|
|
||||||
c.io.respA.bits.data.poke(cpData.U)
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
|
||||||
c.io.tmemC.wdata.expect(cpData.U)
|
|
||||||
stepTmem(c, tmem)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.respA.valid.poke(false.B)
|
|
||||||
|
|
||||||
// Now issue ld from same tmem address
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(4.U)
|
|
||||||
c.io.initiate.bits.rd.poke(2.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.writeback.ready.poke(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
|
|
||||||
// ldReq: ren asserted, serve from tmem model
|
|
||||||
c.io.tmemC.ren.expect(true.B)
|
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// writeback should carry cpData
|
|
||||||
c.io.writeback.valid.expect(true.B)
|
|
||||||
for (i <- 0 until numLanes) {
|
|
||||||
c.io.writeback.bits.data(i).expect((BigInt(0xABCD0000L) + i).U)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "st then cb round-trip: data written via st is readable via cb" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val tmem = makeTmem()
|
|
||||||
val tmemAddr = BigInt(0x140)
|
|
||||||
val stData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xDEAD0000L + i)), 32)
|
|
||||||
|
|
||||||
// Issue st: respC -> tmem
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(5.U)
|
|
||||||
c.io.initiate.bits.rd.poke(4.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.respC.poke(stData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
|
|
||||||
// stReq: reqC valid
|
|
||||||
c.io.reqC.valid.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// stWrite: tmemC written
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.io.tmemC.wdata.expect(stData.U)
|
|
||||||
stepTmem(c, tmem)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// Issue cb: tmem -> global mem
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(6.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.initiate.bits.addressB.poke("h20000000".U)
|
|
||||||
c.io.reqA.ready.poke(true.B)
|
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
|
|
||||||
// cbRead: ren asserted
|
|
||||||
c.io.tmemC.ren.expect(true.B)
|
|
||||||
c.io.tmemC.rdata.poke(tmem(tmemAddr / fragBytes).U)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// cbWrite: reqA write with stData
|
|
||||||
c.io.reqA.valid.expect(true.B)
|
|
||||||
c.io.reqA.bits.rw.expect(true.B)
|
|
||||||
c.io.reqA.bits.address.expect("h20000000".U)
|
|
||||||
c.io.reqA.bits.data.expect(stData.U)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "wait ops are no-ops and do not stall pipeline" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
|
|
||||||
// bwgmmaWait: should accept immediately and stay idle
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(1.U) // bwgmmaWait
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.writeback.valid.expect(false.B)
|
|
||||||
c.io.reqA.valid.expect(false.B)
|
|
||||||
c.io.reqB.valid.expect(false.B)
|
|
||||||
|
|
||||||
// tcgen05CpWait: same
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(3.U) // tcgen05CpWait
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.writeback.valid.expect(false.B)
|
|
||||||
c.io.reqA.valid.expect(false.B)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "not accept a second tensor op until the first one completes" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val firstAddr = BigInt(0x180)
|
|
||||||
val secondAddr = BigInt(0x1a0)
|
|
||||||
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xCAFE0000L + i)), 32)
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(5.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(firstAddr.U)
|
|
||||||
c.io.respC.poke(storeData.U)
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
c.io.initiate.bits.op.poke(4.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(secondAddr.U)
|
|
||||||
c.io.initiate.bits.rd.poke(2.U)
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "multi-warp TMEM isolation: warp 0 and warp 3 do not alias" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val tmem = makeTmem()
|
|
||||||
|
|
||||||
// warp 0: tmem_slot_base(0) = 0, tmem_a_base = 0
|
|
||||||
val warp0TmemA = BigInt(0x000)
|
|
||||||
val warp0Data = packWords(Seq.fill(numLanes)(BigInt(0xAAAAAAAAL)), 32)
|
|
||||||
|
|
||||||
// warp 3: tmem_slot_base(3) = 3*2048 = 6144 = 0x1800, tmem_a_base = 0x1800
|
|
||||||
val warp3TmemA = BigInt(0x1800)
|
|
||||||
val warp3Data = packWords(Seq.fill(numLanes)(BigInt(0xBBBBBBBBL)), 32)
|
|
||||||
|
|
||||||
// Write warp 0 data via st
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(5.U)
|
|
||||||
c.io.initiate.bits.wid.poke(0.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(warp0TmemA.U)
|
|
||||||
c.io.respC.poke(warp0Data.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.reqC.valid.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.io.tmemC.waddr.expect((warp0TmemA / fragBytes).U)
|
|
||||||
stepTmem(c, tmem)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// Write warp 3 data via st
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(5.U)
|
|
||||||
c.io.initiate.bits.wid.poke(3.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(warp3TmemA.U)
|
|
||||||
c.io.respC.poke(warp3Data.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.reqC.valid.expect(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.io.tmemC.waddr.expect((warp3TmemA / fragBytes).U)
|
|
||||||
stepTmem(c, tmem)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// Verify no aliasing: warp 0 row != warp 3 row
|
|
||||||
assert(warp0TmemA / fragBytes != warp3TmemA / fragBytes)
|
|
||||||
assert(tmem(warp0TmemA / fragBytes) == warp0Data)
|
|
||||||
assert(tmem(warp3TmemA / fragBytes) == warp3Data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,281 +0,0 @@
|
|||||||
package radiance.core
|
|
||||||
|
|
||||||
import chisel3._
|
|
||||||
import chiseltest._
|
|
||||||
import chiseltest.simulator.VerilatorBackendAnnotation
|
|
||||||
import org.scalatest.flatspec.AnyFlatSpec
|
|
||||||
|
|
||||||
import scala.collection.mutable
|
|
||||||
|
|
||||||
class TensorCoreBlackwellTest extends AnyFlatSpec with ChiselScalatestTester {
|
|
||||||
behavior of "TensorCoreBlackwell"
|
|
||||||
|
|
||||||
private val numWarps = 4
|
|
||||||
private val numLanes = 8
|
|
||||||
|
|
||||||
private def idleIO(c: TensorCoreBlackwell): Unit = {
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.respA.valid.poke(false.B)
|
|
||||||
c.io.respB.valid.poke(false.B)
|
|
||||||
c.io.respA.bits.source.poke(0.U)
|
|
||||||
c.io.respB.bits.source.poke(0.U)
|
|
||||||
c.io.respA.bits.data.poke(0.U)
|
|
||||||
c.io.respB.bits.data.poke(0.U)
|
|
||||||
c.io.reqA.ready.poke(false.B)
|
|
||||||
c.io.reqB.ready.poke(false.B)
|
|
||||||
c.io.respC.poke(0.U)
|
|
||||||
c.io.writeback.ready.poke(false.B)
|
|
||||||
c.io.tmemC.rdata.poke(0.U)
|
|
||||||
}
|
|
||||||
|
|
||||||
private def packWords(words: Seq[BigInt], width: Int): BigInt = {
|
|
||||||
val mask = (BigInt(1) << width) - 1
|
|
||||||
words.zipWithIndex.foldLeft(BigInt(0)) {
|
|
||||||
case (acc, (word, i)) => acc | ((word & mask) << (i * width))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple TMEM model: address → 256-bit row
|
|
||||||
private def makeTmem() = mutable.Map[BigInt, BigInt]().withDefaultValue(BigInt(0))
|
|
||||||
|
|
||||||
// Drive tmemC read response from model, handle write
|
|
||||||
private def stepTmem(c: TensorCoreBlackwell, tmem: mutable.Map[BigInt, BigInt]): Unit = {
|
|
||||||
if (c.io.tmemC.ren.peek().litToBoolean) {
|
|
||||||
val addr = c.io.tmemC.raddr.peek().litValue
|
|
||||||
c.io.tmemC.rdata.poke(tmem(addr).U)
|
|
||||||
}
|
|
||||||
if (c.io.tmemC.wen.peek().litToBoolean) {
|
|
||||||
val addr = c.io.tmemC.waddr.peek().litValue
|
|
||||||
tmem(addr) = c.io.tmemC.wdata.peek().litValue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "tcgen05_ld: read from TMEM to writeback" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val tmem = makeTmem()
|
|
||||||
val fragBytes = 32
|
|
||||||
val tmemAddr = BigInt(0x40) // row 2 (0x40 / 32 = 2)
|
|
||||||
val testData = packWords(Seq.tabulate(numLanes)(i => BigInt(0x1000 + i)), 32)
|
|
||||||
tmem(tmemAddr / fragBytes) = testData
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(4.U) // tcgen05Ld
|
|
||||||
c.io.initiate.bits.wid.poke(0.U)
|
|
||||||
c.io.initiate.bits.rd.poke(3.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.writeback.ready.poke(true.B)
|
|
||||||
c.io.tmemC.rdata.poke(testData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
// ldReq: tmemC.ren asserted; rdata must be valid before next step
|
|
||||||
c.io.tmemC.ren.expect(true.B)
|
|
||||||
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
|
|
||||||
c.io.tmemC.rdata.poke(testData.U)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// waitWb: wbValid gets set this cycle, step to let it register
|
|
||||||
c.io.tmemC.rdata.poke(testData.U)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// idle: writeback.valid now true
|
|
||||||
c.io.writeback.valid.expect(true.B)
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
c.io.writeback.bits.rd.expect(3.U)
|
|
||||||
c.io.writeback.bits.wid.expect(0.U)
|
|
||||||
for (i <- 0 until numLanes) {
|
|
||||||
c.io.writeback.bits.data(i).expect((0x1000 + i).U)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "tcgen05_st: write from respC to TMEM" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val fragBytes = 32
|
|
||||||
val tmemAddr = BigInt(0x60)
|
|
||||||
val storeData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xAB00 + i)), 32)
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(5.U) // tcgen05St
|
|
||||||
c.io.initiate.bits.wid.poke(0.U)
|
|
||||||
c.io.initiate.bits.rd.poke(7.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.respC.poke(storeData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
// stReq: reqC.valid asserted
|
|
||||||
c.io.reqC.valid.expect(true.B)
|
|
||||||
c.io.reqC.bits.expect(7.U)
|
|
||||||
c.clock.step()
|
|
||||||
|
|
||||||
// stWrite: tmemC.wen asserted with storeData
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
|
||||||
c.io.tmemC.wdata.expect(storeData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "tcgen05_cp: read from global mem (reqA) and write to TMEM" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val fragBytes = 32
|
|
||||||
val tmemAddr = BigInt(0x80)
|
|
||||||
val gmemAddr = "ha0001000"
|
|
||||||
val cpData = packWords(Seq.fill(numLanes)(BigInt(0xdeadbeefL)), 32)
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(2.U) // tcgen05Cp
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.initiate.bits.addressB.poke(gmemAddr.U)
|
|
||||||
c.io.reqA.ready.poke(true.B)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
// cpRead: reqA issued to global mem
|
|
||||||
c.io.reqA.valid.expect(true.B)
|
|
||||||
c.io.reqA.bits.rw.expect(false.B)
|
|
||||||
c.io.reqA.bits.address.expect(gmemAddr.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
// cpWrite: respA fires → tmemC.wen in same cycle
|
|
||||||
c.io.respA.valid.poke(true.B)
|
|
||||||
c.io.respA.bits.data.poke(cpData.U)
|
|
||||||
|
|
||||||
// tmemC write happens combinatorially when respA fires
|
|
||||||
c.io.tmemC.wen.expect(true.B)
|
|
||||||
c.io.tmemC.waddr.expect((tmemAddr / fragBytes).U)
|
|
||||||
c.io.tmemC.wdata.expect(cpData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "tcgen05_cb: read from TMEM and write to global mem (reqA)" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
val fragBytes = 32
|
|
||||||
val tmemAddr = BigInt(0xa0)
|
|
||||||
val gmemAddr = "ha2000000"
|
|
||||||
val cbData = packWords(Seq.tabulate(numLanes)(i => BigInt(0xC000 + i)), 32)
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(6.U) // tcgen05Cb
|
|
||||||
c.io.initiate.bits.addressA.poke(tmemAddr.U)
|
|
||||||
c.io.initiate.bits.addressB.poke(gmemAddr.U)
|
|
||||||
c.io.reqA.ready.poke(true.B)
|
|
||||||
c.io.tmemC.rdata.poke(cbData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
// cbRead: tmemC.ren asserted
|
|
||||||
c.io.tmemC.ren.expect(true.B)
|
|
||||||
c.io.tmemC.raddr.expect((tmemAddr / fragBytes).U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
|
|
||||||
// cbWrite: reqA write to global mem
|
|
||||||
c.io.reqA.valid.expect(true.B)
|
|
||||||
c.io.reqA.bits.rw.expect(true.B)
|
|
||||||
c.io.reqA.bits.address.expect(gmemAddr.U)
|
|
||||||
c.io.reqA.bits.data.expect(cbData.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(false.B)
|
|
||||||
c.io.respA.valid.poke(true.B)
|
|
||||||
c.io.respA.bits.data.poke(0.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.ready.expect(true.B)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
it should "run bwgmma: TMEM_C = TMEM_A * SMEM_B + TMEM_C" in {
|
|
||||||
test(new TensorCoreBlackwell(numWarps, numLanes, half = true, numSourceIds = 4))
|
|
||||||
.withAnnotations(Seq(VerilatorBackendAnnotation)) { c =>
|
|
||||||
idleIO(c)
|
|
||||||
|
|
||||||
val fragBytes = 32
|
|
||||||
val aBase = BigInt(0x100)
|
|
||||||
val bBase = BigInt(0x800)
|
|
||||||
val cBase = BigInt(0x1000)
|
|
||||||
|
|
||||||
// A: all fp16 1.0 (0x3c00), 16 halves per frag
|
|
||||||
val fp16One = BigInt(0x3c00)
|
|
||||||
val fp16Two = BigInt(0x4000)
|
|
||||||
val fp32One = BigInt(0x3f800000)
|
|
||||||
val fp32SixtyFive = BigInt(0x42820000)
|
|
||||||
val aFrag = packWords(Seq.fill(16)(fp16One), 16)
|
|
||||||
val bFrag = packWords(Seq.fill(16)(fp16Two), 16)
|
|
||||||
val cFrag = packWords(Seq.fill(numLanes)(fp32One), 32)
|
|
||||||
val expectedCFrag = packWords(Seq.fill(numLanes)(fp32SixtyFive), 32)
|
|
||||||
|
|
||||||
// Populate TMEM with A and C tiles
|
|
||||||
val tmem = makeTmem()
|
|
||||||
for (i <- 0 until 32) {
|
|
||||||
tmem(aBase / fragBytes + i) = aFrag
|
|
||||||
tmem(cBase / fragBytes + i) = cFrag
|
|
||||||
}
|
|
||||||
val bMem = mutable.Map[BigInt, BigInt]()
|
|
||||||
for (i <- 0 until 32) bMem(bBase + i * fragBytes) = bFrag
|
|
||||||
|
|
||||||
c.io.reqB.ready.poke(true.B)
|
|
||||||
c.io.writeback.ready.poke(true.B)
|
|
||||||
|
|
||||||
c.io.initiate.valid.poke(true.B)
|
|
||||||
c.io.initiate.bits.op.poke(0.U) // bwgmma
|
|
||||||
c.io.initiate.bits.wid.poke(1.U)
|
|
||||||
c.io.initiate.bits.rd.poke(0.U)
|
|
||||||
c.io.initiate.bits.addressA.poke(aBase.U)
|
|
||||||
c.io.initiate.bits.addressB.poke(bBase.U)
|
|
||||||
c.io.initiate.bits.addressC.poke(cBase.U)
|
|
||||||
c.clock.step()
|
|
||||||
c.io.initiate.valid.poke(false.B)
|
|
||||||
|
|
||||||
var pendingB = Option.empty[(BigInt, BigInt)]
|
|
||||||
var sawWriteback = false
|
|
||||||
|
|
||||||
for (_ <- 0 until 20000 if !sawWriteback) {
|
|
||||||
// Drive TMEM reads/writes
|
|
||||||
stepTmem(c, tmem)
|
|
||||||
|
|
||||||
// Drive SMEM B responses
|
|
||||||
pendingB.foreach { case (src, data) =>
|
|
||||||
c.io.respB.valid.poke(true.B)
|
|
||||||
c.io.respB.bits.source.poke(src.U)
|
|
||||||
c.io.respB.bits.data.poke(data.U)
|
|
||||||
}
|
|
||||||
if (pendingB.isEmpty) c.io.respB.valid.poke(false.B)
|
|
||||||
|
|
||||||
if (c.io.writeback.valid.peek().litToBoolean) {
|
|
||||||
sawWriteback = true
|
|
||||||
} else {
|
|
||||||
val nextB = if (c.io.reqB.valid.peek().litToBoolean) {
|
|
||||||
val addr = c.io.reqB.bits.address.peek().litValue
|
|
||||||
val src = c.io.reqB.bits.source.peek().litValue
|
|
||||||
Some((src, bMem(addr)))
|
|
||||||
} else None
|
|
||||||
|
|
||||||
c.clock.step()
|
|
||||||
pendingB = nextB
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sawWriteback, "BWGMMA did not complete")
|
|
||||||
c.io.writeback.bits.wid.expect(1.U)
|
|
||||||
// Verify all 32 C frags in TMEM
|
|
||||||
for (i <- 0 until 32) {
|
|
||||||
val row = cBase / fragBytes + i
|
|
||||||
assert(tmem(row) == expectedCFrag,
|
|
||||||
s"C frag $i mismatch: got 0x${tmem(row).toString(16)}, expected 0x${expectedCFrag.toString(16)}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user