tensor: Instantiate actual DPU

This commit is contained in:
Hansung Kim
2024-10-17 14:37:33 -07:00
parent e1e3ac8274
commit 8847278ad1

View File

@@ -33,8 +33,9 @@ class TensorCoreDecoupled(
) extends Module {
val numWarpBits = log2Ceil(numWarps)
val wordSize = 4 // TODO FP16
val wordSizeInBits = wordSize * 8 // TODO FP16
val sourceWidth = log2Ceil(numSourceIds)
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
val dataWidth = numLanes * wordSizeInBits // TODO FP16
val numFPRegBits = log2Ceil(numFPRegs)
val io = IO(new Bundle {
@@ -45,7 +46,7 @@ class TensorCoreDecoupled(
val last = Bool()
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
val data = Vec(numLanes, UInt((wordSizeInBits).W))
})
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
@@ -223,9 +224,6 @@ class TensorCoreDecoupled(
io.writeback.bits.data.widthOption.get,
"response data width does not match the writeback data width")
// FIXME: this need to change to dpu_ready
val dpuReady = io.writeback.ready // FIXME: this need be actual dpu
val substepExecute = RegInit(0.U(1.W))
when (respQueueA.fire) {
substepExecute := substepExecute + 1.U
@@ -267,7 +265,10 @@ class TensorCoreDecoupled(
fullAQueue.io.enq.bits.data := fullAEnqData
fullAQueue.io.enq.bits.tag := fullAEnqTag
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid
val operandA = fullAQueue.io.deq.bits.data
val operandB = respQueueB.bits.data
val dpuReady = Wire(Bool())
val dpuFire = operandsValid && dpuReady
val substepCompute = RegInit(0.U(1.W))
when (dpuFire) {
@@ -301,6 +302,66 @@ class TensorCoreDecoupled(
}
assertAligned
// Dot-product unit
//
// 4x2 four-element DPUs summing up to 32 MACs in total
val dpus = Seq.fill(4)(Seq.fill(2)(
Module(new TensorDotProductUnit(half = false))
))
// operandA is 4x4 in K-major
val operandADimensional =
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq
println(s"operandA: ${fullAQueue.io.deq.bits.data.widthOption.get} bits")
println(s"A: ${operandADimensional.length}, ${operandADimensional(0).length}")
assert(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// operandB is 2x4, i.e. 4x2 in N-major
val operandBDimensional =
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4).toSeq
println(s"B: ${operandBDimensional.length}, ${operandBDimensional(0).length}")
val ncSubstep = tilingParams.nc / 2
assert(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
assert(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
for (m <- 0 until tilingParams.mc) {
for (n <- 0 until ncSubstep) {
dpus(m)(n).io.in.valid := dpuFire
dpus(m)(n).io.in.bits.a := operandADimensional(m)
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
// dpu ready couples with writeback backpressure
dpus(m)(n).io.stall := !io.writeback.ready
}
}
dpuReady := !dpus(0)(0).io.stall
dontTouch(dpuFire)
dontTouch(dpuReady)
val dpuValids = dpus.flatMap(_.map(_.io.out.valid))
val dpuValid = dpuValids.reduce(_ && _)
def assertDPU = {
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
"stall signals of DPUs went unaligned")
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
"valid signals of DPUs went unaligned")
}
assertDPU
// flatten DPU output into 1D array in M-major order
val flattenedDPUOut = (0 until ncSubstep).flatMap { n =>
(0 until tilingParams.mc).map { m =>
dpus(m)(n).io.out.bits.data
}
}
io.writeback.bits.data := flattenedDPUOut
def rdGen(set: UInt, step: UInt): UInt = {
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
// thread
@@ -309,19 +370,11 @@ class TensorCoreDecoupled(
// FIXME: add substep here
}
io.writeback.valid := operandsValid // FIXME: bypass logic
io.writeback.valid := dpuValid
io.writeback.bits.wid := warpReg
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
// FIXME: debug dummy: pipe A directly to writeback
val groupedRespA = respQueueA.bits.data
.asBools.grouped(wordSize * 8/*bits*/)
.map(VecInit(_).asUInt)
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
wb := data
}
// State transition
// ----------------
//
@@ -400,7 +453,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
val tensor = Module(new TensorCoreDecoupled(
8, 8, outer.numSrcIds , TensorTilingParams()))
val wordSize = 4 // FIXME: hardcoded
val wordSize = 4 // @cleanup: hardcoded
val zip = Seq((outer.node.out(0), tensor.io.reqA),
(outer.node.out(1), tensor.io.reqB))
@@ -431,7 +484,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
tlOutB.d.ready := tensor.io.respB.ready
tensor.io.initiate.valid := io.start
tensor.io.initiate.bits.wid := 0.U // FIXME
tensor.io.initiate.bits.wid := 0.U // TODO
tensor.io.writeback.ready := true.B
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
@@ -443,7 +496,7 @@ class TensorCoreDecoupledTLRAM(implicit p: Parameters) extends LazyModule {
val xbar = LazyModule(new TLXbar)
val ram = LazyModule(new TLRAM(
address = AddressSet(0x0000, 0xffffff),
beatBytes = 32 // FIXME: hardcoded
beatBytes = 32 // @cleanup: hardcoded
))
ram.node :=* xbar.node :=* tensor.node
@@ -461,11 +514,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
val xbar = LazyModule(new TLXbar)
val ramA = LazyModule(new TLRAM(
address = AddressSet(0x000, 0xfffeff),
beatBytes = 32 // FIXME: hardcoded
beatBytes = 32 // @cleanup: hardcoded
))
val ramB = LazyModule(new TLRAM(
address = AddressSet(0x100, 0xfffeff),
beatBytes = 32 // FIXME: hardcoded
beatBytes = 32 // @cleanup: hardcoded
))
xbar.node :=* tensor.node