把混和精度模型权重导出至本地文件,减少dtype开销

--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
        --export_only
This commit is contained in:
2026-01-19 15:14:01 +08:00
parent cb334f308b
commit 7e501b17fd
20 changed files with 245 additions and 55 deletions

View File

@@ -441,22 +441,47 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
def _load_state_dict(model: nn.Module,
state_dict: Mapping[str, torch.Tensor],
strict: bool = True,
assign: bool = False) -> None:
if assign:
try:
model.load_state_dict(state_dict, strict=strict, assign=True)
return
except TypeError:
warnings.warn(
"load_state_dict(assign=True) not supported; "
"falling back to copy load.")
model.load_state_dict(state_dict, strict=strict)
def load_model_checkpoint(model: nn.Module,
ckpt: str,
assign: bool | None = None) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
assign (bool | None): Whether to preserve checkpoint tensor dtypes
via load_state_dict(assign=True). If None, auto-enable when a
casted checkpoint metadata is detected.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
ckpt_data = torch.load(ckpt, map_location="cpu")
use_assign = False
if assign is not None:
use_assign = assign
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
use_assign = True
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
state_dict = ckpt_data["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except:
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
except Exception:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
@@ -466,32 +491,93 @@ def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True)
else:
_load_state_dict(model,
new_pl_sd,
strict=True,
assign=use_assign)
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
for key in ckpt_data['module'].keys():
new_pl_sd[key[16:]] = ckpt_data['module'][key]
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
else:
_load_state_dict(model,
ckpt_data,
strict=True,
assign=use_assign)
print('>>> model checkpoint loaded.')
return model
def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype,
label: str,
profiler: Optional[ProfilerManager] = None,
profile_name: Optional[str] = None) -> None:
if module is None:
return
try:
param = next(module.parameters())
except StopIteration:
print(f">>> {label} has no parameters; skip cast")
return
if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast")
return
ctx = nullcontext()
if profiler is not None and profile_name:
ctx = profiler.profile_section(profile_name)
with ctx:
module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}")
def save_casted_checkpoint(model: nn.Module,
save_path: str,
metadata: Optional[Dict[str, Any]] = None) -> None:
if not save_path:
return
save_dir = os.path.dirname(save_path)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
cpu_state = {}
for key, value in model.state_dict().items():
if isinstance(value, torch.Tensor):
cpu_state[key] = value.detach().to("cpu")
else:
cpu_state[key] = value
payload: Dict[str, Any] = {"state_dict": cpu_state}
if metadata:
payload["precision_metadata"] = metadata
torch.save(payload, save_path)
print(f">>> Saved casted checkpoint to {save_path}")
def _module_param_dtype(module: nn.Module | None) -> str:
if module is None:
return "None"
dtype_counts: Dict[str, int] = {}
for param in module.parameters():
return str(param.dtype)
dtype_key = str(param.dtype)
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
if not dtype_counts:
return "no_params"
if len(dtype_counts) == 1:
return next(iter(dtype_counts))
total = sum(dtype_counts.values())
parts = []
for dtype_key in sorted(dtype_counts.keys()):
ratio = dtype_counts[dtype_key] / total
parts.append(f"{dtype_key}={ratio:.1%}")
return f"mixed({', '.join(parts)})"
def log_inference_precision(model: nn.Module) -> None:
try:
param = next(model.parameters())
device = str(param.device)
model_dtype = str(param.dtype)
except StopIteration:
device = "unknown"
model_dtype = "no_params"
for param in model.parameters():
device = str(param.device)
break
model_dtype = _module_param_dtype(model)
print(f">>> inference precision: model={model_dtype}, device={device}")
for attr in [
@@ -966,16 +1052,25 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
with profiler.profile_section("model_loading/diffusion_bf16"):
model.model.to(dtype=torch.bfloat16)
maybe_cast_module(
model.model,
torch.bfloat16,
"diffusion backbone",
profiler=profiler,
profile_name="model_loading/diffusion_bf16",
)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
if args.vae_dtype == "bf16":
model.first_stage_model.to(dtype=torch.bfloat16)
else:
model.first_stage_model.to(dtype=torch.float32)
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
maybe_cast_module(
model.first_stage_model,
vae_weight_dtype,
"VAE",
profiler=profiler,
profile_name="model_loading/vae_cast",
)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
@@ -983,9 +1078,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
model.cond_stage_model.to(dtype=encoder_weight_dtype)
maybe_cast_module(
model.cond_stage_model,
encoder_weight_dtype,
"cond_stage_model",
profiler=profiler,
profile_name="model_loading/encoder_cond_cast",
)
if hasattr(model, "embedder") and model.embedder is not None:
model.embedder.to(dtype=encoder_weight_dtype)
maybe_cast_module(
model.embedder,
encoder_weight_dtype,
"embedder",
profiler=profiler,
profile_name="model_loading/encoder_embedder_cast",
)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
print(
@@ -996,11 +1103,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
projector_bf16 = projector_mode in ("autocast", "bf16_full")
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
model.image_proj_model.to(dtype=projector_weight_dtype)
maybe_cast_module(
model.image_proj_model,
projector_weight_dtype,
"image_proj_model",
profiler=profiler,
profile_name="model_loading/projector_image_cast",
)
if hasattr(model, "state_projector") and model.state_projector is not None:
model.state_projector.to(dtype=projector_weight_dtype)
maybe_cast_module(
model.state_projector,
projector_weight_dtype,
"state_projector",
profiler=profiler,
profile_name="model_loading/projector_state_cast",
)
if hasattr(model, "action_projector") and model.action_projector is not None:
model.action_projector.to(dtype=projector_weight_dtype)
maybe_cast_module(
model.action_projector,
projector_weight_dtype,
"action_projector",
profiler=profiler,
profile_name="model_loading/projector_action_cast",
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16
model.projector_mode = projector_mode
@@ -1010,6 +1135,19 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
log_inference_precision(model)
if args.export_casted_ckpt:
metadata = {
"diffusion_dtype": args.diffusion_dtype,
"vae_dtype": args.vae_dtype,
"encoder_mode": args.encoder_mode,
"projector_mode": args.projector_mode,
"perframe_ae": args.perframe_ae,
}
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
if args.export_only:
print(">>> export_only set; skipping inference.")
return
profiler.record_memory("after_model_load")
# Run over data
@@ -1373,6 +1511,19 @@ def get_parser():
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--export_casted_ckpt",
type=str,
default=None,
help=
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
)
parser.add_argument(
"--export_only",
action='store_true',
default=False,
help="Exit after exporting the casted checkpoint."
)
parser.add_argument(
"--step_log_every",
type=int,

View File

@@ -4,7 +4,7 @@ dataset="unitree_g1_pack_camera"
{
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
--savedir "${res_dir}/output" \
--bs 1 --height 320 --width 512 \
@@ -22,7 +22,6 @@ dataset="unitree_g1_pack_camera"
--guidance_rescale 0.7 \
--perframe_ae \
--diffusion_dtype bf16 \
--projector_mode autocast \
--encoder_mode bf16_full \
--vae_dtype bf16
--projector_mode bf16_full \
--encoder_mode bf16_full
} 2>&1 | tee "${res_dir}/output.log"

View File

@@ -119,3 +119,43 @@ embedder
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
- 位置src/unifolm_wma/modules/attention.py
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
--export_only
checkpoint or previous run; diffusion backbone and condition embedder are bfloat16 while
projector is fp32, so casts are skipped as expected. The "inference precision" line reflects the
top-level param dtype (possibly fp32 embeddings), explaining the apparent inconsistency with
submodule param dtypes.
Explaining skip cast due to preserved dtypes
The skip cast happens because checkpoint loading preserves stored dtypes, so weights already
match the target (bf16 or fp32), making casting unnecessary. The log shows first param dtype for
checks, which can differ from submodules but is expected. For forcing recast, exporting a new
casted checkpoint is needed since repeated casts to the same dtype no-op.
• “skip cast” 的意思是该模块的权重已经是目标 dtype所以 maybe_cast_module 直接跳过转换。你这段日志
里:
- diffusion backbone / cond_stage_model / embedder 已经是 torch.bfloat16
- projector 仍是 torch.float32因为你设的是 projector_mode=autocast权重保持 fp32
- VAE 你设置为 fp32也已经是 fp32
所以全部都匹配目标精度了,自然会 “skip cast”。这通常说明你加载的 ckpt 本身就是混合精度权重(比如
之前用 --export_casted_ckpt 导出的)。
另外这行:
>>> inference precision: model=torch.float32
只是取了模型第一个参数的 dtypelog_inference_precision 的实现如此),模型里可能还有一些主模块外
的 fp32 参数(比如位置编码等),所以看到整体是 fp32 很正常,不代表 diffusion 不在 bf16。后面的子
模块打印才更准确。
如果你想确认是否从“原始 fp32 ckpt”重新 cast只要把 --ckpt_path 指回原始 ckpt 跑一次即可。