emergency push

This commit is contained in:
Richard Yan
2024-10-21 13:50:26 -07:00
parent ffdabf9184
commit 8307d8d154
3 changed files with 332 additions and 131 deletions

View File

@@ -0,0 +1,134 @@
package radiance.memory
import chisel3._
import chisel3.util._
import freechips.rocketchip.diplomacy.{AddressSet, TransferSizes, IdRange}
import freechips.rocketchip.tilelink._
import freechips.rocketchip.util.BundleField
import org.chipsalliance.cde.config.Parameters
import org.chipsalliance.diplomacy.ValName
import org.chipsalliance.diplomacy.lazymodule._
class DuplicatorNode(override val name: String = "dup")
(implicit p: Parameters) extends LazyModule {
// tilelink node that has two identical managers for parallelizing request processing
// one of the two managers must deassert ready when A channel is valid
val node = TLNexusNode(
clientFn = { seq =>
val inMapping = TLXbar.mapInputIds(seq)
val sourceRange = IdRange(inMapping.map(_.start).min, inMapping.map(_.end).max)
assert((sourceRange.start == 0) && isPow2(sourceRange.end))
val visibilities = seq.flatMap(_.masters.flatMap(_.visibility))
val unifiedVis = if (visibilities.map(_ == AddressSet.everything).reduce(_ || _)) Seq(AddressSet.everything)
else AddressSet.unify(visibilities)
seq.head.v1copy(
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
responseKeys = seq.flatMap(_.responseKeys).distinct,
minLatency = seq.map(_.minLatency).min,
clients = Seq.tabulate(2) { i =>
TLMasterParameters.v1(
name = s"${name}_read_client",
sourceId = sourceRange.shift(sourceRange.size * i),
visibility = unifiedVis,
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
)
}
)
},
managerFn = { seq =>
println(f"combined address range of $name managers: " +
f"${AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address)))}, supports:" +
f"${seq.map(_.anySupportClaims).reduce(_ mincover _)}")
seq.head.v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
minLatency = seq.map(_.minLatency).min,
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
managers = Seq(TLSlaveParameters.v2(
name = Some(s"${name}_manager"),
address = AddressSet.unify(seq.flatMap(_.slaves.flatMap(_.address))),
supports = seq.map(_.anySupportClaims).reduce(_ mincover _),
fifoId = Some(0),
))
)
}
)
lazy val module = new LazyModuleImp(this) {
assert(node.out.length == 2, s"$name should have 2 outgoing edges but has ${node.out.length}")
assert(node.in.length == 1, s"$name should have one incoming edge but has ${node.in.length}")
val inSourceWidth = log2Ceil(node.in.head._2.master.endSourceId)
val inSourceEnd = 1 << inSourceWidth
val nodeIn = node.in.head._1
val nodeOuts = node.out.map(_._1)
val sourceEnq = Wire(DecoupledIO(UInt(inSourceWidth.W)))
sourceEnq.valid := nodeIn.a.valid && nodeOuts.map(_.a.ready).reduce(_ || _)
sourceEnq.bits := nodeIn.a.bits.source
val idQueue = Queue(sourceEnq, entries = 4, pipe = false, flow = false)
val srcMatch = nodeOuts.map(_.d.bits.source(inSourceWidth - 1, 0) === idQueue.bits)
idQueue.ready := nodeIn.d.ready && srcMatch.reduce(_ || _)
assert(sourceEnq.fire === nodeIn.a.fire)
assert(idQueue.fire === nodeIn.d.fire)
(nodeOuts zip srcMatch).foreach { case (o, m) =>
o.a.bits := nodeIn.a.bits
o.a.bits.source := nodeIn.a.bits.source | inSourceEnd.U
o.a.valid := nodeIn.a.valid
nodeIn.d.bits := o.d.bits
nodeIn.d.bits.source := o.d.bits.source(inSourceWidth - 1, 0)
nodeIn.d.valid := o.d.valid
o.d.ready := nodeIn.d.ready && m
}
assert(!(nodeOuts.head.a.ready && nodeOuts.last.a.ready) || !nodeIn.a.valid, "double output fire")
nodeIn.a.ready := nodeOuts.map(_.a.ready).reduce(_ || _) && sourceEnq.ready
}
}
object DuplicatorNode {
def apply()(implicit p: Parameters): TLNexusNode = {
LazyModule(new DuplicatorNode()).node
}
}
class DoubleOutXbar(clients: Seq[TLNode], override val name: String = "2o_xbar")
(implicit p: Parameters) extends LazyModule {
val xbar0 = TLXbar(TLArbiter.lowestIndexFirst)
val xbar1 = TLXbar(TLArbiter.lowestIndexFirst)
implicit val disableMonitors: Boolean = false
val dupedIds = clients.map(connectOne(_, DuplicatorNode.apply)).map { c =>
val id0 = connectOne(c, TLIdentityNode.apply)
val id1 = connectOne(c, TLIdentityNode.apply)
xbar0 := id0
xbar1 := id1
Seq(id0, id1)
}.transpose
lazy val module = new LazyModuleImp(this) {
val id0InReadys = VecInit(dupedIds.head.map(_.in.head._1.a.ready)).asUInt
val id1InValids = VecInit(dupedIds.last.map(_.in.head._1.a.valid)).asUInt
(dupedIds.last.map(_.out.head._1.a.valid) zip (id1InValids & (~id0InReadys).asUInt).asBools)
.foreach { case (o, i) => o := i }
}
}
object DoubleOutXbar {
def apply(clients: Seq[TLNode])(implicit p: Parameters): Seq[TLNode] = {
val doubleOutXbar: DoubleOutXbar = LazyModule(new DoubleOutXbar(clients))
Seq(doubleOutXbar.xbar0, doubleOutXbar.xbar1)
}
}

