fp16 gemmini support

This commit is contained in:
Richard Yan
2024-08-06 02:32:35 -07:00
parent 0d9c2ca6ad
commit af60ae3332
2 changed files with 47 additions and 9 deletions

View File

@@ -3,14 +3,17 @@
package radiance.subsystem
import chisel3._
import chisel3.util._
import org.chipsalliance.cde.config._
import freechips.rocketchip.rocket._
import freechips.rocketchip.tile._
import freechips.rocketchip.subsystem._
import gemmini.{CapacityInKilobytes, GemminiFPConfigs}
import gemmini._
import gemmini.Arithmetic.FloatArithmetic._
import radiance.tile._
import radiance.memory._
import radiance.subsystem.RadianceGemminiDataType.{BF16, FP16, FP32, Int8}
case class RadianceSharedMemKey(address: BigInt,
size: Int,
@@ -84,9 +87,14 @@ class WithRadianceCores(
), useVxCache)
}
class WithRadianceGemmini(location: HierarchicalLocation,
crossing: RocketCrossingParams,
dim: Int, accSizeInKB: Int, tileSize: Int) extends Config((site, _, up) => {
object RadianceGemminiDataType extends Enumeration {
type Type = Value
val FP32, FP16, BF16, Int8 = Value
}
class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams,
dim: Int, accSizeInKB: Int, tileSize: Int,
dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => {
case TilesLocated(`location`) => {
val prev = up(TilesLocated(`location`))
val idOffset = up(NumTiles)
@@ -100,7 +108,31 @@ class WithRadianceGemmini(location: HierarchicalLocation,
}.sum
val smKey = site(RadianceSharedMemKey).get
val tileParams = GemminiTileParams(
gemminiConfig = GemminiFPConfigs.FP32DefaultConfig.copy(
gemminiConfig = {
implicit val arithmetic: Arithmetic[Float] =
Arithmetic.FloatArithmetic.asInstanceOf[Arithmetic[Float]]
dataType match {
case FP32 => GemminiFPConfigs.FP32DefaultConfig
case FP16 => GemminiFPConfigs.FP16DefaultConfig.copy(
acc_scale_args = Some(ScaleArguments(
(t: Float, u: Float) => {t},
1, Float(8, 24), -1, identity = "1.0", c_str = "((x))"
)),
mvin_scale_args = Some(ScaleArguments(
(t: Float, u: Float) => t * u,
1, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))"
)),
mvin_scale_acc_args = None,
has_training_convs = false,
// hardcode_d_to_garbage_addr = true,
acc_read_full_width = false, // set to true to output fp32
)
case BF16 => GemminiFPConfigs.BF16DefaultConfig
// TODO: Int8
}}.copy(
dataflow = Dataflow.WS,
ex_read_from_acc = false,
ex_write_to_spad = false,
has_training_convs = false,
has_max_pool = false,
use_tl_ext_mem = true,
@@ -112,8 +144,10 @@ class WithRadianceGemmini(location: HierarchicalLocation,
meshRows = dim,
meshColumns = dim,
tile_latency = 0,
mesh_output_delay = 1,
acc_latency = 3,
dma_maxbytes = site(CacheBlockBytes),
dma_buswidth = 256, // TODO: parameterize
dma_buswidth = dmaBytes,
tl_ext_mem_base = smKey.address,
sp_banks = smKey.numBanks,
sp_capacity = CapacityInKilobytes(smKey.size >> 10),
@@ -130,7 +164,8 @@ class WithRadianceGemmini(location: HierarchicalLocation,
}
case NumTiles => up(NumTiles) + 1
}) {
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int) =
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int,
dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) =
this(location, RocketCrossingParams(
master = HierarchicalElementMasterPortParams.locationDefault(location),
slave = HierarchicalElementSlavePortParams.locationDefault(location),
@@ -138,7 +173,7 @@ class WithRadianceGemmini(location: HierarchicalLocation,
case InSubsystem => CBUS
case InCluster(clusterId) => CCBUS(clusterId)
}
), dim, accSizeInKB, tileSize)
), dim, accSizeInKB, tileSize, dataType, dmaBytes)
}
class WithRadianceSharedMem(address: BigInt,

View File

@@ -44,7 +44,10 @@ class RadianceCluster (
val gemminiTiles = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).toSeq.asInstanceOf[Seq[GemminiTile]]
val gemminis = gemminiTiles.map(_.gemmini)
val gemminiConfigs = gemminis.map(_.config)
// val gemminiConfig = thisClusterParams.gemminiConfig.get // TODO: handle None gracefully
if (!(gemminiConfigs.tail.map(_.inputType == gemminiConfigs.head.inputType).reduce(_ && _))) {
println("******** WARNING ********\n******** gemmini data types do not match\n******** WARNING ********")
}
val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).toSeq.asInstanceOf[Seq[RadianceTile]]