Add stall IO to dpu

This commit is contained in:
Hansung Kim
2024-05-28 21:18:19 -07:00
parent 907150e51c
commit 793db0e29d

View File

@@ -7,7 +7,8 @@ import chisel3._
import chisel3.util._
import freechips.rocketchip.tile
class DPUPipe extends Module with tile.HasFPUParameters {
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
class TensorDotProductUnit extends Module with tile.HasFPUParameters {
val fLen = 32
val minFLen = 32
def xLen = 32
@@ -19,6 +20,7 @@ class DPUPipe extends Module with tile.HasFPUParameters {
val b = Vec(dotProductDim, Bits((fLen).W))
val c = Bits((fLen).W)
}))
val stall = Input(Bool())
val out = Valid(new Bundle {
val data = Bits((fLen).W)
})
@@ -30,20 +32,12 @@ class DPUPipe extends Module with tile.HasFPUParameters {
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S))
// val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
// fma.io.validin := io.in.valid
// fma.io.op := 0.U // FIXME
// fma.io.roundingMode := hardfloat.consts.round_near_even
// fma.io.detectTininess := hardfloat.consts.tininess_afterRounding
// fma.io.a := unbox(in1, S, Some(tile.FType.S))
// fma.io.b := unbox(in2, S, Some(tile.FType.S))
// fma.io.c := unbox(in3, S, Some(tile.FType.S))
val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig))
dpu.io.in.valid := io.in.valid
dpu.io.in.bits.a := in1
dpu.io.in.bits.b := in2
dpu.io.in.bits.c := in3
dpu.io.stall := io.stall
io.out.valid := dpu.io.out.valid
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
@@ -63,6 +57,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
// val roundingMode = UInt(3.W)
// val detectTininess = UInt(1.W)
}))
val stall = Input(Bool())
val out = Valid(new Bundle {
val data = Bits((recFLen).W)
})
@@ -70,7 +65,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth)))
mul.zipWithIndex.foreach { case (m, i) =>
m.io.roundingMode := hardfloat.consts.round_near_even // consts.round_near_maxMag
// FIXME: these settings are arbitrary
m.io.roundingMode := hardfloat.consts.round_near_even
m.io.detectTininess := hardfloat.consts.tininess_afterRounding
m.io.a := io.in.bits.a(i)
m.io.b := io.in.bits.b(i)
@@ -79,7 +75,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out)))
val mulStageC = Pipe(io.in.valid, io.in.bits.c)
// mul stage end -------------------------------------------------------
// mul stage end -------------------------------------------------------------
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth)))
add1.zipWithIndex.foreach { case (a, i) =>
@@ -93,7 +89,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out)))
val add1StageC = Pipe(mulStageC)
// add1 stage end -----------------------------------------------------
// add1 stage end ------------------------------------------------------------
val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
add2.io.subOp := 0.U // FIXME
@@ -106,7 +102,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
val add2StageOut = Pipe(add1StageOut.valid, add2.io.out)
val add2StageC = Pipe(add1StageC)
// add2 stage end -----------------------------------------------------
// add2 stage end ------------------------------------------------------------
val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
acc.io.subOp := 0.U // FIXME
@@ -119,7 +115,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits
// FIXME: exception output ignored
// acc stage end -----------------------------------------------------
// acc stage end -------------------------------------------------------------
}
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {