tensor: Do proper source generation

SourceGenerator keeps on givin'
This commit is contained in:
Hansung Kim
2024-10-14 21:38:54 -07:00
parent bf6f7210b7
commit 14a640bf2d

View File

@@ -8,8 +8,9 @@ import chisel3.util._
import org.chipsalliance.cde.config.Parameters
import org.chipsalliance.diplomacy.lazymodule.{LazyModule, LazyModuleImp}
import freechips.rocketchip.tilelink._
import freechips.rocketchip.diplomacy.AddressSet
import freechips.rocketchip.diplomacy.{IdRange, AddressSet}
import freechips.rocketchip.unittest.{UnitTest, UnitTestModule}
import radiance.memory.SourceGenerator
case class TensorTilingParams(
// Dimension of the SMEM tile
@@ -26,11 +27,13 @@ case class TensorTilingParams(
class TensorCoreDecoupled(
val numWarps: Int,
val numLanes: Int,
val numSourceIds: Int,
val tilingParams: TensorTilingParams
) extends Module {
val numWarpBits = log2Ceil(numWarps)
val wordSize = 4 // TODO FP16
val dataWidth = numLanes * wordSize // TODO FP16
val sourceWidth = log2Ceil(numSourceIds)
val io = IO(new Bundle {
val initiate = Flipped(Decoupled(new Bundle {
@@ -40,10 +43,10 @@ class TensorCoreDecoupled(
val wid = UInt(numWarpBits.W)
val last = Bool()
})
val respA = Flipped(Decoupled(new TensorMemResp(dataWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(dataWidth)))
val reqA = Decoupled(new TensorMemReq)
val reqB = Decoupled(new TensorMemReq)
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val reqA = Decoupled(new TensorMemReq(sourceWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth))
})
dontTouch(io)
@@ -106,13 +109,25 @@ class TensorCoreDecoupled(
}
// memory traffic generation
io.reqA.valid := (state === TensorState.run) // FIXME
io.reqA.bits.address := 0.U // FIXME
val genReq = (state === TensorState.run)
List((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
case (req, resp) => {
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
sourceGen.io.gen := req.fire
sourceGen.io.meta := DontCare
req.valid := genReq
req.bits.address := 0.U // FIXME
req.bits.source := sourceGen.io.id.bits
sourceGen.io.reclaim.valid := resp.fire
sourceGen.io.reclaim.bits := resp.bits.source
}
}
io.respA.ready := true.B
io.respB.ready := true.B
// FIXME
io.reqB.valid := false.B
io.reqB.bits := DontCare
// state transition logic
switch(state) {
@@ -150,12 +165,17 @@ class TensorCoreDecoupled(
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
}
class TensorMemReq extends Bundle {
// TODO: tag
class TensorMemReq(
sourceWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val address = UInt(32.W)
}
class TensorMemResp(val dataWidth: Int) extends Bundle {
// TODO: tag
class TensorMemResp(
sourceWidth: Int,
dataWidth: Int
) extends Bundle {
val source = UInt(sourceWidth.W)
val data = UInt(32.W)
}
@@ -164,18 +184,20 @@ class TensorMemResp(val dataWidth: Int) extends Bundle {
// wraps TensorCoreDecoupled with TileLink client node for use in a Diplomacy
// network.
class TensorCoreDecoupledTL(implicit p: Parameters) extends LazyModule {
val numSrcIds = 4
// node with two edges; one for A and one for B matrix
val node = TLClientNode(Seq(
TLMasterPortParameters.v2(
Seq(TLMasterParameters.v2(
name = "TensorCoreDecoupledMatrixANode",
// sourceId : TODO
sourceId = IdRange(0, numSrcIds)
))
),
TLMasterPortParameters.v2(
Seq(TLMasterParameters.v2(
name = "TensorCoreDecoupledMatrixBNode",
// sourceId : TODO
sourceId = IdRange(0, numSrcIds)
))
)
))
@@ -185,42 +207,42 @@ class TensorCoreDecoupledTL(implicit p: Parameters) extends LazyModule {
class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
extends LazyModuleImp(outer) with UnitTestModule {
val tensor = Module(new TensorCoreDecoupled(8, 8, TensorTilingParams()))
val wordSize = 4 // FIXME: hardcoded
require(outer.node.out.length == 2/*A and B*/)
val (tlOut, edge) = outer.node.out(0)
val (tlOutB, edgeB) = outer.node.out(1)
val tensor = Module(new TensorCoreDecoupled(
8, 8, outer.numSrcIds , TensorTilingParams()))
val wordSize = 4 // FIXME: hardcoded
val zip = List((outer.node.out(0), tensor.io.reqA),
(outer.node.out(1), tensor.io.reqB))
zip.foreach { case ((tl, edge), req) =>
tl.a.valid := req.valid
val (legal, bits) = edge.Get(
fromSource = 0.U, // TODO: sourceGen.io.id.bits,
fromSource = req.bits.source,
toAddress = req.bits.address,
lgSize = log2Ceil(wordSize).U
)
tl.a.bits := bits
req.ready := tl.a.ready
when(tl.a.fire) {
assert(legal, "illegal TL req gen")
}
}
tensor.io.respA.valid := tlOut.d.valid
tensor.io.respA.bits.data := tlOut.d.bits.data
// TODO: tensor.io.respA.bits.source := tlOut.d.bits.source
// TODO: dedup A and B
val (tlOutA, _) = outer.node.out(0)
val (tlOutB, _) = outer.node.out(1)
tensor.io.respA.valid := tlOutA.d.valid
tensor.io.respA.bits.data := tlOutA.d.bits.data
tensor.io.respA.bits.source := tlOutA.d.bits.source
tlOutA.d.ready := tensor.io.respA.ready
tensor.io.respB.valid := tlOutB.d.valid
tensor.io.respB.bits.data := tlOutB.d.bits.data
tensor.io.respB.bits.source := tlOutB.d.bits.source
tlOutB.d.ready := tensor.io.respB.ready
tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 0.U
// TODO
tensor.io.respA.valid := false.B
tensor.io.respA.bits := DontCare
tensor.io.respB.valid := false.B
tensor.io.respB.bits := DontCare
tensor.io.reqA.ready := true.B
tensor.io.reqB.ready := true.B
tensor.io.initiate.bits.wid := 0.U // FIXME
tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid