From 4dba0def01c738985972840d06909b9f95e38ddd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 28 May 2024 16:41:44 -0700 Subject: [PATCH] Do proper recoding and boxing for FMA input --- src/main/scala/radiance/core/TensorDPU.scala | 37 ++++++++++++++++- src/test/scala/radiance/TensorDPUTest.scala | 43 +++++++++++++++++--- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/src/main/scala/radiance/core/TensorDPU.scala b/src/main/scala/radiance/core/TensorDPU.scala index ce60990..42de021 100644 --- a/src/main/scala/radiance/core/TensorDPU.scala +++ b/src/main/scala/radiance/core/TensorDPU.scala @@ -5,7 +5,42 @@ package radiance.core import chisel3._ import chisel3.util._ -import freechips.rocketchip.rocket._ +import freechips.rocketchip.tile + +class DPUPipe extends Module with tile.HasFPUParameters { + val fLen = 32 + val minFLen = 32 + def xLen = 32 + + val io = IO(new Bundle { + val in = Flipped(Valid(new Bundle { + val a = Bits((fLen).W) + val b = Bits((fLen).W) + val c = Bits((fLen).W) + })) + val out = Valid(new Bundle { + val data = Bits((fLen+1).W) + }) + }) + + val t = tile.FType.S + + val in1 = recode(io.in.bits.a, S) + val in2 = recode(io.in.bits.b, S) + val in3 = recode(io.in.bits.c, S) + + val fma = Module(new MulAddRecFNPipe(2, t.exp, t.sig)) + fma.io.validin := io.in.valid + fma.io.op := 0.U // FIXME + fma.io.roundingMode := 0.U // FIXME + fma.io.detectTininess := hardfloat.consts.tininess_afterRounding + fma.io.a := unbox(in1, S, Some(tile.FType.S)) + fma.io.b := unbox(in2, S, Some(tile.FType.S)) + fma.io.c := unbox(in3, S, Some(tile.FType.S)) + + io.out.valid := fma.io.validout + io.out.bits.data := ieee(box(fma.io.out, S)) +} class MulAddRecFNPipe(latency: Int, expWidth: Int, sigWidth: Int) extends Module { require(latency <= 2) diff --git a/src/test/scala/radiance/TensorDPUTest.scala b/src/test/scala/radiance/TensorDPUTest.scala index 4ecaa9b..f90b717 100644 --- a/src/test/scala/radiance/TensorDPUTest.scala +++ b/src/test/scala/radiance/TensorDPUTest.scala @@ -5,13 +5,16 @@ import chisel3.stage.PrintFullStackTraceAnnotation import chisel3.util._ import chiseltest._ import chiseltest.simulator.VerilatorFlags +import org.chipsalliance.cde.config.Parameters +import freechips.rocketchip.tile import org.scalatest.flatspec.AnyFlatSpec class MulAddTest extends AnyFlatSpec with ChiselScalatestTester { behavior of "MulAddRecFNPipe" + val t = tile.FType.S it should "do basic arithmetic" in { - test(new MulAddRecFNPipe(2, 8, 23)) + test(new MulAddRecFNPipe(2, t.exp, t.sig)) // .withAnnotations(Seq(WriteVcdAnnotation)) { c => c.io.validin.poke(true.B) @@ -24,18 +27,48 @@ class MulAddTest extends AnyFlatSpec with ChiselScalatestTester { // 0: round to nearest, ties to even c.io.roundingMode.poke(0.U) c.io.detectTininess.poke(hardfloat.consts.tininess_beforeRounding) - c.io.a.poke(0x3f800000.U/*2.0*/) - c.io.b.poke(0x3f800000.U/*3.0*/) - c.io.c.poke(0x00000000.U/*0.0*/) + c.io.a.poke(0x3f800000.U) + c.io.b.poke(0x3f800000.U) + c.io.c.poke(0x00000000.U) c.clock.step() c.io.validin.poke(false.B) c.io.validout.expect(false.B) c.clock.step() c.io.validout.expect(true.B) - c.io.out.expect(0x40c00000.U/*6.0*/) + c.io.out.expect(0x40c00000.U) c.clock.step() c.io.validout.expect(false.B) } } } +class DPUPipeTest extends AnyFlatSpec with ChiselScalatestTester { + behavior of "DPUPipe" + + implicit val p: Parameters = Parameters.empty + + it should "pass" in { + test(new DPUPipe) + // .withAnnotations(Seq(WriteVcdAnnotation)) + { fma => + fma.io.in.valid.poke(true.B) + fma.io.in.bits.a.poke(0x40000000L.U(64.W)) + fma.io.in.bits.b.poke(0x40400000L.U(64.W)) + fma.io.in.bits.c.poke(0x3f800000L.U(64.W)) + fma.clock.step() + fma.io.in.valid.poke(true.B) + fma.io.in.bits.a.poke(0x40000000L.U(64.W)) + fma.io.in.bits.b.poke(0x3f800000L.U(64.W)) + fma.io.in.bits.c.poke(0x3f800000L.U(64.W)) + fma.clock.step() + fma.io.in.valid.poke(false.B) + fma.io.out.valid.expect(true.B) + fma.io.out.bits.data.expect(0x40e00000L.U) + fma.clock.step() + // pipelined back-to-back response + fma.io.out.valid.expect(true.B) + fma.io.out.bits.data.expect(0x40400000L.U) + } + } +} +