对扩散主干做 BF16

量化对象:model.model(扩散 UNet/WMAModel 主体)
This commit is contained in:
2026-01-18 17:14:16 +08:00
parent 7b499284bf
commit 2b634cde90
5 changed files with 141 additions and 75 deletions

View File

@@ -714,22 +714,23 @@ def preprocess_observation(
return return_observations
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
diffusion_autocast_dtype: Optional[torch.dtype] = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -750,9 +751,10 @@ def image_guided_synthesis_sim_mode(
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
**kwargs: Additional arguments passed to the DDIM sampler.
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
@@ -810,31 +812,37 @@ def image_guided_synthesis_sim_mode(
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"):
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
cond_z0 = None
if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"):
autocast_ctx = nullcontext()
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if samples.dtype != torch.float32:
samples = samples.float()
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
@@ -889,6 +897,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
with profiler.profile_section("model_loading/diffusion_bf16"):
model.model.to(dtype=torch.bfloat16)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
log_inference_precision(model)
profiler.record_memory("after_model_load")
@@ -1014,20 +1029,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False)
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues
with profiler.profile_section("update_action_queues"):
@@ -1058,20 +1074,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale)
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype)
with profiler.profile_section("update_state_queues"):
for step_idx in range(args.exe_steps):
@@ -1216,13 +1233,20 @@ def get_parser():
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for diffusion backbone weights and sampling autocast."
)
parser.add_argument(
"--n_action_steps",
type=int,

View File

@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
dataset="unitree_g1_pack_camera"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \
@@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera"
--n_iter 11 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
--perframe_ae \
--diffusion_dtype bf16
} 2>&1 | tee "${res_dir}/output.log"

View File

@@ -29,4 +29,45 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit
1. torch.compile + AMP + TF32 + cudnn.benchmark
2. 排查 .to()/copy/clone 的重复位置并移出循环
3. 若需要更大幅度,再换采样器/降步数
3. 若需要更大幅度,再换采样器/降步数
A100 上我推荐 BF16 优先稳定性更好、PSNR 更稳FP16 作为速度优先方案。
下面是“分模块”的 消融方案(从稳到激进):
0基线
- 全 FP32你现在就是这个
1只对扩散主干做 BF16最推荐
- 量化对象model.model扩散 UNet/WMAModel 主体)
- 保持 FP32first_stage_modelVAE 编/解码、cond_stage_model文本、embedder图像、image_proj_model
- 预期PSNR 基本不掉 or 极小波动
2+ 轻量投影/MLP 做 BF16
- 增加image_proj_model、state_projector、action_projector
- 预期:几乎不影响 PSNR
3+ 文本/图像编码做 BF16
- 增加cond_stage_model、embedder
- 预期:可能有轻微波动,通常仍可接受
4VAE 也做 BF16最容易伤 PSNR
- 增加first_stage_model
- 预期:画质/PSNR 最敏感,建议最后做消融
———
具体建议A100
- 优先 BF16稳定性好于 FP16
- 只做半精度,不做 INT 量化:保持 PSNR
- VAE 尽量 FP32最影响画质的模块