diff --git a/src/main/scala/radiance/subsystem/Configs.scala b/src/main/scala/radiance/subsystem/Configs.scala index 8c3b2dd..2ecfc3a 100644 --- a/src/main/scala/radiance/subsystem/Configs.scala +++ b/src/main/scala/radiance/subsystem/Configs.scala @@ -3,14 +3,17 @@ package radiance.subsystem +import chisel3._ import chisel3.util._ import org.chipsalliance.cde.config._ import freechips.rocketchip.rocket._ import freechips.rocketchip.tile._ import freechips.rocketchip.subsystem._ -import gemmini.{CapacityInKilobytes, GemminiFPConfigs} +import gemmini._ +import gemmini.Arithmetic.FloatArithmetic._ import radiance.tile._ import radiance.memory._ +import radiance.subsystem.RadianceGemminiDataType.{BF16, FP16, FP32, Int8} case class RadianceSharedMemKey(address: BigInt, size: Int, @@ -84,9 +87,14 @@ class WithRadianceCores( ), useVxCache) } -class WithRadianceGemmini(location: HierarchicalLocation, - crossing: RocketCrossingParams, - dim: Int, accSizeInKB: Int, tileSize: Int) extends Config((site, _, up) => { +object RadianceGemminiDataType extends Enumeration { + type Type = Value + val FP32, FP16, BF16, Int8 = Value +} + +class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams, + dim: Int, accSizeInKB: Int, tileSize: Int, + dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => { case TilesLocated(`location`) => { val prev = up(TilesLocated(`location`)) val idOffset = up(NumTiles) @@ -100,7 +108,31 @@ class WithRadianceGemmini(location: HierarchicalLocation, }.sum val smKey = site(RadianceSharedMemKey).get val tileParams = GemminiTileParams( - gemminiConfig = GemminiFPConfigs.FP32DefaultConfig.copy( + gemminiConfig = { + implicit val arithmetic: Arithmetic[Float] = + Arithmetic.FloatArithmetic.asInstanceOf[Arithmetic[Float]] + dataType match { + case FP32 => GemminiFPConfigs.FP32DefaultConfig + case FP16 => GemminiFPConfigs.FP16DefaultConfig.copy( + acc_scale_args = Some(ScaleArguments( + (t: Float, u: Float) => {t}, + 1, Float(8, 24), -1, identity = "1.0", c_str = "((x))" + )), + mvin_scale_args = Some(ScaleArguments( + (t: Float, u: Float) => t * u, + 1, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))" + )), + mvin_scale_acc_args = None, + has_training_convs = false, + // hardcode_d_to_garbage_addr = true, + acc_read_full_width = false, // set to true to output fp32 + ) + case BF16 => GemminiFPConfigs.BF16DefaultConfig + // TODO: Int8 + }}.copy( + dataflow = Dataflow.WS, + ex_read_from_acc = false, + ex_write_to_spad = false, has_training_convs = false, has_max_pool = false, use_tl_ext_mem = true, @@ -112,8 +144,10 @@ class WithRadianceGemmini(location: HierarchicalLocation, meshRows = dim, meshColumns = dim, tile_latency = 0, + mesh_output_delay = 1, + acc_latency = 3, dma_maxbytes = site(CacheBlockBytes), - dma_buswidth = 256, // TODO: parameterize + dma_buswidth = dmaBytes, tl_ext_mem_base = smKey.address, sp_banks = smKey.numBanks, sp_capacity = CapacityInKilobytes(smKey.size >> 10), @@ -130,7 +164,8 @@ class WithRadianceGemmini(location: HierarchicalLocation, } case NumTiles => up(NumTiles) + 1 }) { - def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int) = + def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int, + dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) = this(location, RocketCrossingParams( master = HierarchicalElementMasterPortParams.locationDefault(location), slave = HierarchicalElementSlavePortParams.locationDefault(location), @@ -138,7 +173,7 @@ class WithRadianceGemmini(location: HierarchicalLocation, case InSubsystem => CBUS case InCluster(clusterId) => CCBUS(clusterId) } - ), dim, accSizeInKB, tileSize) + ), dim, accSizeInKB, tileSize, dataType, dmaBytes) } class WithRadianceSharedMem(address: BigInt, diff --git a/src/main/scala/radiance/tile/RadianceCluster.scala b/src/main/scala/radiance/tile/RadianceCluster.scala index f2a79f4..3890180 100644 --- a/src/main/scala/radiance/tile/RadianceCluster.scala +++ b/src/main/scala/radiance/tile/RadianceCluster.scala @@ -44,7 +44,10 @@ class RadianceCluster ( val gemminiTiles = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).toSeq.asInstanceOf[Seq[GemminiTile]] val gemminis = gemminiTiles.map(_.gemmini) val gemminiConfigs = gemminis.map(_.config) - // val gemminiConfig = thisClusterParams.gemminiConfig.get // TODO: handle None gracefully + + if (!(gemminiConfigs.tail.map(_.inputType == gemminiConfigs.head.inputType).reduce(_ && _))) { + println("******** WARNING ********\n******** gemmini data types do not match\n******** WARNING ********") + } val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).toSeq.asInstanceOf[Seq[RadianceTile]]