diff --git a/src/main/scala/radiance/subsystem/Configs.scala b/src/main/scala/radiance/subsystem/Configs.scala index 4a3a940..98addce 100644 --- a/src/main/scala/radiance/subsystem/Configs.scala +++ b/src/main/scala/radiance/subsystem/Configs.scala @@ -136,7 +136,7 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi case FP16 => GemminiFPConfigs.FP16DefaultConfig.copy( acc_scale_args = Some(ScaleArguments( (t: Float, u: Float) => {t}, - 1, Float(5, 11), -1, identity = "1.0", c_str = "((x))" + 1, Float(8, 24), -1, identity = "1.0", c_str = "((x))" )), mvin_scale_args = Some(ScaleArguments( (t: Float, u: Float) => t * u, @@ -148,8 +148,8 @@ class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossi // from sirius spatialArrayInputType = Float(5, 11, isRecoded = skipRecoding), spatialArrayWeightType = Float(5, 11, isRecoded = skipRecoding), - spatialArrayOutputType = Float(5, 11, isRecoded = skipRecoding), - accType = Float(5, 11), + spatialArrayOutputType = Float(8, 24, isRecoded = skipRecoding), + accType = Float(8, 24), // hardcode_d_to_garbage_addr = true, acc_read_full_width = false, // set to true to output fp32 diff --git a/src/main/scala/radiance/tile/GemminiTile.scala b/src/main/scala/radiance/tile/GemminiTile.scala index 54373a6..f458975 100644 --- a/src/main/scala/radiance/tile/GemminiTile.scala +++ b/src/main/scala/radiance/tile/GemminiTile.scala @@ -168,6 +168,8 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) val rs2 = UInt(64.W) } val ciscInst = Wire(ciscInstT) + val startsLoop = WireInit(false.B) + val runningLoops = RegInit(0.U(4.W)) val accCommandQueue = Module(new Queue(UInt(32.W), 4, false, true)) accCommandQueue.io.enq.bits := accSlave.cmd.bits @@ -228,19 +230,25 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) println(s"gemmini cisc initialized with DIM=${config.DIM}, tileSize=${tileSizeM},${tileSizeN},${tileSizeK}") println(f"boundsInst=${rectBoundsInst.litValue}%x, hexadecile=${spadHexadecile}") + when (ciscValid) { switch (ciscId(6, 0)) { is (0.U) { // compute on given hexadeciles val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8)) val accSkipInst = genAccSkipInst(ciscArgs(16), 0x2b8.U) + startsLoop := true.B ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) } // replaces opcode 0: (a, b, accum) = (0, 2, 0), op 1 = (0, 2, 1), op 2 = (1, 3, 1), op 3 = (1, 3, 0) is (1.U) { // compute on given hexadeciles and mvout to spad val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8)) // note that accumulation is disabled val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U) + startsLoop := true.B ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) } + is (2.U) { // no actual invocation, fake job placeholder + startsLoop := true.B + } is (8.U) { // set a, b stride val inst = Wire(ciscInstT) inst.inst := 0x1820b07b.U @@ -250,11 +258,13 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) } is (9.U) { // move out to scratchpad val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(7, 0) * spadHexadecile.U) << 32).asUInt | 0x278.U) + startsLoop := true.B ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst)) } is (10.U) { // load to scratchpad hexadeciles val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8)) val accSkipInst = genAccSkipInst(1.U, 0x2e0.U) + startsLoop := true.B ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst)) } // replaces opcode 10: (a, b) = (0, 2), opcode 11 = (1, 3), opcode 12 = (0, 0), opcode 13 = (2, 2) is (11.U) { // set d, c stride @@ -266,6 +276,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) } is (12.U) { // store to gmem val accSkipInst = genAccSkipInst(0.U, 0x78.U) + startsLoop := true.B ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst)) } @@ -279,6 +290,11 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) } } + val completionCount = PopCount(outer.gemmini.module.completion_io.completed) + val loopStarted = Mux(ciscValid && instCounter.value === 0.U && startsLoop, 1.U, 0.U) + runningLoops := runningLoops + loopStarted - completionCount + assert(runningLoops + loopStarted >= completionCount) + val gemminiIO = outer.gemmini.module.io.cmd val regValid = Wire(Bool()) @@ -299,6 +315,11 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) // (!outer.gemmini.module.io.busy, outer.gemmini.module.io.busy.asUInt) (true.B, outer.gemmini.module.io.busy.asUInt) } + + def gemminiRunningLoopsReg(_dReady: Bool): (Bool, UInt) = { + (true.B, runningLoops) + } + outer.regNode.regmap( 0x00 -> Seq(RegField.w(32, gemminiCommandReg(_, _))), 0x10 -> Seq( @@ -307,7 +328,8 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer) 0x18 -> Seq( RegField.w(32, gemminiRs2RegLSB), RegField.w(32, gemminiRs2RegMSB)), - 0x20 -> Seq(RegField.r(32, gemminiBusyReg(_))) + 0x20 -> Seq(RegField.r(32, gemminiBusyReg(_))), + 0x28 -> Seq(RegField.r(32, gemminiRunningLoopsReg(_))) ) assert(!regValid || gemminiIO.ready)