route aligned smem requests separately, fix node bugs

This commit is contained in:
Richard Yan
2024-04-09 20:07:58 -07:00
parent f60a318edb
commit c2fbe8388e
4 changed files with 232 additions and 78 deletions

View File

@@ -0,0 +1,103 @@
package radiance.memory
import chisel3._
import chisel3.experimental.SourceInfo
import chisel3.util._
import freechips.rocketchip.diplomacy._
import freechips.rocketchip.tilelink._
import freechips.rocketchip.util.BundleField
import org.chipsalliance.cde.config.Parameters
// this node splits the incoming requests into two outgoing edges,
// the first edge contains requests that match the filter AddressSet,
// and the second edge contains requests that don't.
// on the return leg, the two responses are arbitrated in a RR fashion.
class AlignFilterNode(filters: Seq[AddressSet])(implicit p: Parameters) extends LazyModule {
val node = TLNexusNode(clientFn = seq => {
require(seq.map(_.masters.size).sum == 1, s"there should only be one client to a filter node, " +
s"found ${seq.map(_.masters.size).sum}")
val master = seq.head.masters.head
// TODO: to implement multiple filters, source Id mapping needs to be redone
assert(filters.length == 1, "multiple filters currently not supported")
seq.head.v1copy(
clients = filters.map { filter =>
master.v2copy(
name = s"${name}_filter_aligned",
sourceId = master.sourceId,
visibility = Seq(filter),
emits = seq.map(_.anyEmitClaims).reduce(_ mincover _)
)
} ++ Seq(
master.v2copy(
name = s"${name}_filter_unaligned",
sourceId = master.sourceId.shift(master.sourceId.size),
visibility = Seq(AddressSet.everything),
emits = seq.map(_.anyEmitClaims).reduce(_ mincover _)
),
)
)
}, managerFn = seq => {
val addresses = seq.flatMap(_.slaves.flatMap(_.address))
val unifiedAddressRange = addresses.flatMap(_.toRanges).sorted.reduce(_.union(_).get)
assert(isPow2(unifiedAddressRange.size))
println(s"$name address range ${unifiedAddressRange}")
seq.head.v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
minLatency = seq.map(_.minLatency).min,
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
managers = Seq(TLSlaveParameters.v2(
name = Some(s"${name}_manager"),
address = Seq(AddressSet(unifiedAddressRange.base, unifiedAddressRange.size - 1)),
supports = seq.map(_.anySupportClaims).reduce(_ mincover _)
))
)
})
def cast_d[T <: TLBundleD](d: TLBundleD, target_d_t: T): T = {
val new_d = Wire(target_d_t.cloneType)
d.elements.foreach { case (name, data) =>
val new_d_field = new_d.elements.filter(_._1 == name).head._2
new_d_field := data.asTypeOf(new_d_field)
}
new_d
}
def cast_d[T <: DecoupledIO[TLBundleD]](ds: Seq[DecoupledIO[TLBundleD]], target_d_t: T): Seq[T] = {
ds.map { d =>
val new_d = Wire(target_d_t.cloneType)
new_d.valid := d.valid
new_d.bits := cast_d(d.bits, target_d_t.bits)
d.ready := new_d.ready
new_d
}
}
lazy val module = new LazyModuleImp(this) {
val (c, c_edge) = node.in.head
val a = node.out.init.map(_._1)
val ua = node.out.last._1
val a_aligned = filters.map(_.contains(c.a.bits.address))
(a zip a_aligned).foreach { case (a, aligned) =>
a.a.bits := c.a.bits
a.a.valid := c.a.valid && aligned
}
ua.a.bits := c.a.bits
ua.a.bits.source := c.a.bits.source + (1.U << c.a.bits.source.getWidth)
ua.a.valid := c.a.valid && !a_aligned.reduce(_ || _)
c.a.ready := MuxCase(ua.a.ready, (a zip a_aligned).map { case (a, aligned) => aligned -> a.a.ready })
TLArbiter.robin(c_edge, c.d, cast_d(a.map(_.d) ++ Seq(ua.d), c.d): _*)
}
}
object AlignFilterNode {
def apply(filters: Seq[AddressSet])(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
LazyModule(new AlignFilterNode(filters)).node
}
}

View File

