Merge branch 'tensor-decoupled'
This commit is contained in:
Submodule src/main/resources/vsrc/vortex updated: cde8da1f3b...32ccdeef01
@@ -5,6 +5,7 @@ package radiance.core
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import chisel3.experimental.requireIsChiselType
|
||||
import org.chipsalliance.cde.config.Parameters
|
||||
import org.chipsalliance.diplomacy.lazymodule.{LazyModule, LazyModuleImp}
|
||||
import freechips.rocketchip.tilelink._
|
||||
@@ -28,12 +29,15 @@ class TensorCoreDecoupled(
|
||||
val numWarps: Int,
|
||||
val numLanes: Int,
|
||||
val numSourceIds: Int,
|
||||
val tilingParams: TensorTilingParams
|
||||
val tilingParams: TensorTilingParams,
|
||||
val numFPRegs: Int = 32
|
||||
) extends Module {
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val wordSize = 4 // TODO FP16
|
||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||
val wordSizeInBits = wordSize * 8 // TODO FP16
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
val dataWidth = numLanes * wordSizeInBits // TODO FP16
|
||||
val numFPRegBits = log2Ceil(numFPRegs)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val initiate = Flipped(Decoupled(new Bundle {
|
||||
@@ -42,7 +46,8 @@ class TensorCoreDecoupled(
|
||||
val writeback = Decoupled(new Bundle {
|
||||
val last = Bool()
|
||||
val wid = UInt(numWarpBits.W)
|
||||
val data = Vec(numLanes, UInt(wordSize.W))
|
||||
val rd = UInt(numFPRegBits.W)
|
||||
val data = Vec(numLanes, UInt((wordSizeInBits).W))
|
||||
})
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
@@ -51,20 +56,32 @@ class TensorCoreDecoupled(
|
||||
})
|
||||
dontTouch(io)
|
||||
|
||||
class TensorMemReq(
|
||||
sourceWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val address = UInt(32.W)
|
||||
}
|
||||
class TensorMemResp(
|
||||
sourceWidth: Int,
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
// mem response after translation from TL source to set/step tag
|
||||
class TensorMemRespWithTag(
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val tag = new TensorMemTag
|
||||
val data = UInt(dataWidth.W)
|
||||
}
|
||||
|
||||
// FSM
|
||||
// ---
|
||||
// This drives the overall pipeline of memory requests, dot-product unit
|
||||
// operations and regfile writeback.
|
||||
|
||||
object TensorState extends ChiselEnum {
|
||||
val idle = Value(0.U)
|
||||
val run = Value(1.U)
|
||||
// All set/step sequencing is complete and the tensor core is holding the
|
||||
// result data until downstream writeback is ready.
|
||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||
val finish = Value(2.U)
|
||||
}
|
||||
val state = RegInit(TensorState.idle)
|
||||
val busy = RegInit(false.B)
|
||||
// Holds the warp id the core is currently working on. Note that we only
|
||||
// support one outstanding warp request
|
||||
@@ -76,15 +93,15 @@ class TensorCoreDecoupled(
|
||||
// steps: i-j iteration
|
||||
val numSteps = (tilingParams.m * tilingParams.n) / (tilingParams.mc * tilingParams.nc)
|
||||
val stepBits = log2Ceil(numSteps)
|
||||
val set = RegInit(0.U(setBits.W))
|
||||
val step = RegInit(0.U(stepBits.W))
|
||||
val lastSet = ((1 << setBits) - 1)
|
||||
val lastStep = ((1 << stepBits) - 1)
|
||||
def setDone(set: UInt) = (set === lastSet.U)
|
||||
def stepDone(step: UInt) = (step === lastStep.U)
|
||||
|
||||
when(io.initiate.fire) {
|
||||
when (io.initiate.fire) {
|
||||
val wid = io.initiate.bits.wid
|
||||
busy := true.B
|
||||
warpReg := wid
|
||||
set := 0.U
|
||||
step := 0.U
|
||||
when(io.writeback.fire) {
|
||||
assert(
|
||||
io.writeback.bits.wid =/= wid,
|
||||
@@ -92,130 +109,463 @@ class TensorCoreDecoupled(
|
||||
)
|
||||
}
|
||||
}
|
||||
when(io.writeback.fire) {
|
||||
|
||||
// TODO: @perf: Instead of waiting until the last writeback, release busy as
|
||||
// soon as the access frontend is complete so that there's a better chance to
|
||||
// saturate the backend with back-to-back HGMMAs. This would require sending
|
||||
// the 'wid' register to backend instead of having it shared with the
|
||||
// frontend.
|
||||
when(io.writeback.fire && io.writeback.bits.last) {
|
||||
busy := false.B
|
||||
}
|
||||
|
||||
// Memory traffic generation
|
||||
// -------------------------
|
||||
//
|
||||
val genReq = (state === TensorState.run)
|
||||
// serialize every HGMMA request
|
||||
io.initiate.ready := !busy
|
||||
|
||||
Seq((io.reqA, io.respA), (io.reqB, io.respB)).foreach {
|
||||
case (req, resp) => {
|
||||
val sourceGen = Module(new SourceGenerator(log2Ceil(numSourceIds)))
|
||||
// ===========================================================================
|
||||
// Access stage
|
||||
// ===========================================================================
|
||||
//
|
||||
// Frontend of the decoupled access/execute pipeline.
|
||||
|
||||
// States
|
||||
//
|
||||
object AccessorState extends ChiselEnum {
|
||||
val idle = Value(0.U)
|
||||
val access = Value(1.U)
|
||||
// All set/step sequencing is complete and the tensor core is holding the
|
||||
// result data until downstream writeback is ready.
|
||||
// FIXME: is this necessary if writeback is decoupled with queues?
|
||||
val finish = Value(2.U)
|
||||
}
|
||||
val state = RegInit(AccessorState.idle)
|
||||
val allReqsDone = WireInit(false.B)
|
||||
dontTouch(allReqsDone)
|
||||
|
||||
switch(state) {
|
||||
is(AccessorState.idle) {
|
||||
when(io.initiate.fire) {
|
||||
state := AccessorState.access
|
||||
}
|
||||
}
|
||||
is(AccessorState.access) {
|
||||
when (allReqsDone) {
|
||||
state := AccessorState.finish
|
||||
}
|
||||
}
|
||||
is(AccessorState.finish) {
|
||||
// FIXME: decouple writeback
|
||||
when(io.writeback.fire) {
|
||||
state := AccessorState.idle
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 'index' is the index of a memory request among the sequence of requests
|
||||
// needed to read a full M-column of A or N-row of B. Its range is [0,m/2)
|
||||
// or [0,n/2), where 2 is the stride can be read in a single request size.
|
||||
require(tilingParams.m == tilingParams.n,
|
||||
"currently only supports square SMEM tile")
|
||||
val numIndices = tilingParams.m / 2/*FIXME:hardcoded?*/
|
||||
val indexBits = log2Ceil(numIndices)
|
||||
val lastIndex = (1 << indexBits) - 1
|
||||
|
||||
class TensorMemTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val index = UInt(indexBits.W)
|
||||
}
|
||||
|
||||
val tagInit = Wire(new TensorMemTag)
|
||||
tagInit.set := 0.U
|
||||
tagInit.index := 0.U
|
||||
val tagA = RegInit(tagInit)
|
||||
val tagB = RegInit(tagInit)
|
||||
|
||||
when (io.reqA.fire) {
|
||||
when (tagA.index === lastIndex.U) {
|
||||
tagA.set := tagA.set + 1.U
|
||||
}
|
||||
tagA.index := tagA.index + 1.U
|
||||
}
|
||||
when (io.reqB.fire) {
|
||||
when (tagB.index === lastIndex.U) {
|
||||
tagB.set := tagB.set + 1.U
|
||||
}
|
||||
tagB.index := tagB.index + 1.U
|
||||
}
|
||||
|
||||
// Address generation
|
||||
//
|
||||
def addressGen(base: UInt, set: UInt, index: UInt): UInt = {
|
||||
// note that both A and B are K-major to facilitate bank conflict-free SMEM
|
||||
// accesses, so that below code applies to both.
|
||||
//
|
||||
// (row,col) coordinate of the compute tile
|
||||
val tileRow = index
|
||||
val tileCol = set
|
||||
// (row,col) coordinate of the starting element of the compute tile
|
||||
val elemRow = index << 1
|
||||
val elemCol = tileCol << log2Ceil(tilingParams.kc)
|
||||
val rowStride = tilingParams.k * wordSize
|
||||
val rowStrideBits = log2Ceil(rowStride)
|
||||
val wordStrideBits = log2Ceil(wordSize)
|
||||
val tileOffset = (elemRow << rowStrideBits) + (elemCol << wordStrideBits)
|
||||
|
||||
base + tileOffset
|
||||
}
|
||||
|
||||
// FIXME: bogus base address
|
||||
val addressA = addressGen(0.U, tagA.set, tagA.index)
|
||||
val addressB = addressGen(0x400.U, tagB.set, tagB.index)
|
||||
|
||||
val lastReqA = (tagA.set === lastSet.U) && (tagA.index === lastIndex.U)
|
||||
val lastReqB = (tagB.set === lastSet.U) && (tagB.index === lastIndex.U)
|
||||
val doneReqA = RegInit(false.B)
|
||||
val doneReqB = RegInit(false.B)
|
||||
when (lastReqA && io.reqA.fire) { doneReqA := true.B }
|
||||
when (lastReqB && io.reqB.fire) { doneReqB := true.B }
|
||||
val genReqA = (state === AccessorState.access) && !doneReqA
|
||||
val genReqB = (state === AccessorState.access) && !doneReqB
|
||||
when (state === AccessorState.finish) {
|
||||
doneReqA := false.B
|
||||
doneReqB := false.B
|
||||
tagA.set := 0.U
|
||||
tagA.index := 0.U
|
||||
tagB.set := 0.U
|
||||
tagB.index := 0.U
|
||||
}
|
||||
|
||||
allReqsDone := doneReqA && doneReqB
|
||||
|
||||
// Request generation
|
||||
//
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
(io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach {
|
||||
case ((req, (resp, respTagged)), i) => {
|
||||
val sourceGen = Module(new SourceGenerator(
|
||||
log2Ceil(numSourceIds),
|
||||
metadata = Some(new TensorMemTag)
|
||||
))
|
||||
|
||||
sourceGen.io.gen := req.fire
|
||||
sourceGen.io.meta := DontCare
|
||||
req.valid := genReq
|
||||
req.bits.address := 0.U // FIXME
|
||||
sourceGen.io.meta := (if (i == 0) tagA else tagB)
|
||||
req.valid := (if (i == 0) genReqA else genReqB)
|
||||
req.bits.address := (if (i == 0) addressA else addressB)
|
||||
req.bits.source := sourceGen.io.id.bits
|
||||
|
||||
sourceGen.io.reclaim.valid := resp.fire
|
||||
sourceGen.io.reclaim.bits := resp.bits.source
|
||||
|
||||
// translate source
|
||||
respTagged.valid := resp.valid
|
||||
respTagged.bits.tag := sourceGen.io.peek
|
||||
respTagged.bits.data := resp.bits.data
|
||||
resp.ready := respTagged.ready
|
||||
}
|
||||
}
|
||||
|
||||
// only advance to the next step if we fired mem requests for both A and B
|
||||
val firedABReg = RegInit(VecInit(false.B, false.B))
|
||||
val firedABNow = VecInit((Seq(io.reqA, io.reqB) zip firedABReg).map {
|
||||
case (req, fired) => { when (req.fire) { fired := true.B } }
|
||||
req.fire
|
||||
})
|
||||
val firedAB = (firedABNow.asUInt | firedABReg.asUInt)
|
||||
val nextStep = firedAB.andR
|
||||
// clear out firedABReg every step. this will overwrite the previous fired
|
||||
// write upon the last fire out of A and B
|
||||
when (nextStep) {
|
||||
firedABReg := Seq(false.B, false.B)
|
||||
}
|
||||
|
||||
io.respA.ready := true.B // FIXME
|
||||
io.respB.ready := true.B // FIXME
|
||||
|
||||
// ===========================================================================
|
||||
// Execute stage
|
||||
// -------------
|
||||
// Execute backend of the decoupled access/execute pipeline.
|
||||
// ===========================================================================
|
||||
//
|
||||
// Backend of the decoupled access/execute pipeline.
|
||||
//
|
||||
val respQueueDepth = 4 // FIXME: parameterize
|
||||
val respQueueA = Queue(io.respA, respQueueDepth)
|
||||
val respQueueB = Queue(io.respB, respQueueDepth)
|
||||
respQueueA.ready := io.writeback.ready // FIXME
|
||||
respQueueB.ready := io.writeback.ready // FIXME
|
||||
require(respQueueDepth >= 4,
|
||||
"respQueueDepth must be at least 4. This is because the B operand buffer " ++
|
||||
"is shallower than A's, so the B response queue has to be deep enough to " ++
|
||||
"hold younger requests until A operand buffer becomes valid and the first DPU " ++
|
||||
"fire can happen. FIXME: make operand buffer report per-subtile valid so " ++
|
||||
"the first compute can happen earlier.")
|
||||
val respQueueA = Queue(respATagged, respQueueDepth)
|
||||
val respQueueB = Queue(respBTagged, respQueueDepth)
|
||||
|
||||
require(respQueueA.bits.data.widthOption.get ==
|
||||
io.writeback.bits.data.widthOption.get * numLanes,
|
||||
"response data width does not match the writeback data width")
|
||||
io.writeback.bits.data.widthOption.get,
|
||||
"response data width does not match the writeback data width")
|
||||
|
||||
// FIXME: debug dummy: pipe A directly to writeback
|
||||
io.writeback.valid := respQueueA.valid
|
||||
val groupedRespA = respQueueA.bits.data.asBools.grouped(wordSize * 8/*bits*/)
|
||||
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
|
||||
wb := VecInit(data).asUInt
|
||||
// FIXME: unnecessary
|
||||
val substepDeqA = RegInit(0.U(1.W))
|
||||
when (respQueueA.fire) {
|
||||
substepDeqA := substepDeqA + 1.U
|
||||
}
|
||||
dontTouch(substepDeqA)
|
||||
|
||||
// Stage the operands in a pipeline so that we obtain the full 4x4 tiles
|
||||
// ready for compute. Also send the set/step tag along the pipe for
|
||||
// alignment check.
|
||||
|
||||
// @cleanup: dedup A and B below
|
||||
|
||||
val fullA = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueB.bits.data), numIndices
|
||||
))
|
||||
fullA.io.enq.valid := respQueueA.valid
|
||||
fullA.io.enq.bits := respQueueA.bits.data
|
||||
respQueueA.ready := fullA.io.enq.ready
|
||||
// `pipe` combinationally couples enq-deq ready
|
||||
val fullATag = Module(new Queue(
|
||||
new TensorMemTag, entries = 1, pipe = true
|
||||
))
|
||||
fullATag.io.enq.valid := respQueueA.valid
|
||||
fullATag.io.enq.bits := respQueueA.bits.tag
|
||||
|
||||
// stage the full A tile once more so that FillBuffer can be filled up in the
|
||||
// background while the tile is being used for compute. This does come with
|
||||
// capacity overhead.
|
||||
val fullABuf = Module(new Queue(
|
||||
new Bundle {
|
||||
val data = chiselTypeOf(fullA.io.deq.bits)
|
||||
val tag = new TensorMemTag
|
||||
}, entries = 1, pipe = true
|
||||
))
|
||||
fullABuf.io.enq.valid := fullA.io.deq.valid
|
||||
fullABuf.io.enq.bits.data := fullA.io.deq.bits
|
||||
fullABuf.io.enq.bits.tag := fullATag.io.deq.bits
|
||||
fullA.io.deq.ready := fullABuf.io.enq.ready
|
||||
fullATag.io.deq.ready := fullABuf.io.enq.ready
|
||||
|
||||
// serialize every two B responses into one full 4x4 B tile
|
||||
// FIXME: do the same for A
|
||||
val fullB = Module(new FillBuffer(
|
||||
chiselTypeOf(respQueueB.bits.data), 2/*substeps*/
|
||||
))
|
||||
fullB.io.enq.valid := respQueueB.valid
|
||||
fullB.io.enq.bits := respQueueB.bits.data
|
||||
respQueueB.ready := fullB.io.enq.ready
|
||||
val fullBTag = Module(new Queue(
|
||||
new TensorMemTag, entries = 1, pipe = true
|
||||
))
|
||||
fullBTag.io.enq.valid := respQueueB.valid
|
||||
fullBTag.io.enq.bits := respQueueB.bits.tag
|
||||
|
||||
val fullBBuf = Module(new Queue(
|
||||
new Bundle {
|
||||
val data = chiselTypeOf(fullB.io.deq.bits)
|
||||
val tag = new TensorMemTag
|
||||
}, entries = 1, pipe = true
|
||||
))
|
||||
fullBBuf.io.enq.valid := fullB.io.deq.valid
|
||||
fullBBuf.io.enq.bits.data := fullB.io.deq.bits
|
||||
fullBBuf.io.enq.bits.tag := fullBTag.io.deq.bits
|
||||
fullB.io.deq.ready := fullBBuf.io.enq.ready
|
||||
fullBTag.io.deq.ready := fullBBuf.io.enq.ready
|
||||
|
||||
val dpuReady = Wire(Bool())
|
||||
val operandsValid = fullABuf.io.deq.valid && fullBBuf.io.deq.valid
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
|
||||
val setCompute = RegInit(0.U(setBits.W))
|
||||
val stepCompute = RegInit(0.U(stepBits.W))
|
||||
val substepCompute = RegInit(0.U(1.W))
|
||||
val nextStepCompute = dpuFire && (substepCompute === 1.U)
|
||||
dontTouch(setCompute)
|
||||
dontTouch(stepCompute)
|
||||
dontTouch(substepCompute)
|
||||
when (dpuFire) {
|
||||
substepCompute := substepCompute + 1.U
|
||||
}
|
||||
|
||||
// Operand selection
|
||||
//
|
||||
// select the correct 4x4 tile from A operand buffer
|
||||
val numTilesM = tilingParams.m / tilingParams.mc
|
||||
val numTilesMBits = log2Ceil(numTilesM)
|
||||
def selectOperandA(buf: Vec[UInt]): UInt = {
|
||||
require(buf.length == numIndices)
|
||||
val stepM = stepCompute & ((1 << numTilesMBits) - 1).U
|
||||
Cat(buf((stepM << 1) + 1.U), buf(stepM << 1))
|
||||
}
|
||||
val operandA = selectOperandA(fullABuf.io.deq.bits.data)
|
||||
val operandATag = fullABuf.io.deq.bits.tag
|
||||
// select the correct 2x4 tile from B operand buffer
|
||||
val operandB = fullBBuf.io.deq.bits.data(substepCompute)
|
||||
val operandBTag = fullBBuf.io.deq.bits.tag
|
||||
dontTouch(operandATag)
|
||||
dontTouch(operandBTag)
|
||||
|
||||
// Operand buffer logic
|
||||
//
|
||||
// hold A data until the entire set is done
|
||||
val shouldDequeueAMask = ((1 << stepBits) - 1).U
|
||||
val shouldDequeueA =
|
||||
((stepCompute & shouldDequeueAMask) === shouldDequeueAMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullABuf.io.deq.ready := dpuFire && shouldDequeueA
|
||||
// hold B tile at respQueueB for multiple steps for reuse, only dequeue when
|
||||
// we fully iterated a column (M-dimension)
|
||||
val shouldDequeueBMask = ((1 << numTilesMBits) - 1).U
|
||||
val shouldDequeueB =
|
||||
((stepCompute & shouldDequeueBMask) === shouldDequeueBMask) &&
|
||||
(substepCompute === 1.U)
|
||||
fullBBuf.io.deq.ready := dpuFire && shouldDequeueB
|
||||
dontTouch(respQueueA)
|
||||
dontTouch(respQueueB)
|
||||
dontTouch(shouldDequeueA)
|
||||
dontTouch(shouldDequeueB)
|
||||
|
||||
// Assert that the DPU is computing with operands of the same set/step. Note
|
||||
// that the B resp will only have step values multiple of 4 due to reuse.
|
||||
//
|
||||
// This check assumes that memory responses come back in-order. Might be too
|
||||
// strong of an assumption depending on the backing memory.
|
||||
def assertAligned = {
|
||||
val stepMask = (1 << numTilesMBits).U
|
||||
when (dpuFire) {
|
||||
assert(fullABuf.io.deq.bits.tag.set === fullBBuf.io.deq.bits.tag.set,
|
||||
"A and B operands are pointing to different sets. " ++
|
||||
"This might indicate memory response coming back out-of-order.")
|
||||
}
|
||||
}
|
||||
assertAligned
|
||||
|
||||
// Dot-product unit
|
||||
//
|
||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||
//
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||
Module(new TensorDotProductUnit(half = false))
|
||||
))
|
||||
|
||||
// reshape operands for easier routing to DPU
|
||||
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
|
||||
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
}
|
||||
val operandADimensional = reshapeByFourWords(operandA)
|
||||
require(operandADimensional.length == tilingParams.mc &&
|
||||
operandADimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
val operandBDimensional = reshapeByFourWords(operandB)
|
||||
require(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
|
||||
for (m <- 0 until tilingParams.mc) {
|
||||
for (n <- 0 until ncSubstep) {
|
||||
dpus(m)(n).io.in.valid := dpuFire
|
||||
dpus(m)(n).io.in.bits.a := operandADimensional(m)
|
||||
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
|
||||
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
|
||||
// dpu ready couples with writeback backpressure
|
||||
dpus(m)(n).io.stall := !io.writeback.ready
|
||||
}
|
||||
}
|
||||
dpuReady := !dpus(0)(0).io.stall
|
||||
dontTouch(dpuFire)
|
||||
dontTouch(dpuReady)
|
||||
|
||||
val dpuValids = dpus.flatMap(_.map(_.io.out.valid))
|
||||
val dpuValid = dpuValids.reduce(_ && _)
|
||||
def assertDPU = {
|
||||
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
|
||||
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
|
||||
"stall signals of DPUs went out of sync")
|
||||
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
|
||||
"valid signals of DPUs went out of sync")
|
||||
}
|
||||
assertDPU
|
||||
|
||||
// flatten DPU output into 1D array in M-major order
|
||||
val flattenedDPUOut = (0 until ncSubstep).flatMap { n =>
|
||||
(0 until tilingParams.mc).map { m =>
|
||||
dpus(m)(n).io.out.bits.data
|
||||
}
|
||||
}
|
||||
io.writeback.bits.data := flattenedDPUOut
|
||||
|
||||
// Writeback logic
|
||||
//
|
||||
// These queues hold metadata needed for writeback in sync with the DPU.
|
||||
|
||||
class TensorComputeTag extends Bundle {
|
||||
val set = UInt(setBits.W)
|
||||
val step = UInt(stepBits.W)
|
||||
val substep = UInt(1.W)
|
||||
}
|
||||
|
||||
val queueDepth = 5 // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||
tagQueue.io.enq.valid := dpuFire
|
||||
tagQueue.io.enq.bits.set := setCompute
|
||||
tagQueue.io.enq.bits.step := stepCompute
|
||||
tagQueue.io.enq.bits.substep := substepCompute
|
||||
tagQueue.io.deq.ready := io.writeback.fire
|
||||
assert(tagQueue.io.enq.ready === true.B,
|
||||
"tag queue full, DPU operation might be throttled")
|
||||
assert(!dpuValid || tagQueue.io.deq.valid,
|
||||
"tag queue and DPU went out of sync")
|
||||
|
||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
|
||||
// note rd is independent to sets
|
||||
def rdGen(step: UInt, substep: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
// thread
|
||||
(step << 1/*2 substeps*/) + substep
|
||||
}
|
||||
|
||||
val setWriteback = tagQueue.io.deq.bits.set
|
||||
val stepWriteback = tagQueue.io.deq.bits.step
|
||||
val substepWriteback = tagQueue.io.deq.bits.substep
|
||||
io.writeback.valid := dpuValid
|
||||
// TODO: decouple wid from frontend
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.rd := rdGen(stepWriteback, substepWriteback)
|
||||
io.writeback.bits.last := setDone(setWriteback) && stepDone(stepWriteback) &&
|
||||
(substepWriteback === 1.U)
|
||||
|
||||
// State transition
|
||||
// ----------------
|
||||
//
|
||||
// set/step sequencing logic
|
||||
val lastSet = ((1 << setBits) - 1)
|
||||
val lastStep = ((1 << stepBits) - 1)
|
||||
val setDone = (set === lastSet.U)
|
||||
val stepDone = (step === lastStep.U)
|
||||
when (nextStep) {
|
||||
step := (step + 1.U) & lastStep.U
|
||||
when (stepDone) {
|
||||
set := (set + 1.U) & lastSet.U
|
||||
}
|
||||
}
|
||||
|
||||
switch(state) {
|
||||
is(TensorState.idle) {
|
||||
when(io.initiate.fire) {
|
||||
state := TensorState.run
|
||||
}
|
||||
}
|
||||
is(TensorState.run) {
|
||||
when (setDone && stepDone && nextStep) {
|
||||
when (state === TensorState.run) {
|
||||
state := TensorState.finish
|
||||
}
|
||||
}
|
||||
}
|
||||
is(TensorState.finish) {
|
||||
when(io.writeback.fire) {
|
||||
state := TensorState.idle
|
||||
def sequenceSetStep(set: UInt, step: UInt, nextStep: Bool) = {
|
||||
when (nextStep) {
|
||||
step := (step + 1.U) & lastStep.U
|
||||
when (stepDone(step)) {
|
||||
set := (set + 1.U) & lastSet.U
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
io.initiate.ready := !busy
|
||||
io.writeback.valid := (state === TensorState.finish)
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.last := false.B // TODO
|
||||
|
||||
// Writeback queues
|
||||
// ----------------
|
||||
// These queues hold the metadata necessary for register
|
||||
// writeback.
|
||||
|
||||
// val queueDepth = 2
|
||||
// val widQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
// val rdQueue = Queue(io.initiate, queueDepth, pipe = (queueDepth == 1))
|
||||
sequenceSetStep(setCompute, stepCompute, nextStepCompute)
|
||||
}
|
||||
|
||||
class TensorMemReq(
|
||||
sourceWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val address = UInt(32.W)
|
||||
}
|
||||
class TensorMemResp(
|
||||
sourceWidth: Int,
|
||||
dataWidth: Int
|
||||
) extends Bundle {
|
||||
val source = UInt(sourceWidth.W)
|
||||
val data = UInt(dataWidth.W)
|
||||
// A buffer that collects multiple entries of input data and exposes the
|
||||
// coalesced data as output. Effectively acts as a width-widening
|
||||
// chisel.util.Pipe.
|
||||
class FillBuffer[T <: Data](
|
||||
gen: T,
|
||||
entries: Int
|
||||
) extends Module {
|
||||
require(entries > 0, "FillBuffer must have a positive number of entries")
|
||||
requireIsChiselType(gen)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
val enq = Flipped(Decoupled(gen))
|
||||
val deq = Decoupled(Vec(entries, gen))
|
||||
})
|
||||
|
||||
val data = Reg(Vec(entries, gen))
|
||||
val ptr = Counter(entries + 1)
|
||||
dontTouch(ptr.value)
|
||||
val full = (ptr.value === entries.U)
|
||||
io.enq.ready := !full
|
||||
when (io.enq.fire) {
|
||||
data(ptr.value) := io.enq.bits
|
||||
ptr.inc()
|
||||
}
|
||||
io.deq.valid := full
|
||||
(io.deq.bits zip data).foreach { case (io, d) => io := d }
|
||||
when (io.deq.fire) {
|
||||
assert(ptr.value === entries.U, "FillBuffer fired before buffer was full")
|
||||
ptr.reset()
|
||||
}
|
||||
}
|
||||
|
||||
// synthesizable unit tests
|
||||
@@ -250,7 +600,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
|
||||
val tensor = Module(new TensorCoreDecoupled(
|
||||
8, 8, outer.numSrcIds , TensorTilingParams()))
|
||||
val wordSize = 4 // FIXME: hardcoded
|
||||
val wordSize = 4 // @cleanup: hardcoded
|
||||
|
||||
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
||||
(outer.node.out(1), tensor.io.reqB))
|
||||
@@ -281,10 +631,14 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
tlOutB.d.ready := tensor.io.respB.ready
|
||||
|
||||
tensor.io.initiate.valid := io.start
|
||||
tensor.io.initiate.bits.wid := 0.U // FIXME
|
||||
tensor.io.initiate.bits.wid := 0.U // TODO
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid
|
||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||
when (io.finished) {
|
||||
// might be too strong
|
||||
assert(tensor.io.writeback.bits.rd === 31.U)
|
||||
}
|
||||
}
|
||||
|
||||
// a minimal Diplomacy graph with a tensor core and a TLRAM
|
||||
@@ -293,7 +647,7 @@ class TensorCoreDecoupledTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
val xbar = LazyModule(new TLXbar)
|
||||
val ram = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x0000, 0xffffff),
|
||||
beatBytes = 32 // FIXME: hardcoded
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
|
||||
ram.node :=* xbar.node :=* tensor.node
|
||||
@@ -305,10 +659,57 @@ class TensorCoreDecoupledTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
}
|
||||
}
|
||||
|
||||
// two separate TLRAMs for A and B for full throughput
|
||||
class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
val tensor = LazyModule(new TensorCoreDecoupledTL)
|
||||
val xbar = LazyModule(new TLXbar)
|
||||
val ramA = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x000, 0xfffbff),
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
val ramB = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x400, 0xfffbff),
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
|
||||
val stutter = new TLIdentityNode
|
||||
xbar.node :=* tensor.node
|
||||
ramA.node := stutter := xbar.node
|
||||
ramB.node := xbar.node
|
||||
|
||||
val fuzz = true
|
||||
|
||||
lazy val module = new Impl
|
||||
class Impl extends LazyModuleImp(this) with UnitTestModule {
|
||||
tensor.module.io.start := io.start
|
||||
io.finished := tensor.module.io.finished
|
||||
|
||||
val (tlIn, _) = stutter.in(0)
|
||||
val (tlOut, _) = stutter.out(0)
|
||||
require(stutter.in.length == 1)
|
||||
require(stutter.out.length == 1)
|
||||
|
||||
// inject stalls for fuzzing
|
||||
val incr = Wire(Bool())
|
||||
val (count, _) = Counter(incr, 0x1000)
|
||||
def cond(x: UInt) = (x & ((1 << 3) - 1).U) =/= 0.U
|
||||
val stall = if (fuzz) cond(count) else false.B
|
||||
|
||||
tlOut.a <> tlIn.a
|
||||
tlIn.d <> tlOut.d
|
||||
incr := tlIn.a.fire || stall
|
||||
when (stall) {
|
||||
tlIn.a.ready := false.B
|
||||
tlOut.a.valid := false.B
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unit test harness
|
||||
class TensorCoreDecoupledTest(timeout: Int = 500000)(implicit p: Parameters)
|
||||
extends UnitTest(timeout) {
|
||||
val dut = Module(LazyModule(new TensorCoreDecoupledTLRAM).module)
|
||||
// val dut = Module(LazyModule(new TensorCoreDecoupledTLRAM).module)
|
||||
val dut = Module(LazyModule(new TensorCoreDecoupledTwoTLRAM).module)
|
||||
dut.io.start := io.start
|
||||
io.finished := dut.io.finished
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
||||
val b = Vec(dotProductDim, Bits((inFLen).W))
|
||||
val c = Bits((outFLen).W) // note C has the out length for accumulation
|
||||
}))
|
||||
// 'stall' is effectively out.ready, combinationally coupled to in.ready
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((outFLen).W)
|
||||
@@ -52,7 +53,7 @@ class TensorDotProductUnit(val half: Boolean) extends Module with tile.HasFPUPar
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
}
|
||||
|
||||
// Copied from chisel3.util.Pipe.
|
||||
// An implementation of chisel3.util.Pipe that supports stalls.
|
||||
class StallingPipe[T <: Data](val gen: T, val latency: Int = 1) extends Module {
|
||||
/** A non-ambiguous name of this `StallingPipe` for use in generated Verilog
|
||||
* names. Includes the latency cycle count in the name as well as the
|
||||
|
||||
@@ -372,7 +372,8 @@ class SourceGenerator[T <: Data](
|
||||
outstanding := outstanding + 1.U
|
||||
}
|
||||
}.elsewhen(io.reclaim.valid) {
|
||||
assert(outstanding > 0.U)
|
||||
assert(outstanding > 0.U,
|
||||
"Over-reclaim. Did some responses get dropped?")
|
||||
outstanding := outstanding - 1.U
|
||||
}
|
||||
dontTouch(outstanding)
|
||||
|
||||
@@ -137,7 +137,7 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
||||
"NUM_THREADS" -> tile.numLsuLanes
|
||||
)
|
||||
)
|
||||
with HasBlackBoxResource {
|
||||
with HasBlackBoxResource with HasBlackBoxPath {
|
||||
// addResource("/vsrc/vortex/hw/unit_tests/generic_queue/testbench.v")
|
||||
// addResource("/vsrc/vortex/hw/unit_tests/VX_divide_tb.v")
|
||||
// addResource("/vsrc/vortex/hw/syn/synopsys/models/memory/cln28hpm/rf2_256x19_wm0/rf2_256x19_wm0_rtl.v")
|
||||
@@ -408,6 +408,34 @@ class Vortex(tile: RadianceTile)(implicit p: Parameters)
|
||||
addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_hopper_core.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/mem/VX_tc_bus_if.sv")
|
||||
// addResource("/vsrc/vortex/hw/rtl/core/VX_tensor_ucode.vh")
|
||||
def addHopperTensorCore = {
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRawFN.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/AddRecFN.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/DotProductPipe.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/FillBuffer_1.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/FillBuffer.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/metadataTable_4x5.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/MulFullRawFN.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/occupancyTable_4x1.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/Queue1_TensorCoreDecoupled_Anon_1.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/Queue1_TensorCoreDecoupled_Anon.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/Queue1_TensorMemTag.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/Queue4_TensorMemRespWithTag.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/Queue5_TensorComputeTag.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/ram_4x261.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/ram_5x7.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/RoundAnyRawFNToRecFN_ie8_is26_oe8_os24.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/RoundAnyRawFNToRecFN_ie8_is47_oe8_os24.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/RoundRawFNToRecFN_e8_s24.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/SimpleTimer.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/SourceGenerator.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/StallingPipe_1.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/StallingPipe_2.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/StallingPipe.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/TensorCoreDecoupled.sv")
|
||||
addPath("/scratch/hansung/chipyard/sims/vcs/generated-src/chipyard.unittest.TestHarness.TensorUnitTestConfig/gen-collateral/TensorDotProductUnit.sv")
|
||||
}
|
||||
addHopperTensorCore
|
||||
addResource("/vsrc/vortex/hw/rtl/core/VX_uop_sequencer.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/core/VX_reduce_unit.sv")
|
||||
addResource("/vsrc/vortex/hw/rtl/fpu/VX_tensor_dpu.sv")
|
||||
|
||||
Reference in New Issue
Block a user