From 3b1ab4e10d4ec0dc698de763c34b1fc8e2a79f22 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 28 May 2024 18:27:56 -0700 Subject: [PATCH] Write four-element dpu without accumulation --- src/main/scala/radiance/core/TensorDPU.scala | 97 ++++++++++++++++---- src/test/scala/radiance/TensorDPUTest.scala | 38 ++++---- 2 files changed, 99 insertions(+), 36 deletions(-) diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index 42de021..90dd2ec 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -11,35 +11,98 @@ class DPUPipe extends Module with tile.HasFPUParameters { val fLen = 32 val minFLen = 32 def xLen = 32 + val dotProductDim = 4 val io = IO(new Bundle { val in = Flipped(Valid(new Bundle { - val a = Bits((fLen).W) - val b = Bits((fLen).W) - val c = Bits((fLen).W) + val a = Vec(dotProductDim, Bits((fLen).W)) + val b = Vec(dotProductDim, Bits((fLen).W)) + val c = Vec(dotProductDim, Bits((fLen).W)) })) val out = Valid(new Bundle { - val data = Bits((fLen+1).W) + val data = Bits((fLen).W) }) }) val t = tile.FType.S - val in1 = recode(io.in.bits.a, S) - val in2 = recode(io.in.bits.b, S) - val in3 = recode(io.in.bits.c, S) + val in1 = io.in.bits.a.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) + val in2 = io.in.bits.b.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) + val in3 = io.in.bits.c.map(x => unbox(recode(x, S), S, Some(tile.FType.S))) - val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig)) - fma.io.validin := io.in.valid - fma.io.op := 0.U // FIXME - fma.io.roundingMode := 0.U // FIXME - fma.io.detectTininess := hardfloat.consts.tininess_afterRounding - fma.io.a := unbox(in1, S, Some(tile.FType.S)) - fma.io.b := unbox(in2, S, Some(tile.FType.S)) - fma.io.c := unbox(in3, S, Some(tile.FType.S)) + // val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig)) + // fma.io.validin := io.in.valid + // fma.io.op := 0.U // FIXME + // fma.io.roundingMode := hardfloat.consts.round_near_even + // fma.io.detectTininess := hardfloat.consts.tininess_afterRounding + // fma.io.a := unbox(in1, S, Some(tile.FType.S)) + // fma.io.b := unbox(in2, S, Some(tile.FType.S)) + // fma.io.c := unbox(in3, S, Some(tile.FType.S)) - io.out.valid := fma.io.validout - io.out.bits.data := ieee(box(fma.io.out, S)) + val dpu = Module(new DotProductPipe(dotProductDim, t.exp, t.sig)) + dpu.io.in.valid := io.in.valid + dpu.io.in.bits.a := in1 + dpu.io.in.bits.b := in2 + // FIXME: in3 unused + + io.out.valid := dpu.io.out.valid + io.out.bits.data := ieee(box(dpu.io.out.bits.data, S)) +} + +class DotProductPipe(dim: Int, expWidth: Int, sigWidth: Int) extends Module { + require(dim == 4, "DPU currently only supports dimension 4") + + val recFLen = expWidth + sigWidth + 1 + val io = IO(new Bundle { + val in = Flipped(Valid(new Bundle { + val a = Vec(4, Bits((recFLen).W)) + val b = Vec(4, Bits((recFLen).W)) + // val roundingMode = UInt(3.W) + // val detectTininess = UInt(1.W) + })) + val out = Valid(new Bundle { + val data = Bits((recFLen).W) + }) + }) + + val mul = Seq.fill(dim)(Module(new hardfloat.MulRecFN(expWidth, sigWidth))) + mul.zipWithIndex.foreach { case (m, i) => + m.io.roundingMode := hardfloat.consts.round_near_even // consts.round_near_maxMag + m.io.detectTininess := hardfloat.consts.tininess_afterRounding + m.io.a := io.in.bits.a(i) + m.io.b := io.in.bits.b(i) + } + + val mulStageOut = Pipe(io.in.valid, VecInit(mul.map(_.io.out))) + + // mul stage end ------------------------------------------------------- + + val add1 = Seq.fill(dim / 2)(Module(new hardfloat.AddRecFN(expWidth, sigWidth))) + add1.zipWithIndex.foreach { case (a, i) => + a.io.subOp := 0.U // FIXME + a.io.a := mulStageOut.bits(2 * i + 0) + a.io.b := mulStageOut.bits(2 * i + 1) + a.io.roundingMode := hardfloat.consts.round_near_even + a.io.detectTininess := hardfloat.consts.tininess_afterRounding + } + + val add1StageOut = Pipe(mulStageOut.valid, VecInit(add1.map(_.io.out))) + + // add stage 1 end ----------------------------------------------------- + + val add2 = Module(new hardfloat.AddRecFN(expWidth, sigWidth)) + add2.io.subOp := 0.U // FIXME + assert(add1StageOut.bits.length == 2) + add2.io.a := add1StageOut.bits(0) + add2.io.b := add1StageOut.bits(1) + add2.io.roundingMode := hardfloat.consts.round_near_even + add2.io.detectTininess := hardfloat.consts.tininess_afterRounding + + io.out.valid := Pipe(add1StageOut.valid, false.B).valid + io.out.bits.data := Pipe(add1StageOut.valid, add2.io.out).bits + // FIXME: exception output ignored + + // add stage 2 end ----------------------------------------------------- } class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index f90b717..5993edf 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -1,7 +1,6 @@ package radiance.core import chisel3._ -import chisel3.stage.PrintFullStackTraceAnnotation import chisel3.util._ import chiseltest._ import chiseltest.simulator.VerilatorFlags @@ -49,25 +48,26 @@ class DPUPipeTest extends AnyFlatSpec with ChiselScalatestTester { it should "pass" in { test(new DPUPipe) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteFstAnnotation)) // .withAnnotations(Seq(WriteVcdAnnotation)) - { fma => - fma.io.in.valid.poke(true.B) - fma.io.in.bits.a.poke(0x40000000L.U(64.W)) - fma.io.in.bits.b.poke(0x40400000L.U(64.W)) - fma.io.in.bits.c.poke(0x3f800000L.U(64.W)) - fma.clock.step() - fma.io.in.valid.poke(true.B) - fma.io.in.bits.a.poke(0x40000000L.U(64.W)) - fma.io.in.bits.b.poke(0x3f800000L.U(64.W)) - fma.io.in.bits.c.poke(0x3f800000L.U(64.W)) - fma.clock.step() - fma.io.in.valid.poke(false.B) - fma.io.out.valid.expect(true.B) - fma.io.out.bits.data.expect(0x40e00000L.U) - fma.clock.step() - // pipelined back-to-back response - fma.io.out.valid.expect(true.B) - fma.io.out.bits.data.expect(0x40400000L.U) + { c => + c.io.in.valid.poke(true.B) + c.io.in.bits.a(0).poke(0x40000000L.U(64.W)) + c.io.in.bits.a(1).poke(0x40000000L.U(64.W)) + c.io.in.bits.a(2).poke(0x40000000L.U(64.W)) + c.io.in.bits.a(3).poke(0x40000000L.U(64.W)) + c.io.in.bits.b(0).poke(0x40000000L.U(64.W)) + c.io.in.bits.b(1).poke(0x40000000L.U(64.W)) + c.io.in.bits.b(2).poke(0x40000000L.U(64.W)) + c.io.in.bits.b(3).poke(0x40000000L.U(64.W)) + c.clock.step() + c.io.in.valid.poke(false.B) + c.clock.step() + c.clock.step() + c.io.out.valid.expect(true.B) + c.io.out.bits.data.expect(0x40e00000L.U) + c.clock.step() + c.io.out.valid.expect(false.B) } } }