tensor: Instantiate actual DPU
This commit is contained in:
@@ -33,8 +33,9 @@ class TensorCoreDecoupled(
|
||||
) extends Module {
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val wordSize = 4 // TODO FP16
|
||||
val wordSizeInBits = wordSize * 8 // TODO FP16
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
val dataWidth = numLanes * wordSize * 8/*bits*/ // TODO FP16
|
||||
val dataWidth = numLanes * wordSizeInBits // TODO FP16
|
||||
val numFPRegBits = log2Ceil(numFPRegs)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
@@ -45,7 +46,7 @@ class TensorCoreDecoupled(
|
||||
val last = Bool()
|
||||
val wid = UInt(numWarpBits.W)
|
||||
val rd = UInt(numFPRegBits.W)
|
||||
val data = Vec(numLanes, UInt((wordSize * 8/*bits*/).W))
|
||||
val data = Vec(numLanes, UInt((wordSizeInBits).W))
|
||||
})
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
@@ -223,9 +224,6 @@ class TensorCoreDecoupled(
|
||||
io.writeback.bits.data.widthOption.get,
|
||||
"response data width does not match the writeback data width")
|
||||
|
||||
// FIXME: this need to change to dpu_ready
|
||||
val dpuReady = io.writeback.ready // FIXME: this need be actual dpu
|
||||
|
||||
val substepExecute = RegInit(0.U(1.W))
|
||||
when (respQueueA.fire) {
|
||||
substepExecute := substepExecute + 1.U
|
||||
@@ -267,7 +265,10 @@ class TensorCoreDecoupled(
|
||||
fullAQueue.io.enq.bits.data := fullAEnqData
|
||||
fullAQueue.io.enq.bits.tag := fullAEnqTag
|
||||
|
||||
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid // FIXME?
|
||||
val operandsValid = fullAQueue.io.deq.valid && respQueueB.valid
|
||||
val operandA = fullAQueue.io.deq.bits.data
|
||||
val operandB = respQueueB.bits.data
|
||||
val dpuReady = Wire(Bool())
|
||||
val dpuFire = operandsValid && dpuReady
|
||||
val substepCompute = RegInit(0.U(1.W))
|
||||
when (dpuFire) {
|
||||
@@ -301,6 +302,66 @@ class TensorCoreDecoupled(
|
||||
}
|
||||
assertAligned
|
||||
|
||||
// Dot-product unit
|
||||
//
|
||||
// 4x2 four-element DPUs summing up to 32 MACs in total
|
||||
val dpus = Seq.fill(4)(Seq.fill(2)(
|
||||
Module(new TensorDotProductUnit(half = false))
|
||||
))
|
||||
// operandA is 4x4 in K-major
|
||||
val operandADimensional =
|
||||
operandA.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4).toSeq
|
||||
println(s"operandA: ${fullAQueue.io.deq.bits.data.widthOption.get} bits")
|
||||
println(s"A: ${operandADimensional.length}, ${operandADimensional(0).length}")
|
||||
assert(operandADimensional.length == tilingParams.mc &&
|
||||
operandADimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
// operandB is 2x4, i.e. 4x2 in N-major
|
||||
val operandBDimensional =
|
||||
operandB.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4).toSeq
|
||||
println(s"B: ${operandBDimensional.length}, ${operandBDimensional(0).length}")
|
||||
val ncSubstep = tilingParams.nc / 2
|
||||
assert(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
assert(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
|
||||
for (m <- 0 until tilingParams.mc) {
|
||||
for (n <- 0 until ncSubstep) {
|
||||
dpus(m)(n).io.in.valid := dpuFire
|
||||
dpus(m)(n).io.in.bits.a := operandADimensional(m)
|
||||
dpus(m)(n).io.in.bits.b := operandBDimensional(n)
|
||||
dpus(m)(n).io.in.bits.c := 0.U // FIXME: bogus accum data
|
||||
// dpu ready couples with writeback backpressure
|
||||
dpus(m)(n).io.stall := !io.writeback.ready
|
||||
}
|
||||
}
|
||||
dpuReady := !dpus(0)(0).io.stall
|
||||
dontTouch(dpuFire)
|
||||
dontTouch(dpuReady)
|
||||
|
||||
val dpuValids = dpus.flatMap(_.map(_.io.out.valid))
|
||||
val dpuValid = dpuValids.reduce(_ && _)
|
||||
def assertDPU = {
|
||||
val dpuStalls = dpus.flatMap(_.map(_.io.stall))
|
||||
assert(dpuStalls.reduce(_ && _) === dpuStalls.reduce(_ || _),
|
||||
"stall signals of DPUs went unaligned")
|
||||
assert(dpuValids.reduce(_ && _) === dpuValids.reduce(_ || _),
|
||||
"valid signals of DPUs went unaligned")
|
||||
}
|
||||
assertDPU
|
||||
|
||||
// flatten DPU output into 1D array in M-major order
|
||||
val flattenedDPUOut = (0 until ncSubstep).flatMap { n =>
|
||||
(0 until tilingParams.mc).map { m =>
|
||||
dpus(m)(n).io.out.bits.data
|
||||
}
|
||||
}
|
||||
io.writeback.bits.data := flattenedDPUOut
|
||||
|
||||
def rdGen(set: UInt, step: UInt): UInt = {
|
||||
// each step produces 4x4 output tile, written by 8 threads with 2 regs per
|
||||
// thread
|
||||
@@ -309,19 +370,11 @@ class TensorCoreDecoupled(
|
||||
// FIXME: add substep here
|
||||
}
|
||||
|
||||
io.writeback.valid := operandsValid // FIXME: bypass logic
|
||||
io.writeback.valid := dpuValid
|
||||
io.writeback.bits.wid := warpReg
|
||||
io.writeback.bits.rd := rdGen(setExecute, stepExecute)
|
||||
io.writeback.bits.last := setDone(setExecute) && stepDone(stepExecute)
|
||||
|
||||
// FIXME: debug dummy: pipe A directly to writeback
|
||||
val groupedRespA = respQueueA.bits.data
|
||||
.asBools.grouped(wordSize * 8/*bits*/)
|
||||
.map(VecInit(_).asUInt)
|
||||
(io.writeback.bits.data zip groupedRespA).foreach { case (wb, data) =>
|
||||
wb := data
|
||||
}
|
||||
|
||||
// State transition
|
||||
// ----------------
|
||||
//
|
||||
@@ -400,7 +453,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
|
||||
val tensor = Module(new TensorCoreDecoupled(
|
||||
8, 8, outer.numSrcIds , TensorTilingParams()))
|
||||
val wordSize = 4 // FIXME: hardcoded
|
||||
val wordSize = 4 // @cleanup: hardcoded
|
||||
|
||||
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
||||
(outer.node.out(1), tensor.io.reqB))
|
||||
@@ -431,7 +484,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
tlOutB.d.ready := tensor.io.respB.ready
|
||||
|
||||
tensor.io.initiate.valid := io.start
|
||||
tensor.io.initiate.bits.wid := 0.U // FIXME
|
||||
tensor.io.initiate.bits.wid := 0.U // TODO
|
||||
tensor.io.writeback.ready := true.B
|
||||
|
||||
io.finished := tensor.io.writeback.valid && tensor.io.writeback.bits.last
|
||||
@@ -443,7 +496,7 @@ class TensorCoreDecoupledTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
val xbar = LazyModule(new TLXbar)
|
||||
val ram = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x0000, 0xffffff),
|
||||
beatBytes = 32 // FIXME: hardcoded
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
|
||||
ram.node :=* xbar.node :=* tensor.node
|
||||
@@ -461,11 +514,11 @@ class TensorCoreDecoupledTwoTLRAM(implicit p: Parameters) extends LazyModule {
|
||||
val xbar = LazyModule(new TLXbar)
|
||||
val ramA = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x000, 0xfffeff),
|
||||
beatBytes = 32 // FIXME: hardcoded
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
val ramB = LazyModule(new TLRAM(
|
||||
address = AddressSet(0x100, 0xfffeff),
|
||||
beatBytes = 32 // FIXME: hardcoded
|
||||
beatBytes = 32 // @cleanup: hardcoded
|
||||
))
|
||||
|
||||
xbar.node :=* tensor.node
|
||||
|
||||
Reference in New Issue
Block a user