把混和精度模型权重导出至本地文件,减少dtype开销
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
--export_only
This commit is contained in:
@@ -441,57 +441,143 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
except:
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
model.load_state_dict(new_pl_sd, strict=True)
|
||||
else:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
def _load_state_dict(model: nn.Module,
|
||||
state_dict: Mapping[str, torch.Tensor],
|
||||
strict: bool = True,
|
||||
assign: bool = False) -> None:
|
||||
if assign:
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=strict, assign=True)
|
||||
return
|
||||
except TypeError:
|
||||
warnings.warn(
|
||||
"load_state_dict(assign=True) not supported; "
|
||||
"falling back to copy load.")
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module,
|
||||
ckpt: str,
|
||||
assign: bool | None = None) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
assign (bool | None): Whether to preserve checkpoint tensor dtypes
|
||||
via load_state_dict(assign=True). If None, auto-enable when a
|
||||
casted checkpoint metadata is detected.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
ckpt_data = torch.load(ckpt, map_location="cpu")
|
||||
use_assign = False
|
||||
if assign is not None:
|
||||
use_assign = assign
|
||||
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
|
||||
use_assign = True
|
||||
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
|
||||
state_dict = ckpt_data["state_dict"]
|
||||
try:
|
||||
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
|
||||
except Exception:
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
_load_state_dict(model,
|
||||
new_pl_sd,
|
||||
strict=True,
|
||||
assign=use_assign)
|
||||
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in ckpt_data['module'].keys():
|
||||
new_pl_sd[key[16:]] = ckpt_data['module'][key]
|
||||
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
|
||||
else:
|
||||
_load_state_dict(model,
|
||||
ckpt_data,
|
||||
strict=True,
|
||||
assign=use_assign)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def maybe_cast_module(module: nn.Module | None,
|
||||
dtype: torch.dtype,
|
||||
label: str,
|
||||
profiler: Optional[ProfilerManager] = None,
|
||||
profile_name: Optional[str] = None) -> None:
|
||||
if module is None:
|
||||
return
|
||||
try:
|
||||
param = next(module.parameters())
|
||||
except StopIteration:
|
||||
print(f">>> {label} has no parameters; skip cast")
|
||||
return
|
||||
if param.dtype == dtype:
|
||||
print(f">>> {label} already {dtype}; skip cast")
|
||||
return
|
||||
ctx = nullcontext()
|
||||
if profiler is not None and profile_name:
|
||||
ctx = profiler.profile_section(profile_name)
|
||||
with ctx:
|
||||
module.to(dtype=dtype)
|
||||
print(f">>> {label} cast to {dtype}")
|
||||
|
||||
|
||||
def save_casted_checkpoint(model: nn.Module,
|
||||
save_path: str,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not save_path:
|
||||
return
|
||||
save_dir = os.path.dirname(save_path)
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
cpu_state = {}
|
||||
for key, value in model.state_dict().items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
cpu_state[key] = value.detach().to("cpu")
|
||||
else:
|
||||
cpu_state[key] = value
|
||||
payload: Dict[str, Any] = {"state_dict": cpu_state}
|
||||
if metadata:
|
||||
payload["precision_metadata"] = metadata
|
||||
torch.save(payload, save_path)
|
||||
print(f">>> Saved casted checkpoint to {save_path}")
|
||||
|
||||
|
||||
def _module_param_dtype(module: nn.Module | None) -> str:
|
||||
if module is None:
|
||||
return "None"
|
||||
dtype_counts: Dict[str, int] = {}
|
||||
for param in module.parameters():
|
||||
return str(param.dtype)
|
||||
return "no_params"
|
||||
dtype_key = str(param.dtype)
|
||||
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
|
||||
if not dtype_counts:
|
||||
return "no_params"
|
||||
if len(dtype_counts) == 1:
|
||||
return next(iter(dtype_counts))
|
||||
total = sum(dtype_counts.values())
|
||||
parts = []
|
||||
for dtype_key in sorted(dtype_counts.keys()):
|
||||
ratio = dtype_counts[dtype_key] / total
|
||||
parts.append(f"{dtype_key}={ratio:.1%}")
|
||||
return f"mixed({', '.join(parts)})"
|
||||
|
||||
|
||||
def log_inference_precision(model: nn.Module) -> None:
|
||||
try:
|
||||
param = next(model.parameters())
|
||||
device = "unknown"
|
||||
for param in model.parameters():
|
||||
device = str(param.device)
|
||||
model_dtype = str(param.dtype)
|
||||
except StopIteration:
|
||||
device = "unknown"
|
||||
model_dtype = "no_params"
|
||||
break
|
||||
model_dtype = _module_param_dtype(model)
|
||||
|
||||
print(f">>> inference precision: model={model_dtype}, device={device}")
|
||||
for attr in [
|
||||
@@ -966,16 +1052,25 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
diffusion_autocast_dtype = None
|
||||
if args.diffusion_dtype == "bf16":
|
||||
with profiler.profile_section("model_loading/diffusion_bf16"):
|
||||
model.model.to(dtype=torch.bfloat16)
|
||||
maybe_cast_module(
|
||||
model.model,
|
||||
torch.bfloat16,
|
||||
"diffusion backbone",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/diffusion_bf16",
|
||||
)
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(dtype=torch.bfloat16)
|
||||
else:
|
||||
model.first_stage_model.to(dtype=torch.float32)
|
||||
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
||||
maybe_cast_module(
|
||||
model.first_stage_model,
|
||||
vae_weight_dtype,
|
||||
"VAE",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/vae_cast",
|
||||
)
|
||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||
|
||||
@@ -983,9 +1078,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
||||
model.cond_stage_model.to(dtype=encoder_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.cond_stage_model,
|
||||
encoder_weight_dtype,
|
||||
"cond_stage_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_cond_cast",
|
||||
)
|
||||
if hasattr(model, "embedder") and model.embedder is not None:
|
||||
model.embedder.to(dtype=encoder_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.embedder,
|
||||
encoder_weight_dtype,
|
||||
"embedder",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/encoder_embedder_cast",
|
||||
)
|
||||
model.encoder_bf16 = encoder_bf16
|
||||
model.encoder_mode = encoder_mode
|
||||
print(
|
||||
@@ -996,11 +1103,29 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
||||
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
||||
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
||||
model.image_proj_model.to(dtype=projector_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.image_proj_model,
|
||||
projector_weight_dtype,
|
||||
"image_proj_model",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_image_cast",
|
||||
)
|
||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||
model.state_projector.to(dtype=projector_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.state_projector,
|
||||
projector_weight_dtype,
|
||||
"state_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_state_cast",
|
||||
)
|
||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||
model.action_projector.to(dtype=projector_weight_dtype)
|
||||
maybe_cast_module(
|
||||
model.action_projector,
|
||||
projector_weight_dtype,
|
||||
"action_projector",
|
||||
profiler=profiler,
|
||||
profile_name="model_loading/projector_action_cast",
|
||||
)
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = projector_bf16
|
||||
model.projector_mode = projector_mode
|
||||
@@ -1010,6 +1135,19 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
if args.export_casted_ckpt:
|
||||
metadata = {
|
||||
"diffusion_dtype": args.diffusion_dtype,
|
||||
"vae_dtype": args.vae_dtype,
|
||||
"encoder_mode": args.encoder_mode,
|
||||
"projector_mode": args.projector_mode,
|
||||
"perframe_ae": args.perframe_ae,
|
||||
}
|
||||
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
|
||||
if args.export_only:
|
||||
print(">>> export_only set; skipping inference.")
|
||||
return
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
|
||||
# Run over data
|
||||
@@ -1373,6 +1511,19 @@ def get_parser():
|
||||
default="fp32",
|
||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_casted_ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export_only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Exit after exporting the casted checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step_log_every",
|
||||
type=int,
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,7 +4,7 @@ dataset="unitree_g1_pack_camera"
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--savedir "${res_dir}/output" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
@@ -22,7 +22,6 @@ dataset="unitree_g1_pack_camera"
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae \
|
||||
--diffusion_dtype bf16 \
|
||||
--projector_mode autocast \
|
||||
--encoder_mode bf16_full \
|
||||
--vae_dtype bf16
|
||||
--projector_mode bf16_full \
|
||||
--encoder_mode bf16_full
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Binary file not shown.
Binary file not shown.
42
useful.sh
42
useful.sh
@@ -118,4 +118,44 @@ embedder:
|
||||
2. attention mask 缓存到 GPU(已完成)
|
||||
|
||||
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
|
||||
- 位置:src/unifolm_wma/modules/attention.py
|
||||
- 位置:src/unifolm_wma/modules/attention.py
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||
--export_only
|
||||
|
||||
|
||||
checkpoint or previous run; diffusion backbone and condition embedder are bfloat16 while
|
||||
projector is fp32, so casts are skipped as expected. The "inference precision" line reflects the
|
||||
top-level param dtype (possibly fp32 embeddings), explaining the apparent inconsistency with
|
||||
submodule param dtypes.
|
||||
|
||||
Explaining skip cast due to preserved dtypes
|
||||
|
||||
The skip cast happens because checkpoint loading preserves stored dtypes, so weights already
|
||||
match the target (bf16 or fp32), making casting unnecessary. The log shows first param dtype for
|
||||
checks, which can differ from submodules but is expected. For forcing recast, exporting a new
|
||||
casted checkpoint is needed since repeated casts to the same dtype no-op.
|
||||
|
||||
• “skip cast” 的意思是该模块的权重已经是目标 dtype,所以 maybe_cast_module 直接跳过转换。你这段日志
|
||||
里:
|
||||
|
||||
- diffusion backbone / cond_stage_model / embedder 已经是 torch.bfloat16
|
||||
- projector 仍是 torch.float32(因为你设的是 projector_mode=autocast,权重保持 fp32)
|
||||
- VAE 你设置为 fp32,也已经是 fp32
|
||||
|
||||
所以全部都匹配目标精度了,自然会 “skip cast”。这通常说明你加载的 ckpt 本身就是混合精度权重(比如
|
||||
之前用 --export_casted_ckpt 导出的)。
|
||||
|
||||
另外这行:
|
||||
|
||||
>>> inference precision: model=torch.float32
|
||||
|
||||
只是取了模型第一个参数的 dtype(log_inference_precision 的实现如此),模型里可能还有一些主模块外
|
||||
的 fp32 参数(比如位置编码等),所以看到整体是 fp32 很正常,不代表 diffusion 不在 bf16。后面的子
|
||||
模块打印才更准确。
|
||||
|
||||
如果你想确认是否从“原始 fp32 ckpt”重新 cast,只要把 --ckpt_path 指回原始 ckpt 跑一次即可。
|
||||
Reference in New Issue
Block a user