tensor: Do proper source generation
SourceGenerator keeps on givin'
This commit is contained in:
@@ -8,8 +8,9 @@ import chisel3.util._
|
||||
import org.chipsalliance.cde.config.Parameters
|
||||
import org.chipsalliance.diplomacy.lazymodule.{LazyModule, LazyModuleImp}
|
||||
import freechips.rocketchip.tilelink._
|
||||
import freechips.rocketchip.diplomacy.AddressSet
|
||||
import freechips.rocketchip.diplomacy.{IdRange, AddressSet}
|
||||
import freechips.rocketchip.unittest.{UnitTest, UnitTestModule}
|
||||
import radiance.memory.SourceGenerator
|
||||
|
||||
case class TensorTilingParams(
|
||||
// Dimension of the SMEM tile
|
||||
@@ -26,11 +27,13 @@ case class TensorTilingParams(
|
||||
class TensorCoreDecoupled(
|
||||
val numWarps: Int,
|
||||
val numLanes: Int,
|
||||
val numSourceIds: Int,
|
||||
val tilingParams: TensorTilingParams
|
||||
) extends Module {
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val wordSize = 4 // TODO FP16
|
||||
val dataWidth = numLanes * wordSize // TODO FP16
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val initiate = Flipped(Decoupled(new Bundle {
|
||||
@@ -40,10 +43,10 @@ class TensorCoreDecoupled(
|
||||
val wid = UInt(numWarpBits.W)
|
||||
val last = Bool()
|
||||
})
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(dataWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(dataWidth)))
|
||||
val reqA = Decoupled(new TensorMemReq)
|
||||
val reqB = Decoupled(new TensorMemReq)
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
||||
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
||||
})
|
||||
dontTouch(io)
|
||||
|
||||
@@ -106,13 +109,25 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
|
||||
// memory traffic generation
|
||||
io.reqA.valid := (state === TensorState.run) // FIXME
|
||||
io.reqA.bits.address := 0.U // FIXME
|
||||
val genReq = (state === TensorState.run)
|
||||
|
||||
List((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
|
||||
case (req, resp) => {
|
||||
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
|
||||
|
||||
sourceGen.io.gen := req.fire
|
||||
sourceGen.io.meta := DontCare
|
||||
req.valid := genReq
|
||||
req.bits.address := 0.U // FIXME
|
||||
req.bits.source := sourceGen.io.id.bits
|
||||
|
||||
sourceGen.io.reclaim.valid := resp.fire
|
||||
sourceGen.io.reclaim.bits := resp.bits.source
|
||||
}
|
||||
}
|
||||
|
||||
io.respA.ready := true.B
|
||||
io.respB.ready := true.B
|
||||
// FIXME
|
||||
io.reqB.valid := false.B
|
||||
io.reqB.bits := DontCare
|
||||
|
||||
// state transition logic
|
||||
switch(state) {
|
||||
@@ -150,12 +165,17 @@ class TensorCoreDecoupled(
|
||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
}
|
||||
|
||||
class TensorMemReq extends Bundle {
|
||||
// TODO: tag
|
||||
class TensorMemReq(
|
||||
sourceWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val address = UInt(32.W)
|
||||
}
|
||||
class TensorMemResp(val dataWidth: Int) extends Bundle {
|
||||
// TODO: tag
|
||||
class TensorMemResp(
|
||||
sourceWidth: Int,
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val data = UInt(32.W)
|
||||
}
|
||||
|
||||
@@ -164,18 +184,20 @@ class TensorMemResp(val dataWidth: Int) extends Bundle {
|
||||
// wraps TensorCoreDecoupled with TileLink client node for use in a Diplomacy
|
||||
// network.
|
||||
class TensorCoreDecoupledTL(implicit p: Parameters) extends LazyModule {
|
||||
val numSrcIds = 4
|
||||
|
||||
// node with two edges; one for A and one for B matrix
|
||||
val node = TLClientNode(Seq(
|
||||
TLMasterPortParameters.v2(
|
||||
Seq(TLMasterParameters.v2(
|
||||
name = "TensorCoreDecoupledMatrixANode",
|
||||
// sourceId : TODO
|
||||
sourceId = IdRange(0, numSrcIds)
|
||||
))
|
||||
),
|
||||
TLMasterPortParameters.v2(
|
||||
Seq(TLMasterParameters.v2(
|
||||
name = "TensorCoreDecoupledMatrixBNode",
|
||||
// sourceId : TODO
|
||||
sourceId = IdRange(0, numSrcIds)
|
||||
))
|
||||
)
|
||||
))
|
||||
@@ -185,42 +207,42 @@ class TensorCoreDecoupledTL(implicit p: Parameters) extends LazyModule {
|
||||
|
||||
class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
extends LazyModuleImp(outer) with UnitTestModule {
|
||||
val tensor = Module(new TensorCoreDecoupled(8, 8, TensorTilingParams()))
|
||||
val wordSize = 4 // FIXME: hardcoded
|
||||
|
||||
require(outer.node.out.length == 2/*A and B*/)
|
||||
|
||||
val (tlOut, edge) = outer.node.out(0)
|
||||
val (tlOutB, edgeB) = outer.node.out(1)
|
||||
val tensor = Module(new TensorCoreDecoupled(
|
||||
8, 8, outer.numSrcIds , TensorTilingParams()))
|
||||
val wordSize = 4 // FIXME: hardcoded
|
||||
|
||||
val zip = List((outer.node.out(0), tensor.io.reqA),
|
||||
(outer.node.out(1), tensor.io.reqB))
|
||||
zip.foreach { case ((tl, edge), req) =>
|
||||
tl.a.valid := req.valid
|
||||
val (legal, bits) = edge.Get(
|
||||
fromSource = 0.U, // TODO: sourceGen.io.id.bits,
|
||||
fromSource = req.bits.source,
|
||||
toAddress = req.bits.address,
|
||||
lgSize = log2Ceil(wordSize).U
|
||||
)
|
||||
tl.a.bits := bits
|
||||
req.ready := tl.a.ready
|
||||
when(tl.a.fire) {
|
||||
assert(legal, "illegal TL req gen")
|
||||
}
|
||||
}
|
||||
|
||||
tensor.io.respA.valid := tlOut.d.valid
|
||||
tensor.io.respA.bits.data := tlOut.d.bits.data
|
||||
// TODO: tensor.io.respA.bits.source := tlOut.d.bits.source
|
||||
// TODO: dedup A and B
|
||||
val (tlOutA, _) = outer.node.out(0)
|
||||
val (tlOutB, _) = outer.node.out(1)
|
||||
tensor.io.respA.valid := tlOutA.d.valid
|
||||
tensor.io.respA.bits.data := tlOutA.d.bits.data
|
||||
tensor.io.respA.bits.source := tlOutA.d.bits.source
|
||||
tlOutA.d.ready := tensor.io.respA.ready
|
||||
tensor.io.respB.valid := tlOutB.d.valid
|
||||
tensor.io.respB.bits.data := tlOutB.d.bits.data
|
||||
tensor.io.respB.bits.source := tlOutB.d.bits.source
|
||||
tlOutB.d.ready := tensor.io.respB.ready
|
||||
|
||||
tensor.io.initiate.valid := io.start
|
||||
tensor.io.initiate.bits.wid := 0.U
|
||||
// TODO
|
||||
tensor.io.respA.valid := false.B
|
||||
tensor.io.respA.bits := DontCare
|
||||
tensor.io.respB.valid := false.B
|
||||
tensor.io.respB.bits := DontCare
|
||||
tensor.io.reqA.ready := true.B
|
||||
tensor.io.reqB.ready := true.B
|
||||
tensor.io.initiate.bits.wid := 0.U // FIXME
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid
|
||||
|
||||
Reference in New Issue
Block a user