diff --git a/src/main/scala/radiance/core/TensorCoreDecoupled.scala b/src/main/scala/radiance/core/TensorCoreDecoupled.scala index e61b542..7fa05ee 100644 --- a/src/main/scala/radiance/core/TensorCoreDecoupled.scala +++ b/src/main/scala/radiance/core/TensorCoreDecoupled.scala @@ -8,8 +8,9 @@ import chisel3.util._ import org.chipsalliance.cde.config.Parameters import org.chipsalliance.diplomacy.lazymodule.{LazyModule, LazyModuleImp} import freechips.rocketchip.tilelink._ -import freechips.rocketchip.diplomacy.AddressSet +import freechips.rocketchip.diplomacy.{IdRange, AddressSet} import freechips.rocketchip.unittest.{UnitTest, UnitTestModule} +import radiance.memory.SourceGenerator case class TensorTilingParams( // Dimension of the SMEM tile @@ -26,11 +27,13 @@ case class TensorTilingParams( class TensorCoreDecoupled( val numWarps: Int, val numLanes: Int, + val numSourceIds: Int, val tilingParams: TensorTilingParams ) extends Module { val numWarpBits = log2Ceil(numWarps) val wordSize = 4 // TODO FP16 val dataWidth = numLanes * wordSize // TODO FP16 + val sourceWidth = log2Ceil(numSourceIds) val io = IO(new Bundle { val initiate = Flipped(Decoupled(new Bundle { @@ -40,10 +43,10 @@ class TensorCoreDecoupled( val wid = UInt(numWarpBits.W) val last = Bool() }) - val respA = Flipped(Decoupled(new TensorMemResp(dataWidth))) - val respB = Flipped(Decoupled(new TensorMemResp(dataWidth))) - val reqA = Decoupled(new TensorMemReq) - val reqB = Decoupled(new TensorMemReq) + val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) + val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth))) + val reqA = Decoupled(new TensorMemReq(sourceWidth)) + val reqB = Decoupled(new TensorMemReq(sourceWidth)) }) dontTouch(io) @@ -106,13 +109,25 @@ class TensorCoreDecoupled( } // memory traffic generation - io.reqA.valid := (state === TensorState.run) // FIXME - io.reqA.bits.address := 0.U // FIXME + val genReq = (state === TensorState.run) + + List((io.reqA, io.respA), (io.reqB, io.respB)).foreach { + case (req, resp) => { + val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds))) + + sourceGen.io.gen := req.fire + sourceGen.io.meta := DontCare + req.valid := genReq + req.bits.address := 0.U // FIXME + req.bits.source := sourceGen.io.id.bits + + sourceGen.io.reclaim.valid := resp.fire + sourceGen.io.reclaim.bits := resp.bits.source + } + } + io.respA.ready := true.B io.respB.ready := true.B - // FIXME - io.reqB.valid := false.B - io.reqB.bits := DontCare // state transition logic switch(state) { @@ -150,12 +165,17 @@ class TensorCoreDecoupled( // val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1)) } -class TensorMemReq extends Bundle { - // TODO: tag +class TensorMemReq( + sourceWidth: Int +) extends Bundle { + val source = UInt(sourceWidth.W) val address = UInt(32.W) } -class TensorMemResp(val dataWidth: Int) extends Bundle { - // TODO: tag +class TensorMemResp( + sourceWidth: Int, + dataWidth: Int +) extends Bundle { + val source = UInt(sourceWidth.W) val data = UInt(32.W) } @@ -164,18 +184,20 @@ class TensorMemResp(val dataWidth: Int) extends Bundle { // wraps TensorCoreDecoupled with TileLink client node for use in a Diplomacy // network. class TensorCoreDecoupledTL(implicit p: Parameters) extends LazyModule { + val numSrcIds = 4 + // node with two edges; one for A and one for B matrix val node = TLClientNode(Seq( TLMasterPortParameters.v2( Seq(TLMasterParameters.v2( name = "TensorCoreDecoupledMatrixANode", - // sourceId : TODO + sourceId = IdRange(0, numSrcIds) )) ), TLMasterPortParameters.v2( Seq(TLMasterParameters.v2( name = "TensorCoreDecoupledMatrixBNode", - // sourceId : TODO + sourceId = IdRange(0, numSrcIds) )) ) )) @@ -185,42 +207,42 @@ class TensorCoreDecoupledTL(implicit p: Parameters) extends LazyModule { class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL) extends LazyModuleImp(outer) with UnitTestModule { - val tensor = Module(new TensorCoreDecoupled(8, 8, TensorTilingParams())) - val wordSize = 4 // FIXME: hardcoded - require(outer.node.out.length == 2/*A and B*/) - val (tlOut, edge) = outer.node.out(0) - val (tlOutB, edgeB) = outer.node.out(1) + val tensor = Module(new TensorCoreDecoupled( + 8, 8, outer.numSrcIds , TensorTilingParams())) + val wordSize = 4 // FIXME: hardcoded val zip = List((outer.node.out(0), tensor.io.reqA), (outer.node.out(1), tensor.io.reqB)) zip.foreach { case ((tl, edge), req) => tl.a.valid := req.valid val (legal, bits) = edge.Get( - fromSource = 0.U, // TODO: sourceGen.io.id.bits, + fromSource = req.bits.source, toAddress = req.bits.address, lgSize = log2Ceil(wordSize).U ) tl.a.bits := bits + req.ready := tl.a.ready when(tl.a.fire) { assert(legal, "illegal TL req gen") } } - tensor.io.respA.valid := tlOut.d.valid - tensor.io.respA.bits.data := tlOut.d.bits.data - // TODO: tensor.io.respA.bits.source := tlOut.d.bits.source + // TODO: dedup A and B + val (tlOutA, _) = outer.node.out(0) + val (tlOutB, _) = outer.node.out(1) + tensor.io.respA.valid := tlOutA.d.valid + tensor.io.respA.bits.data := tlOutA.d.bits.data + tensor.io.respA.bits.source := tlOutA.d.bits.source + tlOutA.d.ready := tensor.io.respA.ready + tensor.io.respB.valid := tlOutB.d.valid + tensor.io.respB.bits.data := tlOutB.d.bits.data + tensor.io.respB.bits.source := tlOutB.d.bits.source + tlOutB.d.ready := tensor.io.respB.ready tensor.io.initiate.valid := io.start - tensor.io.initiate.bits.wid := 0.U - // TODO - tensor.io.respA.valid := false.B - tensor.io.respA.bits := DontCare - tensor.io.respB.valid := false.B - tensor.io.respB.bits := DontCare - tensor.io.reqA.ready := true.B - tensor.io.reqB.ready := true.B + tensor.io.initiate.bits.wid := 0.U // FIXME tensor.io.writeback.ready := true.B io.finished := tensor.io.writeback.valid