View File

@@ -16,16 +16,20 @@ import radiance.memory._
import radiance.subsystem.RadianceGemminiDataType.{BF16, FP16, FP32, Int8}
sealed trait RadianceSmemSerialization
case object FullySerialized extends RadianceSmemSerialization
case object CoreSerialized extends RadianceSmemSerialization
case object NotSerialized extends RadianceSmemSerialization
sealed trait MemType
case object TwoPort extends MemType
case object TwoReadOneWrite extends MemType
case class RadianceSharedMemKey(address: BigInt,
size: Int,
numBanks: Int,
numWords: Int,
wordSize: Int = 4,
memType: MemType = TwoPort,
strideByWord: Boolean = true,
filterAligned: Boolean = true,
disableMonitors: Boolean = true,
@@ -197,6 +201,7 @@ class WithRadianceSharedMem(address: BigInt,
size: Int,
numBanks: Int,
numWords: Int,
memType: MemType = TwoPort,
strideByWord: Boolean = true,
filterAligned: Boolean = true,
disableMonitors: Boolean = true,
@@ -205,8 +210,8 @@ class WithRadianceSharedMem(address: BigInt,
case RadianceSharedMemKey => {
require(isPow2(size) && size >= 1024)
Some(RadianceSharedMemKey(
address, size, numBanks, numWords, 4, strideByWord,
filterAligned, disableMonitors, serializeUnaligned
address, size, numBanks, numWords, 4, memType,
strideByWord, filterAligned, disableMonitors, serializeUnaligned
))
}
})

View File

@@ -7,8 +7,9 @@ import org.chipsalliance.cde.config.Parameters
import freechips.rocketchip.tilelink._
import freechips.rocketchip.diplomacy.{AddressSet, TransferSizes}
import gemmini.Pipeline
import radiance.subsystem.RadianceSharedMemKey
import radiance.subsystem.{RadianceSharedMemKey, TwoPort, TwoReadOneWrite}
import radiance.memory._
import scala.collection.mutable.ArrayBuffer
abstract class RadianceSmemNodeProvider {
@@ -49,60 +50,72 @@ class RadianceSharedMem[T <: RadianceSmemNodeProvider](
require(isPow2(smemSubbanks))
(0 until smemBanks).flatMap { bid =>
(0 until smemSubbanks).map { wid =>
Seq(TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bid}_word${wid}_read_mgr"),
address = Seq(AddressSet(
smemBase + (smemDepth * smemWidth * bid) + wordSize * wid,
smemDepth * smemWidth - smemWidth + wordSize - 1
Seq.fill(smemKey.memType match {
case TwoPort => 1
case TwoReadOneWrite => 2
})(
TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bid}_word${wid}_read_mgr"),
address = Seq(AddressSet(
smemBase + (smemDepth * smemWidth * bid) + wordSize * wid,
smemDepth * smemWidth - smemWidth + wordSize - 1
)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(wordSize, wordSize)),
fifoId = Some(0)
)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(wordSize, wordSize)),
fifoId = Some(0)
)),
beatBytes = wordSize
))
), TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bid}_word${wid}_write_mgr"),
address = Seq(AddressSet(
smemBase + (smemDepth * smemWidth * bid) + wordSize * wid,
smemDepth * smemWidth - smemWidth + wordSize - 1
beatBytes = wordSize
)))
) ++ Seq(
TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bid}_word${wid}_write_mgr"),
address = Seq(AddressSet(
smemBase + (smemDepth * smemWidth * bid) + wordSize * wid,
smemDepth * smemWidth - smemWidth + wordSize - 1
)),
supports = TLMasterToSlaveTransferSizes(
putFull = TransferSizes(wordSize, wordSize),
putPartial = TransferSizes(wordSize, wordSize)),
fifoId = Some(0)
)),
supports = TLMasterToSlaveTransferSizes(
putFull = TransferSizes(wordSize, wordSize),
putPartial = TransferSizes(wordSize, wordSize)),
fifoId = Some(0)
)),
beatBytes = wordSize
))))
beatBytes = wordSize
)))
)
}
}
} else {
(0 until smemBanks).map { bank =>
Seq(TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_read_mgr"),
address = Seq(AddressSet(smemBase + (smemDepth * smemWidth * bank),
smemDepth * smemWidth - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, smemWidth)),
fifoId = Some(0)
)),
beatBytes = smemWidth
))
), TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_write_mgr"),
address = Seq(AddressSet(smemBase + (smemDepth * smemWidth * bank),
smemDepth * smemWidth - 1)),
supports = TLMasterToSlaveTransferSizes(
putFull = TransferSizes(1, smemWidth),
putPartial = TransferSizes(1, smemWidth)),
fifoId = Some(0)
)),
beatBytes = smemWidth
))))
Seq.fill(smemKey.memType match {
case TwoPort => 1
case TwoReadOneWrite => 2
})(
TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_read_mgr"),
address = Seq(AddressSet(smemBase + (smemDepth * smemWidth * bank),
smemDepth * smemWidth - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, smemWidth)),
fifoId = Some(0)
)),
beatBytes = smemWidth
)))
) ++ Seq(
TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_write_mgr"),
address = Seq(AddressSet(smemBase + (smemDepth * smemWidth * bank),
smemDepth * smemWidth - 1)),
supports = TLMasterToSlaveTransferSizes(
putFull = TransferSizes(1, smemWidth),
putPartial = TransferSizes(1, smemWidth)),
fifoId = Some(0)
)),
beatBytes = smemWidth
)))
)
}
}
@@ -115,19 +128,15 @@ class RadianceSharedMem[T <: RadianceSmemNodeProvider](
if (strideByWord) {
smemBankMgrs.grouped(smemSubbanks).zipWithIndex.foreach { case (bankMgrs, bid) =>
bankMgrs.zipWithIndex.foreach { case (Seq(r, w), wid) =>
// TODO: this should be a coordinated round robin
val subbankRXbar = LazyModule(new TLXbar(TLArbiter.lowestIndexFirst))
val subbankWXbar = LazyModule(new TLXbar(TLArbiter.lowestIndexFirst))
subbankRXbar.suggestName(s"smem_b${bid}_w${wid}_r_xbar")
subbankWXbar.suggestName(s"smem_b${bid}_w${wid}_w_xbar")
bankMgrs.zipWithIndex.foreach { case (ports, wid) =>
val readPorts = ports.init
val writePort = ports.last
guardMonitors { implicit p =>
r := subbankRXbar.node
w := subbankWXbar.node
val urXbar = XbarWithExtPolicy(Some(s"ur_b${bid}_w${wid}"))
val uwXbar = XbarWithExtPolicy(Some(s"uw_b${bid}_w${wid}"))
// connect policy nodes
val rPolicyNode = ExtPolicyMasterNode(uniformRNodes(bid)(wid).length)
val wPolicyNode = ExtPolicyMasterNode(uniformWNodes(bid)(wid).length)
urXbar.policySlaveNode := rPolicyNode
@@ -135,6 +144,7 @@ class RadianceSharedMem[T <: RadianceSmemNodeProvider](
uniformPolicyNodes.head(bid)(wid) = rPolicyNode
uniformPolicyNodes.last(bid)(wid) = wPolicyNode
// connect clients
(Seq(urXbar, uwXbar) lazyZip uniformNodesIn lazyZip Seq(uniformRNodes, uniformWNodes))
.foreach { case (xbar, idBuf, uNodes) =>
@@ -145,17 +155,33 @@ class RadianceSharedMem[T <: RadianceSmemNodeProvider](
}
}
uniformNodesOut.head(bid)(wid) = TLIdentityNode()
uniformNodesOut.last(bid)(wid) = TLIdentityNode()
subbankRXbar.node := uniformNodesOut.head(bid)(wid) := urXbar.node
subbankWXbar.node := uniformNodesOut.last(bid)(wid) := uwXbar.node
uniformNodesOut.head(bid)(wid) = connectOne(urXbar.node, TLIdentityNode.apply)
uniformNodesOut.last(bid)(wid) = connectOne(uwXbar.node, TLIdentityNode.apply)
nonuniformRNodes.foreach( subbankRXbar.node :=* _ )
nonuniformWNodes.foreach( subbankWXbar.node :=* _ )
// connect memory
smemKey.memType match {
case TwoPort => {
val subbankRXbar = TLXbar(TLArbiter.lowestIndexFirst, Some(s"smem_b${bid}_w${wid}_r_xbar"))
subbankRXbar := uniformNodesOut.head(bid)(wid)
nonuniformRNodes.foreach( subbankRXbar :=* _ )
readPorts.head := subbankRXbar
}
case TwoReadOneWrite => {
val subbankRXbars = DoubleOutXbar(Seq(uniformNodesOut.head(bid)(wid)) ++ nonuniformRNodes)
(readPorts zip subbankRXbars).foreach { case (rp, sbx) => rp := sbx }
}
}
val subbankWXbar = TLXbar(TLArbiter.lowestIndexFirst, Some(s"smem_b${bid}_w${wid}_w_xbar"))
writePort := subbankWXbar
subbankWXbar := uniformNodesOut.last(bid)(wid)
nonuniformWNodes.foreach( subbankWXbar :=* _ )
}
}
}
} else { // not stride by word
require(smemKey.memType == TwoPort, "double read ports not implemented")
val smemRXbar = TLXbar()
val smemWXbar = TLXbar()
@@ -184,23 +210,25 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
val smNodesImp = outer.providerImp.map(impFn => impFn(outer.smNodes))
def makeBuffer[U <: Data](mem: TwoPortSyncMem[U], rNode: TLBundle, rEdge: TLEdgeIn,
wNode: TLBundle, wEdge: TLEdgeIn): Unit = {
mem.io.ren := rNode.a.fire
case class ReadPort[U <: Data](ren: Bool, data: U)
case class WritePort[U <: Data](wen: Bool, data: U, mask: UInt)
val dataPipeIn = Wire(DecoupledIO(mem.io.rdata.cloneType))
dataPipeIn.valid := RegNext(mem.io.ren)
dataPipeIn.bits := mem.io.rdata
def makeReadBuffer[U <: Data](port: ReadPort[U], rNode: TLBundle, rEdge: TLEdgeIn): Unit = {
port.ren := rNode.a.fire
val dataPipeIn = Wire(DecoupledIO(port.data.cloneType))
dataPipeIn.valid := RegNext(port.ren)
dataPipeIn.bits := port.data
val metadataPipeIn = Wire(DecoupledIO(new Bundle {
val source = rNode.a.bits.source.cloneType
val size = rNode.a.bits.size.cloneType
}))
metadataPipeIn.valid := mem.io.ren
metadataPipeIn.valid := port.ren
metadataPipeIn.bits.source := rNode.a.bits.source
metadataPipeIn.bits.size := rNode.a.bits.size
val sramReadBackupReg = RegInit(0.U.asTypeOf(Valid(mem.io.rdata.cloneType)))
val sramReadBackupReg = RegInit(0.U.asTypeOf(Valid(port.data.cloneType)))
val dataPipeInst = Module(new Pipeline(dataPipeIn.bits.cloneType, 1)())
dataPipeInst.io.in <> dataPipeIn
@@ -214,12 +242,12 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
assert(!sramReadBackupReg.valid) // backup reg should be empty
assert(!metadataPipeIn.ready) // metadata should be filled previous cycle
sramReadBackupReg.valid := true.B
sramReadBackupReg.bits := mem.io.rdata
sramReadBackupReg.bits := port.data
}.otherwise {
assert(dataPipeIn.ready || !dataPipeIn.valid) // do not skip any response
}
assert(metadataPipeIn.fire || !mem.io.ren) // when requesting sram, metadata needs to be ready
assert(metadataPipeIn.fire || !port.ren) // when requesting sram, metadata needs to be ready
assert(rNode.d.fire === metadataPipe.fire) // metadata dequeues iff D fires
// when D becomes ready, and data pipe has emptied, time for backup to empty
@@ -238,11 +266,12 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
rNode.a.ready := rNode.d.ready && !(dataPipe.valid && sramReadBackupReg.valid)
dataPipe.ready := rNode.d.ready
metadataPipe.ready := rNode.d.ready
}
// WRITE
mem.io.wen := RegNext(wNode.a.fire)
mem.io.wdata := RegNext(wNode.a.bits.data)
mem.io.mask := RegNext(wNode.a.bits.mask)
def makeWriteBuffer[U <: Data](port: WritePort[U], wNode: TLBundle, wEdge: TLEdgeIn): Unit = {
port.wen := RegNext(wNode.a.fire)
port.data := RegNext(wNode.a.bits.data)
port.mask := RegNext(wNode.a.bits.mask)
val writeResp = Wire(Flipped(wNode.d.cloneType))
writeResp.bits := wEdge.AccessAck(wNode.a.bits)
@@ -251,21 +280,79 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
wNode.d <> Queue(writeResp, 2)
}
// read/write access counter for smem banks
val Seq(smemReadsPerCycle, smemWritesPerCycle) = outer.smemBankMgrs.transpose.map { rw =>
VecInit(rw.map(_.in.head._1.a.fire.asUInt)).reduceTree(_ +& _)
}
val smemReadCounter = RegInit(0.U(32.W))
val smemWriteCounter = RegInit(0.U(32.W))
smemReadCounter := smemReadCounter +& smemReadsPerCycle
smemWriteCounter := smemWriteCounter +& smemWritesPerCycle
dontTouch(smemReadCounter)
dontTouch(smemWriteCounter)
if (outer.strideByWord) {
val uniformFires = Seq.fill(2)(VecInit.fill(outer.smemBanks)(VecInit.fill(outer.smemSubbanks)(false.B)))
// instantiate sram banks and connect
outer.smemBankMgrs.grouped(outer.smemSubbanks).zipWithIndex.foreach { case (bankMgrs, bid) =>
bankMgrs.zipWithIndex.foreach { case (ports, wid) =>
val readPorts = ports.init
val writePort = ports.last
assert(!readPorts.flatMap(_.portParams.map(_.anySupportPutFull)).reduce(_ || _))
assert(!writePort.portParams.map(_.anySupportGet).reduce(_ || _))
val memDepth = outer.smemDepth
val memWidth = outer.smemWidth
val wordWidth = outer.wordSize
outer.smemKey.memType match {
case TwoPort =>
val mem = TwoPortSyncMem(
n = memDepth,
t = UInt((wordWidth * 8).W),
)
// TODO: bring in cluster id
// mem.suggestName(s"rad_smem_cl${outer.thisClusterParams.clusterId}_b${bid}_w${wid}")
val (rNode, rEdge) = readPorts.head.in.head
val (wNode, wEdge) = writePort.in.head
// address format is
// [ smem_base | bank_id | line_id | word_id | byte_offset ]
// line_id is used to index into the SRAMs
mem.io.raddr := (rNode.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U
mem.io.waddr := RegNext((wNode.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U)
assert((bid.U === ((rNode.a.bits.address & (memDepth * memWidth * outer.smemBanks - 1).U) >>
log2Ceil(memDepth * memWidth).U).asUInt) || !rNode.a.valid, "bank id mismatch with request")
assert((wid.U === ((rNode.a.bits.address & (memWidth - 1).U) >>
log2Ceil(wordWidth).U).asUInt) || !rNode.a.valid, "word id mismatch with request")
makeReadBuffer(ReadPort(mem.io.ren, mem.io.rdata), rNode, rEdge)
makeWriteBuffer(WritePort(mem.io.wen, mem.io.wdata, mem.io.mask), wNode, wEdge)
case TwoReadOneWrite =>
val mem = TwoReadOneWriteSyncMem(
n = memDepth,
t = UInt((wordWidth * 8).W),
)
val (rNode0, rEdge0) = readPorts.head.in.head
val (rNode1, rEdge1) = readPorts.last.in.head
val (wNode, wEdge) = writePort.in.head
mem.io.raddr0 := (rNode0.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U
mem.io.raddr1 := (rNode1.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U
mem.io.waddr := RegNext((wNode.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U)
makeReadBuffer(ReadPort(mem.io.ren0, mem.io.rdata0), rNode0, rEdge0)
makeReadBuffer(ReadPort(mem.io.ren1, mem.io.rdata1), rNode1, rEdge1)
makeWriteBuffer(WritePort(mem.io.wen, mem.io.wdata, mem.io.mask), wNode, wEdge)
}
}
}
// set up uniform mux selects
Seq.tabulate(outer.smemBanks) { bid =>
// note down fire here so the round-robin knows when an input is selected
Seq.tabulate(outer.smemSubbanks) { wid =>
(uniformFires zip outer.uniformNodesOut).foreach { case (uf, n) =>
uf(bid)(wid) := n(bid)(wid).in.head._1.a.fire
}
}
// have a uniform hint to all subbanks in a bank
val wordSelects1h = Seq(
Wire(UInt(outer.uniformNodesIn.head(bid).head.length.W)).suggestName(s"ws_r_b${bid}"),
@@ -275,43 +362,6 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
VecInit(wordsInIdx.toSeq).asUInt.orR
}.toSeq).asUInt.suggestName(s"valid_sources_rw${rw}_b${bid}")
}
assert(bankMgrs.flatten.size == 2/* read and write */ * outer.smemSubbanks)
bankMgrs.zipWithIndex.foreach { case (Seq(r, w), wid) =>
assert(!r.portParams.map(_.anySupportPutFull).reduce(_ || _))
assert(!w.portParams.map(_.anySupportGet).reduce(_ || _))
val memDepth = outer.smemDepth
val memWidth = outer.smemWidth
val wordWidth = outer.wordSize
val mem = TwoPortSyncMem(
n = memDepth,
t = UInt((wordWidth * 8).W),
)
// TODO: bring in cluster id
// mem.suggestName(s"rad_smem_cl${outer.thisClusterParams.clusterId}_b${bid}_w${wid}")
val (rNode, rEdge) = r.in.head
val (wNode, wEdge) = w.in.head
// address format is
// [ smem_base | bank_id | line_id | word_id | byte_offset ]
// line_id is used to index into the SRAMs
mem.io.raddr := (rNode.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U
mem.io.waddr := RegNext((wNode.a.bits.address & (memDepth * memWidth - 1).U) >> log2Ceil(memWidth).U)
assert((bid.U === ((rNode.a.bits.address & (memDepth * memWidth * outer.smemBanks - 1).U) >>
log2Ceil(memDepth * memWidth).U).asUInt) || !rNode.a.valid, "bank id mismatch with request")
assert((wid.U === ((rNode.a.bits.address & (memWidth - 1).U) >>
log2Ceil(wordWidth).U).asUInt) || !rNode.a.valid, "word id mismatch with request")
makeBuffer(mem, rNode, rEdge, wNode, wEdge)
(uniformFires zip outer.uniformNodesOut).foreach { case (uf, n) =>
uf(bid)(wid) := n(bid)(wid).in.head._1.a.fire
}
}
// use round robin to decide uniform select
(wordSelects1h zip Seq(validRSources, validWSources)).zipWithIndex.foreach { case ((ws, vs), rw) =>
ws := TLArbiter.roundRobin(vs.getWidth, vs, uniformFires(rw)(bid).asUInt.orR)
@@ -331,7 +381,7 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
}
}
}
// set policy to use the uniform select as hint
(outer.uniformPolicyNodes zip wordSelects1h).zipWithIndex.foreach { case ((nodesBw, ws), rw) =>
nodesBw(bid).foreach { policy =>
policy.out.head._1.hint := ws
@@ -355,8 +405,20 @@ class RadianceSharedMemImp[T <: RadianceSmemNodeProvider](outer: RadianceSharedM
mem.io.raddr := (rNode.a.bits.address ^ outer.smemBase.U) >> log2Ceil(memWidth).U
mem.io.waddr := RegNext((wNode.a.bits.address ^ outer.smemBase.U) >> log2Ceil(memWidth).U)
makeBuffer(mem, rNode, rEdge, wNode, wEdge)
makeReadBuffer(ReadPort(mem.io.ren, mem.io.rdata), rNode, rEdge)
makeWriteBuffer(WritePort(mem.io.wen, mem.io.wdata, mem.io.mask), wNode, wEdge)
}
}
// read/write access counter for smem banks
val smemAccessesPerCycle = outer.smemBankMgrs.transpose.map { rw =>
VecInit(rw.map(_.in.head._1.a.fire.asUInt)).reduceTree(_ +& _)
}
val smemReadCounter = RegInit(0.U(32.W))
val smemWriteCounter = RegInit(0.U(32.W))
smemReadCounter := smemReadCounter +& smemAccessesPerCycle.init.reduce(_ +& _)
smemWriteCounter := smemWriteCounter +& smemAccessesPerCycle.last
dontTouch(smemReadCounter)
dontTouch(smemWriteCounter)
}