fp16 gemmini support
This commit is contained in:
@@ -3,14 +3,17 @@
|
||||
|
||||
package radiance.subsystem
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import org.chipsalliance.cde.config._
|
||||
import freechips.rocketchip.rocket._
|
||||
import freechips.rocketchip.tile._
|
||||
import freechips.rocketchip.subsystem._
|
||||
import gemmini.{CapacityInKilobytes, GemminiFPConfigs}
|
||||
import gemmini._
|
||||
import gemmini.Arithmetic.FloatArithmetic._
|
||||
import radiance.tile._
|
||||
import radiance.memory._
|
||||
import radiance.subsystem.RadianceGemminiDataType.{BF16, FP16, FP32, Int8}
|
||||
|
||||
case class RadianceSharedMemKey(address: BigInt,
|
||||
size: Int,
|
||||
@@ -84,9 +87,14 @@ class WithRadianceCores(
|
||||
), useVxCache)
|
||||
}
|
||||
|
||||
class WithRadianceGemmini(location: HierarchicalLocation,
|
||||
crossing: RocketCrossingParams,
|
||||
dim: Int, accSizeInKB: Int, tileSize: Int) extends Config((site, _, up) => {
|
||||
object RadianceGemminiDataType extends Enumeration {
|
||||
type Type = Value
|
||||
val FP32, FP16, BF16, Int8 = Value
|
||||
}
|
||||
|
||||
class WithRadianceGemmini(location: HierarchicalLocation, crossing: RocketCrossingParams,
|
||||
dim: Int, accSizeInKB: Int, tileSize: Int,
|
||||
dataType: RadianceGemminiDataType.Type, dmaBytes: Int) extends Config((site, _, up) => {
|
||||
case TilesLocated(`location`) => {
|
||||
val prev = up(TilesLocated(`location`))
|
||||
val idOffset = up(NumTiles)
|
||||
@@ -100,7 +108,31 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
||||
}.sum
|
||||
val smKey = site(RadianceSharedMemKey).get
|
||||
val tileParams = GemminiTileParams(
|
||||
gemminiConfig = GemminiFPConfigs.FP32DefaultConfig.copy(
|
||||
gemminiConfig = {
|
||||
implicit val arithmetic: Arithmetic[Float] =
|
||||
Arithmetic.FloatArithmetic.asInstanceOf[Arithmetic[Float]]
|
||||
dataType match {
|
||||
case FP32 => GemminiFPConfigs.FP32DefaultConfig
|
||||
case FP16 => GemminiFPConfigs.FP16DefaultConfig.copy(
|
||||
acc_scale_args = Some(ScaleArguments(
|
||||
(t: Float, u: Float) => {t},
|
||||
1, Float(8, 24), -1, identity = "1.0", c_str = "((x))"
|
||||
)),
|
||||
mvin_scale_args = Some(ScaleArguments(
|
||||
(t: Float, u: Float) => t * u,
|
||||
1, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))"
|
||||
)),
|
||||
mvin_scale_acc_args = None,
|
||||
has_training_convs = false,
|
||||
// hardcode_d_to_garbage_addr = true,
|
||||
acc_read_full_width = false, // set to true to output fp32
|
||||
)
|
||||
case BF16 => GemminiFPConfigs.BF16DefaultConfig
|
||||
// TODO: Int8
|
||||
}}.copy(
|
||||
dataflow = Dataflow.WS,
|
||||
ex_read_from_acc = false,
|
||||
ex_write_to_spad = false,
|
||||
has_training_convs = false,
|
||||
has_max_pool = false,
|
||||
use_tl_ext_mem = true,
|
||||
@@ -112,8 +144,10 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
||||
meshRows = dim,
|
||||
meshColumns = dim,
|
||||
tile_latency = 0,
|
||||
mesh_output_delay = 1,
|
||||
acc_latency = 3,
|
||||
dma_maxbytes = site(CacheBlockBytes),
|
||||
dma_buswidth = 256, // TODO: parameterize
|
||||
dma_buswidth = dmaBytes,
|
||||
tl_ext_mem_base = smKey.address,
|
||||
sp_banks = smKey.numBanks,
|
||||
sp_capacity = CapacityInKilobytes(smKey.size >> 10),
|
||||
@@ -130,7 +164,8 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
||||
}
|
||||
case NumTiles => up(NumTiles) + 1
|
||||
}) {
|
||||
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int) =
|
||||
def this(location: HierarchicalLocation = InSubsystem, dim: Int, accSizeInKB: Int, tileSize: Int,
|
||||
dataType: RadianceGemminiDataType.Type = RadianceGemminiDataType.FP32, dmaBytes: Int = 256) =
|
||||
this(location, RocketCrossingParams(
|
||||
master = HierarchicalElementMasterPortParams.locationDefault(location),
|
||||
slave = HierarchicalElementSlavePortParams.locationDefault(location),
|
||||
@@ -138,7 +173,7 @@ class WithRadianceGemmini(location: HierarchicalLocation,
|
||||
case InSubsystem => CBUS
|
||||
case InCluster(clusterId) => CCBUS(clusterId)
|
||||
}
|
||||
), dim, accSizeInKB, tileSize)
|
||||
), dim, accSizeInKB, tileSize, dataType, dmaBytes)
|
||||
}
|
||||
|
||||
class WithRadianceSharedMem(address: BigInt,
|
||||
|
||||
@@ -44,7 +44,10 @@ class RadianceCluster (
|
||||
val gemminiTiles = leafTiles.values.filter(_.isInstanceOf[GemminiTile]).toSeq.asInstanceOf[Seq[GemminiTile]]
|
||||
val gemminis = gemminiTiles.map(_.gemmini)
|
||||
val gemminiConfigs = gemminis.map(_.config)
|
||||
// val gemminiConfig = thisClusterParams.gemminiConfig.get // TODO: handle None gracefully
|
||||
|
||||
if (!(gemminiConfigs.tail.map(_.inputType == gemminiConfigs.head.inputType).reduce(_ && _))) {
|
||||
println("******** WARNING ********\n******** gemmini data types do not match\n******** WARNING ********")
|
||||
}
|
||||
|
||||
val radianceTiles = leafTiles.values.filter(_.isInstanceOf[RadianceTile]).toSeq.asInstanceOf[Seq[RadianceTile]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user