add tensor core read client

This commit is contained in:
Richard Yan
2024-10-05 02:48:47 -07:00
parent 2929a84ecc
commit c6df484c00
2 changed files with 22 additions and 2 deletions

View File

@@ -274,6 +274,19 @@ class RadianceTile private (
)
}
val tcSmemSize = 32
val tcSmemNodes = Seq(TLClientNode(Seq(TLMasterPortParameters.v2(
masters = Seq(TLMasterParameters.v2(
name = s"rad_tc_${radianceParams.coreId}",
sourceId = IdRange(0, 1 << smemSourceWidth),
supports = TLSlaveToMasterTransferSizes(
get = TransferSizes(1, tcSmemSize),
putFull = TransferSizes(1, tcSmemSize),
putPartial = TransferSizes(1, tcSmemSize)
)
))
))))
// combine outgoing per-lane dmemNode into 1 idenity node
//
// NOTE: We need TLWidthWidget here because there might be a data width

View File

@@ -54,6 +54,7 @@ class VirgoSharedMemComponents(
smemFanoutXbar.node
}
}
val tcNodeFanouts = radianceTiles.flatMap(_.tcSmemNodes).map(connectXbarName(_, Some("tc_fanout")))
val clBusClients: Seq[TLNode] = radianceSmemFanout
val (uniformRNodes, uniformWNodes, nonuniformRNodes, nonuniformWNodes) =
@@ -84,6 +85,12 @@ class VirgoSharedMemComponents(
val spadSpWriteNodesSingleBank = distAndDuplicate(gemminis.map(_.spad.spad_writer.node), "ws")
val spadSpWriteNodes = Seq.fill(smemBanks)(spadSpWriteNodesSingleBank) // executed only once
// tensor core read nodes
val tcDistNodes = Seq.fill(smemBanks)(tcNodeFanouts.map(connectOne(_, () => DistributorNode(smemWidth, wordSize))))
val tcNodes = tcDistNodes.map { tcBank =>
Seq.fill(smemSubbanks)(tcBank.map(connectXbarName(_, Some("tc_dist_fanout"))))
} // (banks, subbanks, tc client)
if (filterAligned) {
val numLsuLanes = radianceTiles.head.numLsuLanes
val numLaneDupes = Math.max(1, smemSubbanks / numLsuLanes)
@@ -186,8 +193,8 @@ class VirgoSharedMemComponents(
}
val uniformRNodes: Seq[Seq[Seq[TLNexusNode]]] = spadReadNodes.map { rb =>
(rb zip fAligned.head).map { case (rw, fa) => rw ++ fa }
val uniformRNodes: Seq[Seq[Seq[TLNexusNode]]] = (spadReadNodes zip tcNodes).map { case (rb, tcrb) =>
(rb lazyZip tcrb lazyZip fAligned.head).map { case (rw, tcrw, fa) => rw ++ tcrw ++ fa }
}
val uniformWNodes: Seq[Seq[Seq[TLNexusNode]]] = (spadWriteNodes zip spadSpWriteNodes).map { case (wb, wsb) =>
(wb lazyZip wsb lazyZip fAligned.last).map {