Merge branch 'new-cisc'

This commit is contained in:
Hansung Kim
2024-11-09 22:36:27 -08:00

View File

@@ -199,7 +199,7 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
case Right(v: Int) => (v, v, v)
}
val config = outer.gemminiParams.gemminiConfig
val spadQuartile = config.sp_bank_entries * config.sp_banks / 4
val spadHexadecile = config.sp_bank_entries * config.sp_banks / 16
// TODO: as a temporary hack, bit 7 of the cisc opcode
// TODO: will force the tile size to be a square base on M.
@@ -210,71 +210,66 @@ class GemminiTileModuleImp(outer: GemminiTile) extends BaseTileModuleImp(outer)
_.rs2 -> (tileSizeM | (tileSizeM << 16) | (BigInt(tileSizeM) << 32)).U)
val boundsInst = Mux(ciscId(7), squareBoundsInst, rectBoundsInst)
def genStrideInst(tileA: UInt, tileB: UInt) = {
val inst = Wire(ciscInstT)
inst.inst := 0x3020b07b.U
inst.rs1 := tileA * spadHexadecile.U // A should be stored from the start of this block
inst.rs2 := (tileB + 1.U) * spadHexadecile.U // B should be stored up till the end of this block
inst
}
def genAccSkipInst(accumulate: UInt, skips: UInt) = {
val inst = Wire(ciscInstT)
inst.inst := 0x1020b07b.U
inst.rs1 := accumulate
inst.rs2 := skips
inst
}
println(s"gemmini cisc initialized with DIM=${config.DIM}, tileSize=${tileSizeM},${tileSizeN},${tileSizeK}")
println(f"boundsInst=${rectBoundsInst.litValue}%x, quartile=${spadQuartile}")
println(f"boundsInst=${rectBoundsInst.litValue}%x, hexadecile=${spadHexadecile}")
when (ciscValid) {
switch (ciscId(6, 0)) {
is (0.U) { // compute on given quadrants
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> 0.U, _.rs2 -> (spadQuartile * 3).U), // set A, B address
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0.U, _.rs2 -> x"0_000002b8".U) // set skip, acc
))
is (0.U) { // compute on given hexadeciles
val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8))
val accSkipInst = genAccSkipInst(ciscArgs(16), 0x2b8.U)
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
} // replaces opcode 0: (a, b, accum) = (0, 2, 0), op 1 = (0, 2, 1), op 2 = (1, 3, 1), op 3 = (1, 3, 0)
is (1.U) { // compute on given hexadeciles and mvout to spad
val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8))
// note that accumulation is disabled
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(23, 16) * spadHexadecile.U) << 32).asUInt | 0x238.U)
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
}
is (2.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> (spadQuartile * 1).U, _.rs2 -> (spadQuartile * 4).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x1.U, _.rs2 -> x"0_000002b8".U)
))
}
is (1.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> 0.U, _.rs2 -> (spadQuartile * 3).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x1.U, _.rs2 -> x"0_000002b8".U)
))
}
is (3.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> (spadQuartile * 1).U, _.rs2 -> (spadQuartile * 4).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x0.U, _.rs2 -> x"0_000002b8".U)
))
}
is (8.U) {
is (8.U) { // set a, b stride
val inst = Wire(ciscInstT)
inst.inst := 0x1820b07b.U
inst.rs1 := ciscArgs(11, 0)
inst.rs2 := ciscArgs(23, 12)
inst.rs1 := ciscArgs(11, 0) // a
inst.rs2 := ciscArgs(23, 12) // b
ciscInst := microcodeEntry(Seq(inst))
}
is (9.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0.U, _.rs2 -> 0x278.U),
))
is (9.U) { // move out to scratchpad
val accSkipInst = genAccSkipInst(0.U, ((ciscArgs(7, 0) * spadHexadecile.U) << 32).asUInt | 0x278.U)
ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst))
}
is (10.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> 0.U, _.rs2 -> (spadQuartile * 3).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x1.U, _.rs2 -> x"0_000002e0".U)
))
is (10.U) { // load to scratchpad hexadeciles
val strideInst = genStrideInst(ciscArgs(7, 0), ciscArgs(15, 8))
val accSkipInst = genAccSkipInst(1.U, 0x2e0.U)
ciscInst := microcodeEntry(Seq(boundsInst, strideInst, accSkipInst))
} // replaces opcode 10: (a, b) = (0, 2), opcode 11 = (1, 3), opcode 12 = (0, 0), opcode 13 = (2, 2)
is (11.U) { // set d, c stride
val inst = Wire(ciscInstT)
inst.inst := 0x1a20b07b.U
inst.rs1 := ciscArgs(11, 0) // d
inst.rs2 := ciscArgs(23, 12) // c
ciscInst := microcodeEntry(Seq(inst))
}
is (11.U) {
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> (spadQuartile * 1).U, _.rs2 -> (spadQuartile * 4).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x1.U, _.rs2 -> x"0_000002e0".U)
))
is (12.U) { // store to gmem
val accSkipInst = genAccSkipInst(0.U, 0x78.U)
ciscInst := microcodeEntry(Seq(boundsInst, accSkipInst))
}
is (12.U) { // test: DMA for tensor core
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> (spadQuartile * 0).U, _.rs2 -> (spadQuartile * 1).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x1.U, _.rs2 -> x"0_000002e0".U)
))
}
is (13.U) { // test: DMA for tensor core
ciscInst := microcodeEntry(Seq(boundsInst,
ciscInstT.Lit(_.inst -> 0x3020b07b.U, _.rs1 -> (spadQuartile * 2).U, _.rs2 -> (spadQuartile * 3).U),
ciscInstT.Lit(_.inst -> 0x1020b07b.U, _.rs1 -> 0x1.U, _.rs2 -> x"0_000002e0".U)
))
}
is (16.U) {
is (16.U) { // unused, configure gemmini
ciscInst := microcodeEntry(Seq(
ciscInstT.Lit(_.inst -> 0x0020b07b.U, _.rs1 -> x"3f800000_00080101".U, _.rs2 -> 0.U),
ciscInstT.Lit(_.inst -> 0x0020b07b.U, _.rs1 -> x"3f800000_00010004".U, _.rs2 -> x"10000_00000000".U),