Add stall IO to dpu
This commit is contained in:
@@ -7,7 +7,8 @@ import chisel3._
|
||||
import chisel3.util._
|
||||
import freechips.rocketchip.tile
|
||||
|
||||
class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
// Implements the four-element dot product (FEDP) unit in Volta Tensor Cores.
|
||||
class TensorDotProductUnit extends Module with tile.HasFPUParameters {
|
||||
val fLen = 32
|
||||
val minFLen = 32
|
||||
def xLen = 32
|
||||
@@ -19,6 +20,7 @@ class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
val b = Vec(dotProductDim, Bits((fLen).W))
|
||||
val c = Bits((fLen).W)
|
||||
}))
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((fLen).W)
|
||||
})
|
||||
@@ -30,20 +32,12 @@ class DPUPipe extends Module with tile.HasFPUParameters {
|
||||
val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S)))
|
||||
val in3 = unbox(recode(io.in.bits.c, S), S, Some(tile.FType.S))
|
||||
|
||||
// val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig))
|
||||
// fma.io.validin := io.in.valid
|
||||
// fma.io.op := 0.U // FIXME
|
||||
// fma.io.roundingMode := hardfloat.consts.round_near_even
|
||||
// fma.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
// fma.io.a := unbox(in1, S, Some(tile.FType.S))
|
||||
// fma.io.b := unbox(in2, S, Some(tile.FType.S))
|
||||
// fma.io.c := unbox(in3, S, Some(tile.FType.S))
|
||||
|
||||
val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig))
|
||||
dpu.io.in.valid := io.in.valid
|
||||
dpu.io.in.bits.a := in1
|
||||
dpu.io.in.bits.b := in2
|
||||
dpu.io.in.bits.c := in3
|
||||
dpu.io.stall := io.stall
|
||||
|
||||
io.out.valid := dpu.io.out.valid
|
||||
io.out.bits.data := ieee(box(dpu.io.out.bits.data, S))
|
||||
@@ -63,6 +57,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
// val roundingMode = UInt(3.W)
|
||||
// val detectTininess = UInt(1.W)
|
||||
}))
|
||||
val stall = Input(Bool())
|
||||
val out = Valid(new Bundle {
|
||||
val data = Bits((recFLen).W)
|
||||
})
|
||||
@@ -70,7 +65,8 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
|
||||
val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth)))
|
||||
mul.zipWithIndex.foreach { case (m, i) =>
|
||||
m.io.roundingMode := hardfloat.consts.round_near_even // consts.round_near_maxMag
|
||||
// FIXME: these settings are arbitrary
|
||||
m.io.roundingMode := hardfloat.consts.round_near_even
|
||||
m.io.detectTininess := hardfloat.consts.tininess_afterRounding
|
||||
m.io.a := io.in.bits.a(i)
|
||||
m.io.b := io.in.bits.b(i)
|
||||
@@ -79,7 +75,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out)))
|
||||
val mulStageC = Pipe(io.in.valid, io.in.bits.c)
|
||||
|
||||
// mul stage end -------------------------------------------------------
|
||||
// mul stage end -------------------------------------------------------------
|
||||
|
||||
val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth)))
|
||||
add1.zipWithIndex.foreach { case (a, i) =>
|
||||
@@ -93,7 +89,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out)))
|
||||
val add1StageC = Pipe(mulStageC)
|
||||
|
||||
// add1 stage end -----------------------------------------------------
|
||||
// add1 stage end ------------------------------------------------------------
|
||||
|
||||
val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
add2.io.subOp := 0.U // FIXME
|
||||
@@ -106,7 +102,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
val add2StageOut = Pipe(add1StageOut.valid, add2.io.out)
|
||||
val add2StageC = Pipe(add1StageC)
|
||||
|
||||
// add2 stage end -----------------------------------------------------
|
||||
// add2 stage end ------------------------------------------------------------
|
||||
|
||||
val acc = Module(new hardfloat.AddRecFN(expWidth, sigWidth))
|
||||
acc.io.subOp := 0.U // FIXME
|
||||
@@ -119,7 +115,7 @@ class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
io.out.bits.data := Pipe(add2StageOut.valid, acc.io.out).bits
|
||||
// FIXME: exception output ignored
|
||||
|
||||
// acc stage end -----------------------------------------------------
|
||||
// acc stage end -------------------------------------------------------------
|
||||
}
|
||||
|
||||
class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module {
|
||||
|
||||
Reference in New Issue
Block a user