From 0989d90dd25c51d8a7dc6b0aaac809868e39024d Mon Sep 17 00:00:00 2001 From: Richard Yan Date: Mon, 7 Oct 2024 02:59:06 -0700 Subject: [PATCH] connect tc nodes and maybe fix distributor node --- src/main/resources/vsrc/vortex | 2 +- .../radiance/memory/DistributorNode.scala | 39 ++++++----- .../scala/radiance/tile/RadianceTile.scala | 67 +++++++++++++++---- .../tile/VirgoSharedMemComponents.scala | 8 ++- src/main/scala/radiance/tile/VortexCore.scala | 9 +++ 5 files changed, 93 insertions(+), 32 deletions(-) diff --git a/src/main/resources/vsrc/vortex b/src/main/resources/vsrc/vortex index da54162..8bf7f39 160000 --- a/src/main/resources/vsrc/vortex +++ b/src/main/resources/vsrc/vortex @@ -1 +1 @@ -Subproject commit da54162241da020807274bd4087844d379d8170e +Subproject commit 8bf7f39f04e6d3cbc47559fdd3cacca0febe9baa diff --git a/src/main/scala/radiance/memory/DistributorNode.scala b/src/main/scala/radiance/memory/DistributorNode.scala index 8f46af8..29ccebd 100644 --- a/src/main/scala/radiance/memory/DistributorNode.scala +++ b/src/main/scala/radiance/memory/DistributorNode.scala @@ -91,13 +91,15 @@ class DistributorNode(from: Int, to: Int)(implicit p: Parameters) extends LazyMo } def partialData: UInt = VecInit(mn.map(_.d).map(d => Mux(d.fire, d.bits.data, 0.U(d.bits.data.getWidth.W)))).asUInt - def partialValid: UInt = VecInit(mn.map(_.d.fire)).asUInt + def partialValid: UInt = VecInit(mn.map(_.d.valid)).asUInt + def partialFire: UInt = VecInit(mn.map(_.d.fire)).asUInt mn.map(_.d.ready).zip(arrived.asBools).foreach { case (r, a) => r := cn.d.ready && (!partialWait || !a) // if waiting for partial response, ready only if not arrived yet } // TODO: might need coverage test for this + cd := DontCare when (!partialWait) { cn.d.valid := false.B partialWait := false.B @@ -109,31 +111,36 @@ class DistributorNode(from: Int, to: Int)(implicit p: Parameters) extends LazyMo assert(cd.data === partialData, "sanity check") }.elsewhen (partialValid.orR) { // at least 1 valid: enter partial valid state, store partial data into regs - partialWait := true.B - arrived := partialValid + partialWait := cn.d.ready // if something fired, enter partial wait + arrived := partialFire cdReg.data := partialData - when (mn.head.d.valid) { setMetadata(cdReg, mn.head.d.bits) } + when (mn.head.d.fire) { setMetadata(cdReg, mn.head.d.bits) } } }.otherwise { cn.d.valid := false.B partialWait := true.B when ((arrived | partialValid).andR) { // all valids received now - when (mn.head.d.valid) { - setMetadata(cd, mn.head.d.bits) - }.otherwise { - cd := cdReg - } cn.d.valid := true.B - cd.data := cdReg.data | partialData - partialWait := false.B - cdReg := 0.U.asTypeOf(cdReg.cloneType) - arrived := 0.U + when (cn.d.ready) { + assert((arrived | partialFire).andR) + when (mn.head.d.valid) { + setMetadata(cd, mn.head.d.bits) + }.otherwise { + cd := cdReg + } + cd.data := cdReg.data | partialData + partialWait := false.B + cdReg := 0.U.asTypeOf(cdReg.cloneType) + arrived := 0.U + } }.elsewhen (partialValid.orR) { // update partial data - arrived := arrived | partialValid - cdReg.data := cdReg.data | partialData - when (mn.head.d.valid) { setMetadata(cdReg, mn.head.d.bits) } + when (cn.d.ready) { + arrived := arrived | partialValid + cdReg.data := cdReg.data | partialData + when (mn.head.d.valid) { setMetadata(cdReg, mn.head.d.bits) } + } } } } diff --git a/src/main/scala/radiance/tile/RadianceTile.scala b/src/main/scala/radiance/tile/RadianceTile.scala index e6b48f3..8ab5984 100644 --- a/src/main/scala/radiance/tile/RadianceTile.scala +++ b/src/main/scala/radiance/tile/RadianceTile.scala @@ -11,6 +11,7 @@ import freechips.rocketchip.diplomacy._ import org.chipsalliance.diplomacy.lazymodule.LazyModule import freechips.rocketchip.prci.{ClockCrossingType, ClockSinkParameters, RationalCrossing} import freechips.rocketchip.regmapper.RegField +import freechips.rocketchip.resources.BigIntHexContext import freechips.rocketchip.rocket._ import freechips.rocketchip.subsystem.HierarchicalElementCrossingParamsLike import freechips.rocketchip.tile._ @@ -275,17 +276,18 @@ class RadianceTile private ( } val tcSmemSize = 32 - val tcSmemNodes = Seq(TLClientNode(Seq(TLMasterPortParameters.v2( - masters = Seq(TLMasterParameters.v2( - name = s"rad_tc_${radianceParams.coreId}", - sourceId = IdRange(0, 1 << smemSourceWidth), - supports = TLSlaveToMasterTransferSizes( - get = TransferSizes(1, tcSmemSize), - putFull = TransferSizes(1, tcSmemSize), - putPartial = TransferSizes(1, tcSmemSize) - ) - )) - )))) + val tcSmemNodes = Seq.tabulate(2) { i => + TLClientNode(Seq(TLMasterPortParameters.v2( + masters = Seq(TLMasterParameters.v2( + name = s"rad_tc_${radianceParams.coreId}_$i", + sourceId = IdRange(0, 1 << smemSourceWidth), + supports = TLSlaveToMasterTransferSizes( + probe = TransferSizes(1, tcSmemSize), + get = TransferSizes(1, tcSmemSize), + ) + )) + ))) + } // combine outgoing per-lane dmemNode into 1 idenity node // @@ -686,7 +688,7 @@ class RadianceTileModuleImp(outer: RadianceTile) outer.smemSourceWidth, new VortexBundleA(tagWidth = outer.smemTagWidth, dataWidth = 32), new VortexBundleD(tagWidth = outer.smemTagWidth, dataWidth = 32), - outer.smemNodes(0).out.head + outer.smemNodes.head.out.head ) ) } @@ -731,6 +733,46 @@ class RadianceTileModuleImp(outer: RadianceTile) } } + def connectTc { + val tcb0 = new { + val addr = core.io.tc_a_bits_address(31, 0) + val aValid = core.io.tc_a_valid(0) + val dReady = core.io.tc_d_ready(0) + } + val tcb1 = new { + val addr = core.io.tc_a_bits_address(63, 32) + val aValid = core.io.tc_a_valid(1) + val dReady = core.io.tc_d_ready(1) + } + val tcBundles = Seq(tcb0, tcb1) + val adapters = (outer.tcSmemNodes zip tcBundles).zipWithIndex.map { case ((node, bundle), i) => + val client = node.out.head + val adapter = Module( + new VortexTLAdapter( + outer.smemSourceWidth, + new VortexBundleA(tagWidth = 1, dataWidth = 32 * 8), + new VortexBundleD(tagWidth = 1, dataWidth = 32 * 8), + client + ) + ) + adapter.io.inReq.bits <> DontCare + adapter.io.inReq.valid := bundle.aValid + adapter.io.inReq.bits.address := bundle.addr + adapter.io.inReq.bits.source := i.U + adapter.io.inReq.bits.size := 5.U + adapter.io.inReq.bits.opcode := TLMessages.Get + adapter.io.inReq.bits.mask := x"ffffffff".U + adapter.io.inResp.ready := bundle.dReady + + client._1.a <> adapter.io.outReq + adapter.io.outResp <> client._1.d + adapter + } + core.io.tc_a_ready := Cat(adapters.last.io.inReq.ready, adapters.head.io.inReq.ready) + core.io.tc_d_valid := Cat(adapters.last.io.inResp.valid, adapters.head.io.inResp.valid) + core.io.tc_d_bits_data := Cat(adapters.last.io.inResp.bits.data, adapters.head.io.inResp.bits.data) + } + def connectBarrier = { require(outer.barrierMasterNode.out.length == 1) // FIXME: bits not flattened @@ -786,6 +828,7 @@ class RadianceTileModuleImp(outer: RadianceTile) connectImem connectDmem connectSmem + connectTc connectBarrier connectAccelerator } diff --git a/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala b/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala index c72fc7f..8ea5c6f 100644 --- a/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala +++ b/src/main/scala/radiance/tile/VirgoSharedMemComponents.scala @@ -54,7 +54,9 @@ class VirgoSharedMemComponents( smemFanoutXbar.node } } - val tcNodeFanouts = radianceTiles.flatMap(_.tcSmemNodes).map(connectXbarName(_, Some("tc_fanout"))) + val tcNodeFanouts = radianceTiles.flatMap(_.tcSmemNodes) + .map(connectOne(_, () => TLBuffer(BufferParams(2, false, false), BufferParams(0)))) + .map(connectXbarName(_, Some("tc_fanout"))) val clBusClients: Seq[TLNode] = radianceSmemFanout val (uniformRNodes, uniformWNodes, nonuniformRNodes, nonuniformWNodes) = @@ -69,7 +71,7 @@ class VirgoSharedMemComponents( dist := node } val fanout = Seq.tabulate(spSubbanks) { w => - val buf = TLBuffer(BufferParams(1, false, true), BufferParams(0)) + val buf = TLBuffer(BufferParams(2, false, false), BufferParams(0)) buf := dist connectXbarName(buf, Some(s"spad_g${gemminiIdx}w${w}_fanout_$suffix")) } @@ -88,7 +90,7 @@ class VirgoSharedMemComponents( // tensor core read nodes val tcDistNodes = Seq.fill(smemBanks)(tcNodeFanouts.map(connectOne(_, () => DistributorNode(smemWidth, wordSize)))) val tcNodes = tcDistNodes.map { tcBank => - Seq.fill(smemSubbanks)(tcBank.map(connectXbarName(_, Some("tc_dist_fanout")))) + Seq.fill(smemSubbanks)(tcBank.map(connectOne(_, () => TLBuffer(BufferParams(2, false, false)))).map(connectXbarName(_, Some("tc_dist_fanout")))) } // (banks, subbanks, tc client) if (filterAligned) { diff --git a/src/main/scala/radiance/tile/VortexCore.scala b/src/main/scala/radiance/tile/VortexCore.scala index d202aaa..1409ddd 100644 --- a/src/main/scala/radiance/tile/VortexCore.scala +++ b/src/main/scala/radiance/tile/VortexCore.scala @@ -90,6 +90,13 @@ class VortexBundle(tile: RadianceTile)(implicit p: Parameters) extends CoreBundl val smem_d_bits_data = Input(UInt((tile.numLsuLanes * 32).W)) val smem_d_ready = Output(UInt((tile.numLsuLanes * 1).W)) + val tc_a_valid = Output(UInt(2.W)) + val tc_a_bits_address = Output(UInt((2 * 32).W)) + val tc_a_ready = Input(UInt(2.W)) + val tc_d_valid = Input(UInt(2.W)) + val tc_d_bits_data = Input(UInt((2 * 32 * 8).W)) + val tc_d_ready = Output(UInt(2.W)) + // FIXME: hardcoded val barrierIdBits = tile.barrierMasterNode.out(0)._2.barrierIdBits val coreIdBits = tile.barrierMasterNode.out(0)._2.numCoreBits @@ -233,6 +240,8 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters) // addResource("/vsrc/vortex/hw/rtl/mem/VX_gbar_arb.sv") // addResource("/vsrc/vortex/hw/rtl/mem/VX_gbar_unit.sv") + addResource("/vsrc/vortex/hw/rtl/mem/VX_tc_bus_if.sv") + addResource("/vsrc/vortex/hw/rtl/libs/VX_allocator.sv") // addResource("/vsrc/vortex/hw/rtl/libs/VX_avs_adapter.sv") // addResource("/vsrc/vortex/hw/rtl/libs/VX_axi_adapter.sv")