diff --git a/macros/src/main/scala/MacroCompiler.scala b/macros/src/main/scala/MacroCompiler.scala index 32a8b64a..64cee44e 100644 --- a/macros/src/main/scala/MacroCompiler.scala +++ b/macros/src/main/scala/MacroCompiler.scala @@ -19,7 +19,7 @@ import Utils._ */ // TODO: eventually explore compiling a single target memory using multiple // different kinds of target memory. -trait CostMetric { +trait CostMetric extends Serializable { /** * Cost function that returns the cost of compiling a memory using a certain * macro. @@ -100,6 +100,32 @@ object NewDefaultMetric extends CostMetric { } } +object MacroCompilerUtil { + import java.io._ + import java.util.Base64 + + // Adapted from https://stackoverflow.com/a/134918 + + /** Serialize an arbitrary object to String. + * Used to pass structured values through as an annotation. */ + def objToString(o: Serializable): String = { + val baos: ByteArrayOutputStream = new ByteArrayOutputStream + val oos: ObjectOutputStream = new ObjectOutputStream(baos) + oos.writeObject(o) + oos.close() + return Base64.getEncoder.encodeToString(baos.toByteArray) + } + + /** Deserialize an arbitrary object from String. */ + def objFromString(s: String): AnyRef = { + val data = Base64.getDecoder.decode(s) + val ois: ObjectInputStream = new ObjectInputStream(new ByteArrayInputStream(data)) + val o = ois.readObject + ois.close() + return o + } +} + object CostMetric { /** Define some default metric. */ val default: CostMetric = NewDefaultMetric @@ -108,30 +134,48 @@ object CostMetric { def getCostMetric(m: String, params: Map[String, String]): CostMetric = m match { case "default" => default case "PalmerMetric" => PalmerMetric - case "ExternalMetric" => new ExternalMetric(params.get("path").get) + case "ExternalMetric" => { + try { + new ExternalMetric(params.get("path").get) + } catch { + case e: NoSuchElementException => throw new IllegalArgumentException("Missing parameter 'path'") + } + } case "NewDefaultMetric" => NewDefaultMetric case _ => throw new IllegalArgumentException("Invalid cost metric " + m) } } object MacroCompilerAnnotation { - def apply(c: String, mem: File, lib: Option[File], synflops: Boolean): Annotation = - apply(c, mem.toString, lib map (_.toString), synflops) + /** + * Parameters associated to this MacroCompilerAnnotation. + * @param mem Path to memory lib + * @param lib Path to library lib or None if no libraries + * @param costMetric Cost metric to use + * @param synflops True to syn flops + */ + case class Params(mem: String, lib: Option[String], costMetric: CostMetric, synflops: Boolean) + + /** + * Create a MacroCompilerAnnotation. + * @param c Name of the module(?) for this annotation. + * @param p Parameters (see above). + */ + def apply(c: String, p: Params): Annotation = + Annotation(CircuitName(c), classOf[MacroCompilerTransform], MacroCompilerUtil.objToString(p)) - def apply(c: String, mem: String, lib: Option[String], synflops: Boolean): Annotation = { - Annotation(CircuitName(c), classOf[MacroCompilerTransform], - s"${mem} %s ${synflops}".format(lib getOrElse "")) - } - private val matcher = "([^ ]+) ([^ ]*) (true|false)".r def unapply(a: Annotation) = a match { - case Annotation(CircuitName(c), t, matcher(mem, lib, synflops)) if t == classOf[MacroCompilerTransform] => - Some((c, Some(mem), if (lib.isEmpty) None else Some(lib), synflops.toBoolean)) + case Annotation(CircuitName(c), t, serialized) if t == classOf[MacroCompilerTransform] => { + val p: Params = MacroCompilerUtil.objFromString(serialized).asInstanceOf[Params] + Some(c, p) + } case _ => None } } class MacroCompilerPass(mems: Option[Seq[Macro]], - libs: Option[Seq[Macro]]) extends firrtl.passes.Pass { + libs: Option[Seq[Macro]], + costMetric: CostMetric = CostMetric.default) extends firrtl.passes.Pass { def compile(mem: Macro, lib: Macro): Option[(Module, ExtModule)] = { val pairedPorts = mem.sortedPorts zip lib.sortedPorts @@ -437,7 +481,7 @@ class MacroCompilerPass(mems: Option[Seq[Macro]], (best, cost) case ((best, cost), lib) => // Run the cost function to evaluate this potential compile. - CostMetric.default.cost(mem, lib) match { + costMetric.cost(mem, lib) match { case Some(newCost) => { System.err.println(s"Cost of ${lib.src.name} for ${mem.src.name}: ${newCost}") if (newCost > cost) (best, cost) @@ -469,10 +513,9 @@ class MacroCompilerTransform extends Transform { def inputForm = MidForm def outputForm = MidForm def execute(state: CircuitState) = getMyAnnotations(state) match { - case Seq(MacroCompilerAnnotation(state.circuit.main, memFile, libFile, synflops)) => - require(memFile.isDefined) + case Seq(MacroCompilerAnnotation(state.circuit.main, MacroCompilerAnnotation.Params(memFile, libFile, costMetric, synflops))) => // Read, eliminate None, get only SRAM, make firrtl macro - val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(memFile) match { + val mems: Option[Seq[Macro]] = mdf.macrolib.Utils.readMDFFromPath(Some(memFile)) match { case Some(x:Seq[mdf.macrolib.Macro]) => Some(Utils.filterForSRAM(Some(x)) getOrElse(List()) map {new Macro(_)}) case _ => None @@ -483,7 +526,7 @@ class MacroCompilerTransform extends Transform { case _ => None } val transforms = Seq( - new MacroCompilerPass(mems, libs), + new MacroCompilerPass(mems, libs, costMetric), new SynFlopsPass(synflops, libs getOrElse mems.get)) (transforms foldLeft state)((s, xform) => xform runTransform s).copy(form=outputForm) case _ => state @@ -529,9 +572,9 @@ object MacroCompiler extends App { " -cp, --cost-param: Cost function parameter. (Optional depending on the cost function.). e.g. -c ExternalMetric -cp path /path/to/my/cost/script", " --syn-flops: Produces synthesizable flop-based memories (for all memories and library memory macros); likely useful for simulation purposes") mkString "\n" - def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, Boolean) = + def parseArgs(map: MacroParamMap, costMap: CostParamMap, synflops: Boolean, args: List[String]): (MacroParamMap, CostParamMap, Boolean) = args match { - case Nil => (map, synflops) + case Nil => (map, costMap, synflops) case ("-m" | "--macro-list") :: value :: tail => parseArgs(map + (Macros -> value), costMap, synflops, tail) case ("-l" | "--library") :: value :: tail => @@ -551,7 +594,7 @@ object MacroCompiler extends App { } def run(args: List[String]) { - val (params, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args) + val (params, costParams, synflops) = parseArgs(Map[MacroParam, String](), Map[String, String](), false, args) try { val macros = Utils.filterForSRAM(mdf.macrolib.Utils.readMDFFromPath(params.get(Macros))).get map (x => (new Macro(x)).blackbox) @@ -562,8 +605,16 @@ object MacroCompiler extends App { // Note: the last macro in the input list is (seemingly arbitrarily) // determined as the firrtl "top-level module". val circuit = Circuit(NoInfo, macros, macros.last.name) - val annotations = AnnotationMap(Seq(MacroCompilerAnnotation( - circuit.main, params.get(Macros).get, params.get(Library), synflops))) + val annotations = AnnotationMap( + Seq(MacroCompilerAnnotation( + circuit.main, + MacroCompilerAnnotation.Params( + params.get(Macros).get, params.get(Library), + CostMetric.getCostMetric(params.getOrElse(CostFunc, "default"), costParams), + synflops + ) + )) + ) val state = CircuitState(circuit, HighForm, Some(annotations)) // Run the compiler.