connect tc nodes and maybe fix distributor node

This commit is contained in:
Richard Yan
2024-10-07 02:59:06 -07:00
parent 4f057c6994
commit 0989d90dd2
5 changed files with 93 additions and 32 deletions

View File

@@ -91,13 +91,15 @@ class DistributorNode(from: Int, to: Int)(implicit p: Parameters) extends LazyMo
} }
def partialData: UInt = VecInit(mn.map(_.d).map(d => Mux(d.fire, d.bits.data, 0.U(d.bits.data.getWidth.W)))).asUInt def partialData: UInt = VecInit(mn.map(_.d).map(d => Mux(d.fire, d.bits.data, 0.U(d.bits.data.getWidth.W)))).asUInt
def partialValid: UInt = VecInit(mn.map(_.d.fire)).asUInt def partialValid: UInt = VecInit(mn.map(_.d.valid)).asUInt
def partialFire: UInt = VecInit(mn.map(_.d.fire)).asUInt
mn.map(_.d.ready).zip(arrived.asBools).foreach { case (r, a) => mn.map(_.d.ready).zip(arrived.asBools).foreach { case (r, a) =>
r := cn.d.ready && (!partialWait || !a) // if waiting for partial response, ready only if not arrived yet r := cn.d.ready && (!partialWait || !a) // if waiting for partial response, ready only if not arrived yet
} }
// TODO: might need coverage test for this // TODO: might need coverage test for this
cd := DontCare
when (!partialWait) { when (!partialWait) {
cn.d.valid := false.B cn.d.valid := false.B
partialWait := false.B partialWait := false.B
@@ -109,34 +111,39 @@ class DistributorNode(from: Int, to: Int)(implicit p: Parameters) extends LazyMo
assert(cd.data === partialData, "sanity check") assert(cd.data === partialData, "sanity check")
}.elsewhen (partialValid.orR) { }.elsewhen (partialValid.orR) {
// at least 1 valid: enter partial valid state, store partial data into regs // at least 1 valid: enter partial valid state, store partial data into regs
partialWait := true.B partialWait := cn.d.ready // if something fired, enter partial wait
arrived := partialValid arrived := partialFire
cdReg.data := partialData cdReg.data := partialData
when (mn.head.d.valid) { setMetadata(cdReg, mn.head.d.bits) } when (mn.head.d.fire) { setMetadata(cdReg, mn.head.d.bits) }
} }
}.otherwise { }.otherwise {
cn.d.valid := false.B cn.d.valid := false.B
partialWait := true.B partialWait := true.B
when ((arrived | partialValid).andR) { when ((arrived | partialValid).andR) {
// all valids received now // all valids received now
cn.d.valid := true.B
when (cn.d.ready) {
assert((arrived | partialFire).andR)
when (mn.head.d.valid) { when (mn.head.d.valid) {
setMetadata(cd, mn.head.d.bits) setMetadata(cd, mn.head.d.bits)
}.otherwise { }.otherwise {
cd := cdReg cd := cdReg
} }
cn.d.valid := true.B
cd.data := cdReg.data | partialData cd.data := cdReg.data | partialData
partialWait := false.B partialWait := false.B
cdReg := 0.U.asTypeOf(cdReg.cloneType) cdReg := 0.U.asTypeOf(cdReg.cloneType)
arrived := 0.U arrived := 0.U
}
}.elsewhen (partialValid.orR) { }.elsewhen (partialValid.orR) {
// update partial data // update partial data
when (cn.d.ready) {
arrived := arrived | partialValid arrived := arrived | partialValid
cdReg.data := cdReg.data | partialData cdReg.data := cdReg.data | partialData
when (mn.head.d.valid) { setMetadata(cdReg, mn.head.d.bits) } when (mn.head.d.valid) { setMetadata(cdReg, mn.head.d.bits) }
} }
} }
} }
}
} }
object DistributorNode { object DistributorNode {

View File

@@ -11,6 +11,7 @@ import freechips.rocketchip.diplomacy._
import org.chipsalliance.diplomacy.lazymodule.LazyModule import org.chipsalliance.diplomacy.lazymodule.LazyModule
import freechips.rocketchip.prci.{ClockCrossingType, ClockSinkParameters, RationalCrossing} import freechips.rocketchip.prci.{ClockCrossingType, ClockSinkParameters, RationalCrossing}
import freechips.rocketchip.regmapper.RegField import freechips.rocketchip.regmapper.RegField
import freechips.rocketchip.resources.BigIntHexContext
import freechips.rocketchip.rocket._ import freechips.rocketchip.rocket._
import freechips.rocketchip.subsystem.HierarchicalElementCrossingParamsLike import freechips.rocketchip.subsystem.HierarchicalElementCrossingParamsLike
import freechips.rocketchip.tile._ import freechips.rocketchip.tile._
@@ -275,17 +276,18 @@ class RadianceTile private (
} }
val tcSmemSize = 32 val tcSmemSize = 32
val tcSmemNodes = Seq(TLClientNode(Seq(TLMasterPortParameters.v2( val tcSmemNodes = Seq.tabulate(2) { i =>
TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2( masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}", name = s"rad_tc_${radianceParams.coreId}_$i",
sourceId = IdRange(0, 1 << smemSourceWidth), sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes( supports = TLSlaveToMasterTransferSizes(
probe = TransferSizes(1, tcSmemSize),
get = TransferSizes(1, tcSmemSize), get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
putPartial = TransferSizes(1, tcSmemSize)
) )
)) ))
)))) )))
}
// combine outgoing per-lane dmemNode into 1 idenity node // combine outgoing per-lane dmemNode into 1 idenity node
// //
@@ -686,7 +688,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
outer.smemSourceWidth, outer.smemSourceWidth,
new VortexBundleA(tagWidth = outer.smemTagWidth, dataWidth = 32), new VortexBundleA(tagWidth = outer.smemTagWidth, dataWidth = 32),
new VortexBundleD(tagWidth = outer.smemTagWidth, dataWidth = 32), new VortexBundleD(tagWidth = outer.smemTagWidth, dataWidth = 32),
outer.smemNodes(0).out.head outer.smemNodes.head.out.head
) )
) )
} }
@@ -731,6 +733,46 @@ class RadianceTileModuleImp(outer: RadianceTile)
} }
} }
def connectTc {
val tcb0 = new {
val addr = core.io.tc_a_bits_address(31, 0)
val aValid = core.io.tc_a_valid(0)
val dReady = core.io.tc_d_ready(0)
}
val tcb1 = new {
val addr = core.io.tc_a_bits_address(63, 32)
val aValid = core.io.tc_a_valid(1)
val dReady = core.io.tc_d_ready(1)
}
val tcBundles = Seq(tcb0, tcb1)
val adapters = (outer.tcSmemNodes zip tcBundles).zipWithIndex.map { case ((node, bundle), i) =>
val client = node.out.head
val adapter = Module(
new VortexTLAdapter(
outer.smemSourceWidth,
new VortexBundleA(tagWidth = 1, dataWidth = 32 * 8),
new VortexBundleD(tagWidth = 1, dataWidth = 32 * 8),
client
)
)
adapter.io.inReq.bits <> DontCare
adapter.io.inReq.valid := bundle.aValid
adapter.io.inReq.bits.address := bundle.addr
adapter.io.inReq.bits.source := i.U
adapter.io.inReq.bits.size := 5.U
adapter.io.inReq.bits.opcode := TLMessages.Get
adapter.io.inReq.bits.mask := x"ffffffff".U
adapter.io.inResp.ready := bundle.dReady
client._1.a <> adapter.io.outReq
adapter.io.outResp <> client._1.d
adapter
}
core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready)
core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid)
core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data)
}
def connectBarrier = { def connectBarrier = {
require(outer.barrierMasterNode.out.length == 1) require(outer.barrierMasterNode.out.length == 1)
// FIXME: bits not flattened // FIXME: bits not flattened
@@ -786,6 +828,7 @@ class RadianceTileModuleImp(outer: RadianceTile)
connectImem connectImem
connectDmem connectDmem
connectSmem connectSmem
connectTc
connectBarrier connectBarrier
connectAccelerator connectAccelerator
} }

