tensor: Support FP16 in TensorCoreDecoupled

This commit is contained in:
Hansung Kim
2024-10-25 22:26:04 -07:00
parent eed821eda6
commit 543eb2feb4

View File

@@ -15,28 +15,46 @@ import radiance.memory.SourceGenerator
case class TensorTilingParams(
// Dimension of the SMEM tile
m: Int = 16,
n: Int = 16,
k: Int = 16,
m: Int,
n: Int,
k: Int,
// Dimension of the compute tile. This is determined by the number of MAC
// units
mc: Int = 4,
nc: Int = 4,
kc: Int = 4
mc: Int,
nc: Int,
kc: Int,
)
object TensorTilingParams {
def fp16: TensorTilingParams = {
TensorTilingParams (
m = 16, n = 16, k = 32,
mc = 4, nc = 4, kc = 8
)
}
def fp32: TensorTilingParams = {
TensorTilingParams (
m = 16, n = 16, k = 16,
mc = 4, nc = 4, kc = 4
)
}
}
class TensorCoreDecoupled(
val numWarps: Int,
val numLanes: Int,
val numSourceIds: Int,
val tilingParams: TensorTilingParams,
val half: Boolean, // input datatype is FP16 if true, FP32 if false
val numSourceIds: Int = 16,
val numFPRegs: Int = 32
) extends Module {
val tilingParams =
if (half) TensorTilingParams.fp16 else TensorTilingParams.fp32
val numWarpBits = log2Ceil(numWarps)
val wordSize = 4 // TODO FP16
val wordSizeInBits = wordSize * 8 // TODO FP16
val wordSize = if (half) 2 else 4
val wordSizeInBits = wordSize * 8/*bits*/
val sourceWidth = log2Ceil(numSourceIds)
val dataWidth = numLanes * wordSizeInBits // TODO FP16
val laneWidth = 4/*bytes*/ * 8/*bits*/
val memWidth = numLanes * laneWidth
val numFPRegBits = log2Ceil(numFPRegs)
val io = IO(new Bundle {
@@ -47,11 +65,11 @@ class TensorCoreDecoupled(
val last = Bool()
val wid = UInt(numWarpBits.W)
val rd = UInt(numFPRegBits.W)
val data = Vec(numLanes, UInt((wordSizeInBits).W))
val data = Vec(numLanes, UInt(laneWidth.W))
})
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, dataWidth)))
val respC = Input(UInt(dataWidth.W))
val respA = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respB = Flipped(Decoupled(new TensorMemResp(sourceWidth, memWidth)))
val respC = Input(UInt(memWidth.W))
val reqA = Decoupled(new TensorMemReq(sourceWidth))
val reqB = Decoupled(new TensorMemReq(sourceWidth))
val reqC = Output(Valid(UInt(numFPRegBits.W)))
@@ -185,7 +203,8 @@ class TensorCoreDecoupled(
val blockRow = set
val blockCol = index
val blockIndex = (blockRow << indexBits) + blockCol
val blockSize = numLanes * wordSize
val blockSize = numLanes * laneWidth
require(blockSize == memWidth)
val blockSizeBits = log2Ceil(blockSize)
val byteOffset = blockIndex << blockSizeBits
base + byteOffset
@@ -222,8 +241,8 @@ class TensorCoreDecoupled(
tagB.set := stateB.set
tagB.index := stateB.index
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(dataWidth)))
val respATagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
val respBTagged = Wire(Decoupled(new TensorMemRespWithTag(memWidth)))
Seq((io.reqA, (io.respA, respATagged)),
(io.reqB, (io.respB, respBTagged))).zipWithIndex.foreach {
case ((req, (resp, respTagged)), i) => {
@@ -543,24 +562,32 @@ class TensorCoreDecoupled(
require(tilingParams.mc * ncSubstep == numLanes,
"substep tile size doesn't match writeback throughput")
val dpus = Seq.fill(tilingParams.mc)(Seq.fill(ncSubstep)(
Module(new TensorDotProductUnit(dim = 4, half = false))
Module(new TensorDotProductUnit(
dim = tilingParams.kc,
half = half,
))
))
// reshape operands for easier routing to DPU
def reshapeByFourWords(x: UInt): Seq[Seq[UInt]] = {
// reshape UInt into a two-dimensional array where the innermost dimension
// has `numWords` elements
def reshapeByWords(x: UInt, wordSizeInBits: Int, numWords: Int)
: Seq[Seq[UInt]] = {
x.asBools.grouped(wordSizeInBits).map(VecInit(_).asUInt).toSeq
.grouped(4/*k-dim*/).toSeq
.grouped(numWords).toSeq
}
val operandADimensional = reshapeByFourWords(operandA)
val operandADimensional =
reshapeByWords(operandA, wordSizeInBits, tilingParams.kc)
require(operandADimensional.length == tilingParams.mc &&
operandADimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
val operandBDimensional = reshapeByFourWords(operandB)
val operandBDimensional =
reshapeByWords(operandB, wordSizeInBits, tilingParams.kc)
require(operandBDimensional.length == ncSubstep &&
operandBDimensional(0).length == tilingParams.kc,
"operand width doesn't agree with tiling parameter")
// note operand C is M-major
val operandCDimensional = reshapeByFourWords(operandC)
// note operand C is M-major, and always FP32
val operandCDimensional =
reshapeByWords(operandC, 4/*fp32*/ * 8/*bits*/, tilingParams.mc)
require(operandCDimensional.length == ncSubstep &&
operandCDimensional(0).length == tilingParams.mc,
"operand width doesn't agree with tiling parameter")
@@ -609,7 +636,7 @@ class TensorCoreDecoupled(
val substep = UInt(1.W)
}
val queueDepth = 5 // needs to be at least the DPU latency
val queueDepth = (if (half) 6 else 5) // needs to be at least the DPU latency
val tagQueue = Module(new Queue(new TensorComputeTag, queueDepth))
tagQueue.io.enq.valid := dpuFire
tagQueue.io.enq.bits.warp := operandATag.warp
@@ -617,6 +644,8 @@ class TensorCoreDecoupled(
tagQueue.io.enq.bits.step := stepCompute
tagQueue.io.enq.bits.substep := substepCompute
tagQueue.io.deq.ready := io.writeback.fire
// this is not necessary for correctness, and might trigger when there's a
// lot of writeback contention
assert(tagQueue.io.enq.ready === true.B,
"tag queue full, DPU operation might be throttled")
assert(!dpuValid || tagQueue.io.deq.valid,
@@ -727,7 +756,7 @@ class TensorCoreDecoupledTLImp(outer: TensorCoreDecoupledTL)
require(outer.node.out.length == 2/*A and B*/)
val tensor = Module(new TensorCoreDecoupled(
8, 8, outer.numSourceIds , TensorTilingParams()))
8, 8, half = true, outer.numSourceIds))
val wordSize = 4 // @cleanup: hardcoded
val zip = Seq((outer.node.out(0), tensor.io.reqA),