@@ -9,7 +9,8 @@ import freechips.rocketchip.util.BundleField
import org.chipsalliance.cde.config.Parameters
class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) extends LazyModule {
class RWSplitterNode(visibility: Option[AddressSet], override val name: String = "rw_splitter")
(implicit p: Parameters) extends LazyModule {
// this node accepts both read and write requests,
// splits & arbitrates them into one client node per type of operation;
// there will be N incoming edges, two outgoing edges, with two N:1 muxes;
@@ -21,11 +22,9 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
val write_src_range = read_src_range.shift(read_src_range.size)
val visibilities = seq.flatMap(_.masters.flatMap(_.visibility))
val vis_min = visibilities.map(_.base).min
val vis_max = visibilities.map{ x => x.base + x.mask }.max
val vis_mask = vis_max - vis_min
require(isPow2(vis_mask + 1) || vis_mask == -1)
println(f"combined visibilities of splitter memory node clients: ${vis_min}, ${vis_mask}")
val unified_vis = if (visibilities.map(_ == AddressSet.everything).reduce(_ || _)) Seq(AddressSet.everything)
else AddressSet.unify(visibilities)
println(s"$name has input visibilities $visibilities, unified to $unified_vis")
seq.head.v1copy(
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
@@ -36,7 +35,7 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
TLMasterParameters.v1(
name = s"${name}_read_client",
sourceId = read_src_range,
visibility = Seq(AddressSet(vis_min, vis_mask)),
visibility = Seq(visibility.getOrElse(unified_vis)),
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsPutFull = TransferSizes.none,
@@ -45,7 +44,7 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
TLMasterParameters.v1(
name = s"${name}_write_client",
sourceId = write_src_range,
visibility = Seq(AddressSet(vis_min, vis_mask)),
visibility = Seq(visibility.getOrElse(unified_vis)),
supportsProbe = TransferSizes.mincover(
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
supportsGet = TransferSizes.none,
@@ -57,6 +56,10 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
},
managerFn = { seq =>
// val fifoIdFactory = TLXbar.relabeler()
println(f"combined address range of $name managers: " +
f"${AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address)))}, supports:" +
f"${seq.map(_.anySupportClaims).reduce(_ mincover _)}")
seq.head.v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
@@ -65,11 +68,7 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
managers = Seq(TLSlaveParameters.v2(
name = Some(s"${name}_manager"),
address = AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address))),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsGet))),
putFull = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutFull))),
putPartial = TransferSizes.mincover(seq.flatMap(_.slaves.map(_.supportsPutPartial)))
),
supports = seq.map(_.anySupportClaims).reduce(_ mincover _),
fifoId = Some(0),
))
)
@@ -79,8 +78,7 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
lazy val module = new LazyModuleImp(this) {
val u_out = node.out
val u_in = node.in
assert(u_out.length == 2)
println(f"${name} has ${u_in.length} incoming client(s)")
assert(u_out.length == 2, s"$name should have 2 outgoing edges but has ${u_out.length}")
val r_out = u_out.head
val w_out = u_out.last
@@ -154,10 +152,16 @@ class RWSplitterNode(name: String = "rw_splitter")(implicit p: Parameters) exten
object RWSplitterNode {
def apply()(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
LazyModule(new RWSplitterNode(name = valName.name)).node
LazyModule(new RWSplitterNode(None, name = valName.name)).node
}
def apply(name: String)(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
LazyModule(new RWSplitterNode(name = name)).node
def apply(visibility: AddressSet)
(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
apply(visibility, valName.name)
}
def apply(visibility: AddressSet, name: String)
(implicit p: Parameters, valName: ValName, sourceInfo: SourceInfo): TLNexusNode = {
LazyModule(new RWSplitterNode(Some(visibility), name = name)).node
}
}

View File

@@ -19,7 +19,8 @@ case class RadianceClusterParams(
) extends InstantiableClusterParams[RadianceCluster] {
val baseName = "radiance_cluster"
val uniqueName = s"${baseName}_$clusterId"
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByClusterIdImpl)(implicit p: Parameters): RadianceCluster = {
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByClusterIdImpl)
(implicit p: Parameters): RadianceCluster = {
new RadianceCluster(this, crossing.crossingType, lookup)
}
}
@@ -41,7 +42,8 @@ class RadianceCluster (
// val numLsuLanes = 4 // FIXME: hardcoded
val wordSize = 4
val gemminis = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).asInstanceOf[Iterable[GemminiTile]]
// must toSeq here, otherwise Iterable is lazy and will break diplomacy
val gemminis = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).toSeq.asInstanceOf[Seq[GemminiTile]]
require(gemminis.size == 1, "there should be one and only one gemmini per cluster")
val gemmini = gemminis.head.gemmini
val gemminiTile = gemminis.head
@@ -50,7 +52,7 @@ class RadianceCluster (
val max_write_width_bytes = gemminiConfig.dma_buswidth / 8
val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).asInstanceOf[Iterable[RadianceTile]]
val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).toSeq.asInstanceOf[Seq[RadianceTile]]
val numCores = leafTiles.size - gemminis.size
@@ -79,12 +81,15 @@ class RadianceCluster (
val smem_size = smem_width * smem_depth * smem_banks
val stride_by_word = true
val filter_aligned = true
val disable_monitors = false // otherwise it generate 1k+ different tl monitors
val radiance_smem_fanout = radianceTiles.flatMap {
_.smemNodes.map { m =>
val smem_fanout_xbar = TLXbar()
smem_fanout_xbar :=* m
smem_fanout_xbar
val radiance_smem_fanout = radianceTiles.zipWithIndex.flatMap { case (tile, cid) =>
tile.smemNodes.zipWithIndex.map { case (m, lid) =>
val smem_fanout_xbar = LazyModule(new TLXbar())
smem_fanout_xbar.suggestName(f"rad_smem_fanout_cl${thisClusterParams.clusterId}_c${cid}_l${lid}_xbar")
smem_fanout_xbar.node :=* m
smem_fanout_xbar.node
}
}
@@ -151,80 +156,121 @@ class RadianceCluster (
}
}
def connect_one[T <: BaseNode with TLNode](from: TLNode, to: () => T): T = {
val t = to()
if (disable_monitors) {
DisableMonitors { implicit p => t := from}
} else {
t := from
}
t
}
def connect_xbar(from: TLNode): TLNode = {
val t = TLXbar()
if (disable_monitors) {
DisableMonitors { implicit p => t := from}
} else {
t := from
}
t
}
if (stride_by_word) {
// ask if you need to deal with this, it's not supposed to be readable
val spad_read_nodes = Seq.fill(smem_banks) {
val r_dist = DistributorNode(from = smem_width, to = wordSize)
r_dist := gemmini.spad_read_nodes
Seq.fill(smem_subbanks) {
val id_node = TLIdentityNode()
id_node := r_dist
id_node
}
Seq.fill(smem_subbanks) { connect_one(r_dist, TLIdentityNode.apply) }
}
val spad_write_nodes = Seq.fill(smem_banks) {
val w_dist = DistributorNode(from = smem_width, to = wordSize)
w_dist := gemmini.spad_write_nodes
Seq.fill(smem_subbanks) {
val id_node = TLIdentityNode()
id_node := w_dist
id_node
}
Seq.fill(smem_subbanks) { connect_one(w_dist, TLIdentityNode.apply) }
}
val ws_dist = DistributorNode(from = smem_width, to = wordSize)
ws_dist := gemmini.spad.spad_writer.node // this is the dma write node
val spad_sp_write_nodes = Seq.fill(smem_subbanks) {
val ws_xbar = TLXbar() // fanout to 4 banks
ws_xbar := ws_dist
ws_xbar
}
val spad_sp_write_nodes = Seq.fill(smem_subbanks) { connect_xbar(ws_dist) }
// spad_read_nodes.flatten.foreach(node => unified_mem_read_node :=* node)
// spad_write_nodes.flatten.foreach(node => unified_mem_write_node :=* node)
// spad_sp_write_nodes.foreach(node => unified_mem_write_node :=* node)
// unified_mem_write_node :=* DistributorNode(from = smem_width, to = wordSize) :=* gemmini.spad.spad_writer.node // this is the dma write node
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes
val (uniform_r_nodes, uniform_w_nodes, nonuniform_r_nodes, nonuniform_w_nodes):
(Seq[Seq[Seq[TLNode]]], Seq[Seq[Seq[TLNode]]], Seq[TLNode], Seq[TLNode]) = if (filter_aligned) {
// these nodes access an entire line simultaneously
val uniform_r_nodes: Seq[Seq[Seq[TLNode]]] = spad_read_nodes.map { rb =>
rb.map { rw => Seq(rw) }
}
val uniform_w_nodes: Seq[Seq[Seq[TLNode]]] = spad_write_nodes.map { wb =>
(wb zip spad_sp_write_nodes).map { case (ww, sw) => Seq(ww, sw) }
}
val num_lanes = radianceTiles.head.numCoreLanes
val num_lsu_lanes = radianceTiles.head.numLsuLanes
assert(num_lanes >= smem_subbanks)
val splitter_nodes = radiance_smem_fanout.map { m =>
val splitter_node = RWSplitterNode()
splitter_node := m
splitter_node
// since num lanes >= num subbanks, should be only one filter node per core/lane
val filter_nodes: Seq[Seq[(TLNode, TLNode)]] = Seq.tabulate(smem_subbanks) { wid =>
val address = AddressSet(smem_base + wordSize * wid, (smem_size - 1) - (smem_subbanks - 1) * wordSize)
radiance_smem_fanout.grouped(num_lsu_lanes).toList.zipWithIndex.flatMap { case (lanes, cid) =>
lanes.zipWithIndex.flatMap { case (lane, lid) =>
if ((lid % smem_subbanks) == wid) {
println(f"c${cid}_l${lid} connected to w${wid}")
val filter_node = AlignFilterNode(Seq(address))(p, valName = ValName(s"filter_l${lid}_w$wid"), info)
DisableMonitors { implicit p => filter_node := lane }
// Seq((aligned splitter, unaligned splitter))
Seq((connect_one(filter_node, () => RWSplitterNode(address, s"aligned_splitter_c${cid}_l${lid}_w$wid")),
connect_one(filter_node, () => RWSplitterNode(AddressSet.everything, s"unaligned_splitter_c${cid}_l${lid}_w$wid"))))
} else Seq()
}
}
}
val f_aligned = Seq.fill(2)(filter_nodes.map(_.map(_._1).map(connect_xbar)))
val f_unaligned = Seq.fill(2)(filter_nodes.map(_.map(_._2).map(connect_xbar)))
val uniform_r_nodes: Seq[Seq[Seq[TLNode]]] = spad_read_nodes.map { rb =>
(rb zip f_aligned.head).map { case (rw, fa) => Seq(rw) ++ fa }
}
val uniform_w_nodes: Seq[Seq[Seq[TLNode]]] = spad_write_nodes.map { wb =>
(wb lazyZip spad_sp_write_nodes lazyZip f_aligned.last).map {
case (ww, sw, fa) => Seq(ww, sw) ++ fa
}
}
// all to all xbar
val Seq(nonuniform_r_nodes, nonuniform_w_nodes) = f_unaligned.map(_.flatten)
(uniform_r_nodes, uniform_w_nodes, nonuniform_r_nodes, nonuniform_w_nodes)
} else {
val splitter_nodes = radiance_smem_fanout.map { connect_one(_, RWSplitterNode.apply) }
// these nodes access an entire line simultaneously
val uniform_r_nodes: Seq[Seq[Seq[TLNode]]] = spad_read_nodes.map { rb =>
rb.map { rw => Seq(rw) }
}
val uniform_w_nodes: Seq[Seq[Seq[TLNode]]] = spad_write_nodes.map { wb =>
(wb zip spad_sp_write_nodes).map { case (ww, sw) => Seq(ww, sw) }
}
// these nodes are random access
val nonuniform_r_nodes: Seq[TLNode] = splitter_nodes.map(connect_xbar)
val nonuniform_w_nodes: Seq[TLNode] = splitter_nodes.map(connect_xbar)
(uniform_r_nodes, uniform_w_nodes, nonuniform_r_nodes, nonuniform_w_nodes)
}
radiance_smem_fanout.foreach(clbus.inwardNode := _)
// these nodes are random access
val nonuniform_r_nodes: Seq[TLNode] = splitter_nodes.map { s =>
val nu_r_xbar = TLXbar()
nu_r_xbar := s
nu_r_xbar
}.toSeq
val nonuniform_w_nodes: Seq[TLNode] = splitter_nodes.map { s =>
val nu_w_xbar = TLXbar()
nu_w_xbar := s
nu_w_xbar
}.toSeq
smem_bank_mgrs.grouped(smem_subbanks).zipWithIndex.foreach { case (bank_mgrs, bid) =>
bank_mgrs.zipWithIndex.foreach { case (Seq(r, w), wid) =>
// TODO: this should be a coordinated round robin
val subbank_r_xbar = TLXbar(TLArbiter.lowestIndexFirst)
val subbank_w_xbar = TLXbar(TLArbiter.lowestIndexFirst)
r := subbank_r_xbar
w := subbank_w_xbar
uniform_r_nodes(bid)(wid).foreach( subbank_r_xbar := _ )
uniform_w_nodes(bid)(wid).foreach( subbank_w_xbar := _ )
nonuniform_r_nodes.foreach( subbank_r_xbar := _ )
nonuniform_w_nodes.foreach( subbank_w_xbar := _ )
def connect_smem_banks(): Unit = {
r := subbank_r_xbar
w := subbank_w_xbar
uniform_r_nodes(bid)(wid).foreach( subbank_r_xbar := _ )
uniform_w_nodes(bid)(wid).foreach( subbank_w_xbar := _ )
nonuniform_r_nodes.foreach( subbank_r_xbar := _ )
nonuniform_w_nodes.foreach( subbank_w_xbar := _ )
}
if (disable_monitors) {
DisableMonitors(_ => connect_smem_banks())
} else {
connect_smem_banks()
}
}
}
} else {
@@ -357,7 +403,8 @@ class RadianceClusterModuleImp(outer: RadianceCluster) extends ClusterModuleImp(
when (r_node.d.ready && sram_read_backup_reg.valid && !data_pipe.valid) {
sram_read_backup_reg.valid := false.B
}
assert(!(sram_read_backup_reg.valid && data_pipe.valid && data_pipe_in.fire)) // must empty backup before filling data pipe
// must empty backup before filling data pipe
assert(!(sram_read_backup_reg.valid && data_pipe.valid && data_pipe_in.fire))
assert(data_pipe_in.valid === data_pipe_in.fire)
r_node.d.bits := r_edge.AccessAck(