From 93bf7895bee4fe866ede244e91da9514bb321087 Mon Sep 17 00:00:00 2001 From: edwardcwang Date: Thu, 26 Apr 2018 10:33:55 -0700 Subject: [PATCH] Fix corner case in compiling a small mem using a large lib (#32) * Refactor bit pairs calculation into a separate function * Minor clarifications * Clarify MacroCompilerSpec helpers * Add SmallTagArrayTest test * Fix corner case in compiling a small mem using a large lib --- macros/src/main/scala/MacroCompiler.scala | 99 +++++++++++++------ macros/src/test/scala/MacroCompilerSpec.scala | 89 ++++++++++------- macros/src/test/scala/SpecificExamples.scala | 34 +++++++ 3 files changed, 155 insertions(+), 67 deletions(-) diff --git a/macros/src/main/scala/MacroCompiler.scala b/macros/src/main/scala/MacroCompiler.scala index fac8e309..ad38d344 100644 --- a/macros/src/main/scala/MacroCompiler.scala +++ b/macros/src/main/scala/MacroCompiler.scala @@ -102,17 +102,20 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], }) } - def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = { + /** + * Calculate bit pairs. + * This is a list of submemories by width. + * The tuples are (lsb, msb) inclusive. + * Example: (0, 7) and (8, 15) might be a split for a width=16 memory into two width=8 target memories. + * Another example: (0, 3), (4, 7), (8, 11) may be a split for a width-12 memory into 3 width-4 target memories. + * + * @param mem Memory to compile + * @param lib Lib to compile with + * @return Bit pairs or empty list if there was an error. + */ + private def calculateBitPairs(mem: Macro, lib: Macro): Seq[(BigInt, BigInt)] = { val pairedPorts = mem.sortedPorts zip lib.sortedPorts - // Width mapping - - /** - * This is a list of submemories by width. - * The tuples are (lsb, msb) inclusive. - * e.g. (0, 7) and (8, 15) might be a split for a width=16 memory into two - * width=8 memories. - */ val bitPairs = ArrayBuffer[(BigInt, BigInt)]() var currentLSB: BigInt = 0 @@ -133,7 +136,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], // Helper function to check if it's time to split memories. // @param effectiveLibWidth Split memory when we have this many bits. def splitMemory(effectiveLibWidth: Int): Unit = { - assert (!alreadySplit) + assert(!alreadySplit) if (bitsInCurrentMem == effectiveLibWidth) { bitPairCandidates += ((currentLSB, memBit - 1)) @@ -142,8 +145,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], } // Make sure we don't have a maskGran larger than the width of the memory. - assert (memPort.src.effectiveMaskGran <= memPort.src.width.get) - assert (libPort.src.effectiveMaskGran <= libPort.src.width.get) + assert(memPort.src.effectiveMaskGran <= memPort.src.width.get) + assert(libPort.src.effectiveMaskGran <= libPort.src.width.get) val libWidth = libPort.src.width.get @@ -182,8 +185,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], splitMemory(memMask.get) } else { // e.g. mem mask = 13, lib width = 8 - System.err.println(s"Unmasked target memory: unaligned mem maskGran ${p} with lib (${lib.src.name}) width ${libPort.src.width.get} not supported") - return None + System.err.println(s"Unmasked target memory: unaligned mem maskGran $p with lib (${lib.src.name}) width ${libPort.src.width.get} not supported") + return Seq() } } } @@ -199,8 +202,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], // Mem maskGran is a multiple of lib maskGran, carry on as normal. splitMemory(libWidth) } else { - System.err.println(s"Mem maskGran ${m} is not a multiple of lib maskGran ${l}: currently not supported") - return None + System.err.println(s"Mem maskGran $m is not a multiple of lib maskGran $l: currently not supported") + return Seq() } } else { // m < l // Lib maskGran > mem maskGran. @@ -218,8 +221,8 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], // of treating it as simply a width 4 (!!!) memory. // This would require a major refactor though. } else { - System.err.println(s"Lib maskGran ${m} is not a multiple of mem maskGran ${l}: currently not supported") - return None + System.err.println(s"Lib maskGran $m is not a multiple of mem maskGran $l: currently not supported") + return Seq() } } } @@ -228,7 +231,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], // Choose an actual bit pair to add. // We'll have to choose the smallest one (e.g. unmasked read port might be more tolerant of a bigger split than the masked write port). - if (bitPairCandidates.length == 0) { + if (bitPairCandidates.isEmpty) { // No pair needed to split, just continue } else { val bestPair = bitPairCandidates.reduceLeft((leftPair, rightPair) => { @@ -240,7 +243,22 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], } // Add in the last chunk if there are any leftovers bitPairs += ((currentLSB, mem.src.width.toInt - 1)) - // Check bit pairs + + bitPairs.toSeq + } + + def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = { + assert(mem.sortedPorts.lengthCompare(lib.sortedPorts.length) == 0, + "mem and lib should have an equal number of ports") + val pairedPorts = mem.sortedPorts zip lib.sortedPorts + + // Width mapping. See calculateBitPairs. + val bitPairs: Seq[(BigInt, BigInt)] = calculateBitPairs(mem, lib) + if (bitPairs.isEmpty) { + System.err.println("Error occurred during bitPairs calculations (bitPairs is empty).") + return None + } + // Check bit pairs. checkBitPairs(bitPairs) // Depth mapping @@ -278,8 +296,9 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], for ((off, i) <- (0 until mem.src.depth by lib.src.depth).zipWithIndex) { for (j <- bitPairs.indices) { val name = s"mem_${i}_${j}" + // Create the instance. stmts += WDefInstance(NoInfo, name, lib.src.name, lib.tpe) - // connect extra ports + // Connect extra ports of the lib. stmts ++= lib.extraPorts map { case (portName, portValue) => Connect(NoInfo, WSubField(WRef(name), portName), portValue) } @@ -383,14 +402,29 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], } else { require(isPowerOfTwo(libPort.src.effectiveMaskGran), "only powers of two masks supported for now") - val effectiveLibWidth = if (memPort.src.maskGran.get < libPort.src.effectiveMaskGran) memPort.src.maskGran.get else libPort.src.width.get + // How much of this lib's width we are effectively using. + // If we have a mem maskGran less than the lib's maskGran, we'll have to take the smaller maskGran. + // Example: if we have a lib whose maskGran is 8 but our mem's maskGran is 4. + // The other case is if we're using a larger lib than mem. + val usingLessThanLibMaskGran = (memPort.src.maskGran.get < libPort.src.effectiveMaskGran) + val effectiveLibWidth = if (usingLessThanLibMaskGran) + memPort.src.maskGran.get + else + libPort.src.width.get + cat(((0 until libPort.src.width.get by libPort.src.effectiveMaskGran) map (i => { - if (memPort.src.maskGran.get < libPort.src.effectiveMaskGran && i >= effectiveLibWidth) { + if (usingLessThanLibMaskGran && i >= effectiveLibWidth) { // If the memMaskGran is smaller than the lib's gran, then // zero out the upper bits. zero } else { - bits(WRef(mem), (low + i) / memPort.src.effectiveMaskGran) + if (i >= memPort.src.width.get) { + // If our bit is larger than the whole width of the mem, just zero out the upper bits. + zero + } else { + // Pick the appropriate bit from the mem mask. + bits(WRef(mem), (low + i) / memPort.src.effectiveMaskGran) + } } })).reverse) } @@ -589,9 +623,11 @@ class MacroCompilerTransform extends Transform { // FIXME: Use firrtl.LowerFirrtlOptimizations class MacroCompilerOptimizations extends SeqTransform { - def inputForm = LowForm - def outputForm = LowForm - def transforms = Seq( + def inputForm: CircuitForm = LowForm + + def outputForm: CircuitForm = LowForm + + def transforms: Seq[Transform] = Seq( passes.RemoveValidIf, new firrtl.transforms.ConstantPropagation, passes.memlib.VerilogMemDelays, @@ -602,11 +638,12 @@ class MacroCompilerOptimizations extends SeqTransform { } class MacroCompiler extends Compiler { - def emitter = new VerilogEmitter - def transforms = + def emitter: Emitter = new VerilogEmitter + + def transforms: Seq[Transform] = Seq(new MacroCompilerTransform) ++ - getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) ++ - Seq(new MacroCompilerOptimizations) + getLoweringTransforms(firrtl.HighForm, firrtl.LowForm) ++ + Seq(new MacroCompilerOptimizations) } object MacroCompiler extends App { diff --git a/macros/src/test/scala/MacroCompilerSpec.scala b/macros/src/test/scala/MacroCompilerSpec.scala index ade8d6ae..40c613ed 100644 --- a/macros/src/test/scala/MacroCompilerSpec.scala +++ b/macros/src/test/scala/MacroCompilerSpec.scala @@ -6,6 +6,8 @@ import firrtl.Parser.parse import firrtl.Utils.ceilLog2 import java.io.{File, StringWriter} +import mdf.macrolib.SRAMMacro + abstract class MacroCompilerSpec extends org.scalatest.FlatSpec with org.scalatest.Matchers { import scala.language.implicitConversions implicit def String2SomeString(i: String): Option[String] = Some(i) @@ -228,7 +230,7 @@ trait HasSimpleTestGenerator { // generator. def generatorType: String = this.getClass.getSimpleName - require (memDepth >= libDepth) + //require (memDepth >= libDepth) // Convenience variables to check if a mask exists. val memHasMask = memMaskGran != None @@ -258,11 +260,14 @@ trait HasSimpleTestGenerator { def generateLibSRAM() = generateSRAM(lib_name, libPortPrefix, libWidth, libDepth, libMaskGran, extraPorts) def generateMemSRAM() = generateSRAM(mem_name, memPortPrefix, memWidth, memDepth, memMaskGran) - val libSRAM = generateLibSRAM - val memSRAM = generateMemSRAM + def libSRAM = generateLibSRAM + def memSRAM = generateMemSRAM - writeToLib(lib, Seq(libSRAM)) - writeToMem(mem, Seq(memSRAM)) + def libSRAMs: Seq[SRAMMacro] = Seq(libSRAM) + def memSRAMs: Seq[SRAMMacro] = Seq(memSRAM) + + writeToLib(lib, libSRAMs) + writeToMem(mem, memSRAMs) // For masks, width it's a bit tricky since we have to consider cases like // memMaskGran = 4 and libMaskGran = 8. @@ -321,41 +326,52 @@ trait HasSimpleTestGenerator { } /** Helper function to generate a port. - * @param prefix Memory port prefix (e.g. "x" for ports like "x_clk") - * @param addrWidth Address port width - * @param width data width - * @param write Has a write port? - * @param writeEnable Has a write enable port? - * @param read Has a read port? - * @param readEnable Has a read enable port? - * @param mask Mask granularity (# bits) of the port or None. */ - def generatePort(prefix: String, addrWidth: Int, width: Int, write: Boolean, writeEnable: Boolean, read: Boolean, readEnable: Boolean, mask: Option[Int]): String = { - val readStr = if (read) s"output ${prefix}_dout : UInt<$width>" else "" - val writeStr = if (write) s"input ${prefix}_din : UInt<$width>" else "" - val readEnableStr = if (readEnable) s"input ${prefix}_read_en : UInt<1>" else "" - val writeEnableStr = if (writeEnable) s"input ${prefix}_write_en : UInt<1>" else "" + * + * @param prefix Memory port prefix (e.g. "x" for ports like "x_clk") + * @param addrWidth Address port width + * @param width data width + * @param write Has a write port? + * @param writeEnable Has a write enable port? + * @param read Has a read port? + * @param readEnable Has a read enable port? + * @param mask Mask granularity (# bits) of the port or None. + * @param extraPorts Extra ports (name, # bits) + */ + def generatePort(prefix: String, addrWidth: Int, width: Int, write: Boolean, writeEnable: Boolean, read: Boolean, readEnable: Boolean, mask: Option[Int], extraPorts: Seq[(String, Int)] = Seq()): String = { + val realPrefix = if (prefix == "") "" else prefix + "_" + + val readStr = if (read) s"output ${realPrefix}dout : UInt<$width>" else "" + val writeStr = if (write) s"input ${realPrefix}din : UInt<$width>" else "" + val readEnableStr = if (readEnable) s"input ${realPrefix}read_en : UInt<1>" else "" + val writeEnableStr = if (writeEnable) s"input ${realPrefix}write_en : UInt<1>" else "" val maskStr = mask match { - case Some(maskBits: Int) => s"input ${prefix}_mask : UInt<${maskBits}>" + case Some(maskBits: Int) => s"input ${realPrefix}mask : UInt<$maskBits>" case _ => "" } -s""" - input ${prefix}_clk : Clock - input ${prefix}_addr : UInt<$addrWidth> - ${writeStr} - ${readStr} - ${readEnableStr} - ${writeEnableStr} - ${maskStr} -""" + val extraPortsStr = extraPorts.map { case (name, bits) => s" input $name : UInt<$bits>" }.mkString("\n") + s""" + input ${realPrefix}clk : Clock + input ${realPrefix}addr : UInt<$addrWidth> + $writeStr + $readStr + $readEnableStr + $writeEnableStr + $maskStr +$extraPortsStr + """ } - /** Helper function to generate a RW footer port. - * @param prefix Memory port prefix (e.g. "x" for ports like "x_clk") - * @param readEnable Has a read enable port? - * @param mask Mask granularity (# bits) of the port or None. */ - def generateReadWriteFooterPort(prefix: String, readEnable: Boolean, mask: Option[Int]): String = { + /** + * Helper function to generate a RW footer port. + * + * @param prefix Memory port prefix (e.g. "x" for ports like "x_clk") + * @param readEnable Has a read enable port? + * @param mask Mask granularity (# bits) of the port or None. + * @param extraPorts Extra ports (name, # bits) + */ + def generateReadWriteFooterPort(prefix: String, readEnable: Boolean, mask: Option[Int], extraPorts: Seq[(String, Int)] = Seq()): String = { generatePort(prefix, lib_addr_width, libWidth, - write=true, writeEnable=true, read=true, readEnable=readEnable, mask) + write = true, writeEnable = true, read = true, readEnable = readEnable, mask = mask, extraPorts = extraPorts) } /** Helper function to generate a RW header port. @@ -385,8 +401,9 @@ ${generateHeaderPorts} // Generate the target memory ports. def generateFooterPorts(): String = { - require (libSRAM.ports.size == 1, "Footer generator only supports single RW port mem") - generateReadWriteFooterPort(libPortPrefix, libSRAM.ports(0).readEnable.isDefined, if (libHasMask) Some(libMaskBits) else None) + require(libSRAM.ports.size == 1, "Footer generator only supports single RW port mem") + generateReadWriteFooterPort(libPortPrefix, libSRAM.ports(0).readEnable.isDefined, + if (libHasMask) Some(libMaskBits) else None, extraPorts.map(p => (p.name, p.width))) } // Generate the footer (contains the target memory extmodule declaration by default). diff --git a/macros/src/test/scala/SpecificExamples.scala b/macros/src/test/scala/SpecificExamples.scala index 338569d6..2ca1ddf0 100644 --- a/macros/src/test/scala/SpecificExamples.scala +++ b/macros/src/test/scala/SpecificExamples.scala @@ -1,3 +1,4 @@ +// See LICENSE for license details. package barstools.macros import mdf.macrolib._ @@ -1232,6 +1233,39 @@ circuit smem_0_ext : compileExecuteAndTest(mem, lib, v, output) } +class SmallTagArrayTest extends MacroCompilerSpec with HasSRAMGenerator with HasSimpleTestGenerator { + // Test that mapping a smaller memory using a larger lib can still work. + override def memWidth: Int = 26 + override def memDepth: Int = 2 + override def memMaskGran: Option[Int] = Some(26) + override def memPortPrefix: String = "" + + override def libWidth: Int = 32 + override def libDepth: Int = 64 + override def libMaskGran: Option[Int] = Some(1) + override def libPortPrefix: String = "" + + override def extraPorts: Seq[MacroExtraPort] = Seq( + MacroExtraPort(name = "must_be_one", portType = Constant, width = 1, value = 1) + ) + + override def generateBody(): String = + s""" + | inst mem_0_0 of $lib_name + | mem_0_0.must_be_one <= UInt<1>("h1") + | mem_0_0.clk <= clk + | mem_0_0.addr <= addr + | node dout_0_0 = bits(mem_0_0.dout, 25, 0) + | mem_0_0.din <= bits(din, 25, 0) + | mem_0_0.mask <= cat(UInt<1>("h0"), cat(UInt<1>("h0"), cat(UInt<1>("h0"), cat(UInt<1>("h0"), cat(UInt<1>("h0"), cat(UInt<1>("h0"), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), cat(bits(mask, 0, 0), bits(mask, 0, 0)))))))))))))))))))))))))))))))) + | mem_0_0.write_en <= and(and(write_en, UInt<1>("h1")), UInt<1>("h1")) + | node dout_0 = dout_0_0 + | dout <= mux(UInt<1>("h1"), dout_0, UInt<1>("h0")) + """.stripMargin + + compileExecuteAndTest(mem, lib, v, output) +} + class RocketChipTest extends MacroCompilerSpec with HasSRAMGenerator { val mem = s"mem-RocketChipTest.json" val lib = s"lib-RocketChipTest.json"