View File

@@ -54,7 +54,9 @@ class VirgoSharedMemComponents(
smemFanoutXbar.node smemFanoutXbar.node
} }
} }
val tcNodeFanouts = radianceTiles.flatMap(_.tcSmemNodes).map(connectXbarName(_, Some("tc_fanout"))) val tcNodeFanouts = radianceTiles.flatMap(_.tcSmemNodes)
.map(connectOne(_, () => TLBuffer(BufferParams(2, false, false), BufferParams(0))))
.map(connectXbarName(_, Some("tc_fanout")))
val clBusClients: Seq[TLNode] = radianceSmemFanout val clBusClients: Seq[TLNode] = radianceSmemFanout
val (uniformRNodes, uniformWNodes, nonuniformRNodes, nonuniformWNodes) = val (uniformRNodes, uniformWNodes, nonuniformRNodes, nonuniformWNodes) =
@@ -69,7 +71,7 @@ class VirgoSharedMemComponents(
dist := node dist := node
} }
val fanout = Seq.tabulate(spSubbanks) { w => val fanout = Seq.tabulate(spSubbanks) { w =>
val buf = TLBuffer(BufferParams(1, false, true), BufferParams(0)) val buf = TLBuffer(BufferParams(2, false, false), BufferParams(0))
buf := dist buf := dist
connectXbarName(buf, Some(s"spad_g${gemminiIdx}w${w}_fanout_$suffix")) connectXbarName(buf, Some(s"spad_g${gemminiIdx}w${w}_fanout_$suffix"))
} }
@@ -88,7 +90,7 @@ class VirgoSharedMemComponents(
// tensor core read nodes // tensor core read nodes
val tcDistNodes = Seq.fill(smemBanks)(tcNodeFanouts.map(connectOne(_, () => DistributorNode(smemWidth, wordSize)))) val tcDistNodes = Seq.fill(smemBanks)(tcNodeFanouts.map(connectOne(_, () => DistributorNode(smemWidth, wordSize))))
val tcNodes = tcDistNodes.map { tcBank => val tcNodes = tcDistNodes.map { tcBank =>
Seq.fill(smemSubbanks)(tcBank.map(connectXbarName(_, Some("tc_dist_fanout")))) Seq.fill(smemSubbanks)(tcBank.map(connectOne(_, () => TLBuffer(BufferParams(2, false, false)))).map(connectXbarName(_, Some("tc_dist_fanout"))))
} // (banks, subbanks, tc client) } // (banks, subbanks, tc client)
if (filterAligned) { if (filterAligned) {

View File

@@ -90,6 +90,13 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl
val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W))
val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W))
val tc_a_valid = Output(UInt(2.W))
val tc_a_bits_address = Output(UInt((2 * 32).W))
val tc_a_ready = Input(UInt(2.W))
val tc_d_valid = Input(UInt(2.W))
val tc_d_bits_data = Input(UInt((2 * 32 * 8).W))
val tc_d_ready = Output(UInt(2.W))
// FIXME: hardcoded // FIXME: hardcoded
val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits
val coreIdBits = tile.barrierMasterNode.out(0)._2.numCoreBits val coreIdBits = tile.barrierMasterNode.out(0)._2.numCoreBits
@@ -233,6 +240,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
// addResource("/vsrc/vortex/hw/rtl/mem/VX_gbar_arb.sv") // addResource("/vsrc/vortex/hw/rtl/mem/VX_gbar_arb.sv")
// addResource("/vsrc/vortex/hw/rtl/mem/VX_gbar_unit.sv") // addResource("/vsrc/vortex/hw/rtl/mem/VX_gbar_unit.sv")
addResource("/vsrc/vortex/hw/rtl/mem/VX_tc_bus_if.sv")
addResource("/vsrc/vortex/hw/rtl/libs/VX_allocator.sv") addResource("/vsrc/vortex/hw/rtl/libs/VX_allocator.sv")
// addResource("/vsrc/vortex/hw/rtl/libs/VX_avs_adapter.sv") // addResource("/vsrc/vortex/hw/rtl/libs/VX_avs_adapter.sv")
// addResource("/vsrc/vortex/hw/rtl/libs/VX_axi_adapter.sv") // addResource("/vsrc/vortex/hw/rtl/libs/VX_axi_adapter.sv")