tensor: Support FP16 in TensorCoreDecoupled
This commit is contained in:
@@ -15,28 +15,46 @@ import radiance.memory.SourceGenerator
|
||||
|
||||
case class TensorTilingParams(
|
||||
// Dimension of the SMEM tile
|
||||
m: Int = 16,
|
||||
n: Int = 16,
|
||||
k: Int = 16,
|
||||
m: Int,
|
||||
n: Int,
|
||||
k: Int,
|
||||
// Dimension of the compute tile. This is determined by the number of MAC
|
||||
// units
|
||||
mc: Int = 4,
|
||||
nc: Int = 4,
|
||||
kc: Int = 4
|
||||
mc: Int,
|
||||
nc: Int,
|
||||
kc: Int,
|
||||
)
|
||||
|
||||
object TensorTilingParams {
|
||||
def fp16: TensorTilingParams = {
|
||||
TensorTilingParams (
|
||||
m = 16, n = 16, k = 32,
|
||||
mc = 4, nc = 4, kc = 8
|
||||
)
|
||||
}
|
||||
def fp32: TensorTilingParams = {
|
||||
TensorTilingParams (
|
||||
m = 16, n = 16, k = 16,
|
||||
mc = 4, nc = 4, kc = 4
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class TensorCoreDecoupled(
|
||||
val numWarps: Int,
|
||||
val numLanes: Int,
|
||||
val numSourceIds: Int,
|
||||
val tilingParams: TensorTilingParams,
|
||||
val half: Boolean, // input datatype is FP16 if true, FP32 if false
|
||||
val numSourceIds: Int = 16,
|
||||
val numFPRegs: Int = 32
|
||||
) extends Module {
|
||||
val tilingParams =
|
||||
if (half) TensorTilingParams.fp16 else TensorTilingParams.fp32
|
||||
val numWarpBits = log2Ceil(numWarps)
|
||||
val wordSize = 4 // TODO FP16
|
||||
val wordSizeInBits = wordSize * 8 // TODO FP16
|
||||
val wordSize = if (half) 2 else 4
|
||||
val wordSizeInBits = wordSize * 8/*bits*/
|
||||
val sourceWidth = log2Ceil(numSourceIds)
|
||||
val dataWidth = numLanes * wordSizeInBits // TODO FP16
|
||||
val laneWidth = 4/*bytes*/ * 8/*bits*/
|
||||
val memWidth = numLanes * laneWidth
|
||||
val numFPRegBits = log2Ceil(numFPRegs)
|
||||
|
||||
val io = IO(new Bundle {
|
||||
@@ -47,11 +65,11 @@ class TensorCoreDecoupled(
|
||||
val last = Bool()
|
||||
val wid = UInt(numWarpBits.W)
|
||||
val rd = UInt(numFPRegBits.W)
|
||||
val data = Vec(numLanes, UInt((wordSizeInBits).W))
|
||||
val data = Vec(numLanes, UInt(laneWidth.W))
|
||||
})
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
|
||||
val respC = Input(UInt(dataWidth.W))
|
||||
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
||||
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
|
||||
val respC = Input(UInt(memWidth.W))
|
||||
val reqA = Decoupled(new TensorMemReq(sourceWidth))
|
||||
val reqB = Decoupled(new TensorMemReq(sourceWidth))
|
||||
val reqC = Output(Valid(UInt(numFPRegBits.W)))
|
||||
@@ -185,7 +203,8 @@ class TensorCoreDecoupled(
|
||||
val blockRow = set
|
||||
val blockCol = index
|
||||
val blockIndex = (blockRow << indexBits) + blockCol
|
||||
val blockSize = numLanes * wordSize
|
||||
val blockSize = numLanes * laneWidth
|
||||
require(blockSize == memWidth)
|
||||
val blockSizeBits = log2Ceil(blockSize)
|
||||
val byteOffset = blockIndex << blockSizeBits
|
||||
base + byteOffset
|
||||
@@ -222,8 +241,8 @@ class TensorCoreDecoupled(
|
||||
tagB.set := stateB.set
|
||||
tagB.index := stateB.index
|
||||
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
|
||||
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
|
||||
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
|
||||
Seq((io.reqA, (io.respA, respATagged)),
|
||||
(io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach {
|
||||
case ((req, (resp, respTagged)), i) => {
|
||||
@@ -543,24 +562,32 @@ class TensorCoreDecoupled(
|
||||
require(tilingParams.mc * ncSubstep == numLanes,
|
||||
"substep tile size doesn't match writeback throughput")
|
||||
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
|
||||
Module(new TensorDotProductUnit(dim = 4, half = false))
|
||||
Module(new TensorDotProductUnit(
|
||||
dim = tilingParams.kc,
|
||||
half = half,
|
||||
))
|
||||
))
|
||||
|
||||
// reshape operands for easier routing to DPU
|
||||
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
|
||||
// reshape UInt into a two-dimensional array where the innermost dimension
|
||||
// has `numWords` elements
|
||||
def reshapeByWords(x: UInt, wordSizeInBits: Int, numWords: Int)
|
||||
: Seq[Seq[UInt]] = {
|
||||
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
|
||||
.grouped(4/*k-dim*/).toSeq
|
||||
.grouped(numWords).toSeq
|
||||
}
|
||||
val operandADimensional = reshapeByFourWords(operandA)
|
||||
val operandADimensional =
|
||||
reshapeByWords(operandA, wordSizeInBits, tilingParams.kc)
|
||||
require(operandADimensional.length == tilingParams.mc &&
|
||||
operandADimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
val operandBDimensional = reshapeByFourWords(operandB)
|
||||
val operandBDimensional =
|
||||
reshapeByWords(operandB, wordSizeInBits, tilingParams.kc)
|
||||
require(operandBDimensional.length == ncSubstep &&
|
||||
operandBDimensional(0).length == tilingParams.kc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
// note operand C is M-major
|
||||
val operandCDimensional = reshapeByFourWords(operandC)
|
||||
// note operand C is M-major, and always FP32
|
||||
val operandCDimensional =
|
||||
reshapeByWords(operandC, 4/*fp32*/ * 8/*bits*/, tilingParams.mc)
|
||||
require(operandCDimensional.length == ncSubstep &&
|
||||
operandCDimensional(0).length == tilingParams.mc,
|
||||
"operand width doesn't agree with tiling parameter")
|
||||
@@ -609,7 +636,7 @@ class TensorCoreDecoupled(
|
||||
val substep = UInt(1.W)
|
||||
}
|
||||
|
||||
val queueDepth = 5 // needs to be at least the DPU latency
|
||||
val queueDepth = (if (half) 6 else 5) // needs to be at least the DPU latency
|
||||
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
|
||||
tagQueue.io.enq.valid := dpuFire
|
||||
tagQueue.io.enq.bits.warp := operandATag.warp
|
||||
@@ -617,6 +644,8 @@ class TensorCoreDecoupled(
|
||||
tagQueue.io.enq.bits.step := stepCompute
|
||||
tagQueue.io.enq.bits.substep := substepCompute
|
||||
tagQueue.io.deq.ready := io.writeback.fire
|
||||
// this is not necessary for correctness, and might trigger when there's a
|
||||
// lot of writeback contention
|
||||
assert(tagQueue.io.enq.ready === true.B,
|
||||
"tag queue full, DPU operation might be throttled")
|
||||
assert(!dpuValid || tagQueue.io.deq.valid,
|
||||
@@ -727,7 +756,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
|
||||
require(outer.node.out.length == 2/*A and B*/)
|
||||
|
||||
val tensor = Module(new TensorCoreDecoupled(
|
||||
8, 8, outer.numSourceIds , TensorTilingParams()))
|
||||
8, 8, half = true, outer.numSourceIds))
|
||||
val wordSize = 4 // @cleanup: hardcoded
|
||||
|
||||
val zip = Seq((outer.node.out(0), tensor.io.reqA),
|
||||
|
||||
Reference in New Issue
Block a user