From e588182642fb6897725d660f6c5ec14a7b2d497b Mon Sep 17 00:00:00 2001 From: olivame Date: Sun, 8 Feb 2026 12:35:59 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=B7=B7=E5=90=88=E7=B2=BE?= =?UTF-8?q?=E5=BA=A6vae=E7=9B=B8=E5=85=B3=E7=9A=84=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E9=94=99=E8=AF=AF=EF=BC=8C=E7=A1=AE=E4=BF=9D=E5=9C=A8=E6=8E=A8?= =?UTF-8?q?=E7=90=86=E9=98=B6=E6=AE=B5=E6=AD=A3=E7=A1=AE=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E4=BA=86=E6=B7=B7=E5=90=88=E7=B2=BE=E5=BA=A6=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E4=B8=94=E5=AF=BC=E5=87=BA=E4=BA=86=E6=AD=A3?= =?UTF-8?q?=E7=A1=AE=E7=B2=BE=E5=BA=A6=E7=9A=84=E6=A3=80=E6=9F=A5=E7=82=B9?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/evaluation/world_model_interaction.py | 156 +++++++++++++++++- src/unifolm_wma/models/ddpms.py | 5 +- .../case1/output.log | 35 ++-- .../case1/psnr_result1.json | 2 +- .../case1/run_world_model_interaction.sh | 5 +- 5 files changed, 178 insertions(+), 25 deletions(-) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index ccc9747..8f18401 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -1,4 +1,5 @@ import argparse, os, glob +from contextlib import nullcontext import pandas as pd import random import torch @@ -38,6 +39,68 @@ def get_device_from_parameters(module: nn.Module) -> torch.device: return next(iter(module.parameters())).device +def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module: + """Apply precision settings to model components based on command-line arguments. + + Args: + model (nn.Module): The model to apply precision settings to. + args (argparse.Namespace): Parsed command-line arguments containing precision settings. + + Returns: + nn.Module: Model with precision settings applied. + """ + print(f">>> Applying precision settings:") + print(f" - Diffusion dtype: {args.diffusion_dtype}") + print(f" - Projector mode: {args.projector_mode}") + print(f" - Encoder mode: {args.encoder_mode}") + print(f" - VAE dtype: {args.vae_dtype}") + + # 1. Set Diffusion backbone precision + if args.diffusion_dtype == "bf16": + # Convert diffusion model weights to bf16 + model.model.to(torch.bfloat16) + model.diffusion_autocast_dtype = torch.bfloat16 + print(" ✓ Diffusion model weights converted to bfloat16") + else: + model.diffusion_autocast_dtype = None + print(" ✓ Diffusion model using fp32") + + # 2. Set Projector precision + if args.projector_mode == "bf16_full": + model.state_projector.to(torch.bfloat16) + model.action_projector.to(torch.bfloat16) + model.projector_autocast_dtype = None + print(" ✓ Projectors converted to bfloat16") + elif args.projector_mode == "autocast": + model.projector_autocast_dtype = torch.bfloat16 + print(" ✓ Projectors will use autocast (weights fp32, compute bf16)") + else: + model.projector_autocast_dtype = None + # fp32 mode: do nothing, keep original precision + + # 3. Set Encoder precision + if args.encoder_mode == "bf16_full": + model.embedder.to(torch.bfloat16) + model.image_proj_model.to(torch.bfloat16) + model.encoder_autocast_dtype = None + print(" ✓ Encoders converted to bfloat16") + elif args.encoder_mode == "autocast": + model.encoder_autocast_dtype = torch.bfloat16 + print(" ✓ Encoders will use autocast (weights fp32, compute bf16)") + else: + model.encoder_autocast_dtype = None + # fp32 mode: do nothing, keep original precision + + # 4. Set VAE precision + if args.vae_dtype == "bf16": + model.first_stage_model.to(torch.bfloat16) + print(" ✓ VAE converted to bfloat16") + else: + print(" ✓ VAE kept in fp32 for best quality") + + return model + + def write_video(video_path: str, stacked_frames: list, fps: int) -> None: """Save a list of frames to a video file. @@ -262,6 +325,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor: """ b, c, t, h, w = videos.shape x = rearrange(videos, 'b c t h w -> (b t) c h w') + + # Auto-detect VAE dtype and convert input + vae_dtype = next(model.first_stage_model.parameters()).dtype + x = x.to(dtype=vae_dtype) + z = model.encode_first_stage(x) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z @@ -363,10 +431,22 @@ def image_guided_synthesis_sim_mode( fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) + # Auto-detect model dtype and convert inputs accordingly + model_dtype = next(model.embedder.parameters()).dtype + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) - cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] - cond_img_emb = model.embedder(cond_img) - cond_img_emb = model.image_proj_model(cond_img_emb) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype) + + # Encoder autocast: weights stay fp32, compute in bf16 + enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None) + if enc_ac_dtype is not None and model.device.type == 'cuda': + enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype) + else: + enc_ctx = nullcontext() + + with enc_ctx: + cond_img_emb = model.embedder(cond_img) + cond_img_emb = model.image_proj_model(cond_img_emb) if model.model.conditioning_key == 'hybrid': z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) @@ -380,11 +460,22 @@ def image_guided_synthesis_sim_mode( prompts = [""] * batch_size cond_ins_emb = model.get_learned_conditioning(prompts) - cond_state_emb = model.state_projector(observation['observation.state']) - cond_state_emb = cond_state_emb + model.agent_state_pos_emb + # Auto-detect projector dtype and convert inputs + projector_dtype = next(model.state_projector.parameters()).dtype - cond_action_emb = model.action_projector(observation['action']) - cond_action_emb = cond_action_emb + model.agent_action_pos_emb + # Projector autocast: weights stay fp32, compute in bf16 + proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None) + if proj_ac_dtype is not None and model.device.type == 'cuda': + proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype) + else: + proj_ctx = nullcontext() + + with proj_ctx: + cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype)) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb + + cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype)) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) @@ -406,8 +497,17 @@ def image_guided_synthesis_sim_mode( kwargs.update({"unconditional_conditioning_img_nonetext": None}) cond_mask = None cond_z0 = None + + # Setup autocast context for diffusion sampling + autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None) + if autocast_dtype is not None and model.device.type == 'cuda': + autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype) + else: + autocast_ctx = nullcontext() + if ddim_sampler is not None: - samples, actions, states, intermedia = ddim_sampler.sample( + with autocast_ctx: + samples, actions, states, intermedia = ddim_sampler.sample( S=ddim_steps, conditioning=cond, batch_size=batch_size, @@ -464,6 +564,17 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: model.eval() print(f'>>> Load pre-trained model ...') + # Apply precision settings before moving to GPU + model = apply_precision_settings(model, args) + + # Export precision-converted checkpoint if requested + if args.export_precision_ckpt: + export_path = args.export_precision_ckpt + os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True) + torch.save({"state_dict": model.state_dict()}, export_path) + print(f">>> Precision-converted checkpoint saved to: {export_path}") + return + # Build unnomalizer logging.info("***** Configing Data *****") data = instantiate_from_config(config.data) @@ -798,6 +909,35 @@ def get_parser(): type=int, default=8, help="fps for the saving video") + parser.add_argument( + "--diffusion_dtype", + type=str, + choices=["fp32", "bf16"], + default="bf16", + help="Diffusion backbone precision (fp32/bf16)") + parser.add_argument( + "--projector_mode", + type=str, + choices=["fp32", "autocast", "bf16_full"], + default="bf16_full", + help="Projector precision mode (fp32/autocast/bf16_full)") + parser.add_argument( + "--encoder_mode", + type=str, + choices=["fp32", "autocast", "bf16_full"], + default="bf16_full", + help="Encoder precision mode (fp32/autocast/bf16_full)") + parser.add_argument( + "--vae_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="VAE precision (fp32/bf16, most affects image quality)") + parser.add_argument( + "--export_precision_ckpt", + type=str, + default=None, + help="Export precision-converted checkpoint to this path, then exit.") return parser diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py index fbf2042..383169d 100644 --- a/src/unifolm_wma/models/ddpms.py +++ b/src/unifolm_wma/models/ddpms.py @@ -1105,6 +1105,10 @@ class LatentDiffusion(DDPM): else: reshape_back = False + # Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE) + vae_dtype = next(self.first_stage_model.parameters()).dtype + z = z.to(dtype=vae_dtype) + if not self.perframe_ae: z = 1. / self.scale_factor * z results = self.first_stage_model.decode(z, **kwargs) @@ -2457,7 +2461,6 @@ class DiffusionWrapper(pl.LightningModule): Returns: Output from the inner diffusion model (tensor or tuple, depending on the model). """ - if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index 77ba312..9ee412e 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,12 +1,12 @@ -2026-02-08 09:20:29.036523: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-08 09:20:29.301726: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 09:20:29.656318: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-08 09:20:29.656367: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-08 09:20:29.662840: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-08 09:20:29.718736: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 09:20:29.718991: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-08 12:22:55.885867: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-08 12:22:55.890510: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 12:22:55.938683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-08 12:22:55.938759: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-08 12:22:55.941091: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-08 12:22:55.952450: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 12:22:55.952933: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-08 09:20:31.661239: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-08 12:22:56.593653: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 /mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) @@ -23,10 +23,19 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). INFO:root:Loaded ViT-H-14 model config. DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). -/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. +/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. state_dict = torch.load(ckpt, map_location="cpu") >>> model checkpoint loaded. >>> Load pre-trained model ... +>>> Applying precision settings: + - Diffusion dtype: bf16 + - Projector mode: bf16_full + - Encoder mode: bf16_full + - VAE dtype: bf16 + ✓ Diffusion model weights converted to bfloat16 + ✓ Projectors converted to bfloat16 + ✓ Encoders converted to bfloat16 + ✓ VAE converted to bfloat16 INFO:root:***** Configing Data ***** >>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: data stats loaded. @@ -106,7 +115,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:40<11:43, 100.53s/it] 25%|██▌ | 2/8 [03:20<10:01, 100.30s/it] 38%|███▊ | 3/8 [05:00<08:21, 100.22s/it] 50%|█████ | 4/8 [06:41<06:41, 100.29s/it] 62%|██████▎ | 5/8 [08:21<05:01, 100.38s/it] 75%|███████▌ | 6/8 [10:02<03:20, 100.41s/it] 88%|████████▊ | 7/8 [11:42<01:40, 100.37s/it] 100%|██████████| 8/8 [13:23<00:00, 100.44s/it] 100%|██████████| 8/8 [13:23<00:00, 100.39s/it] + 12%|█▎ | 1/8 [01:24<09:53, 84.82s/it] 25%|██▌ | 2/8 [02:49<08:26, 84.48s/it] 38%|███▊ | 3/8 [04:13<07:01, 84.40s/it] 50%|█████ | 4/8 [05:37<05:37, 84.43s/it] 62%|██████▎ | 5/8 [07:02<04:13, 84.44s/it] 75%|███████▌ | 6/8 [08:26<02:48, 84.44s/it] 88%|████████▊ | 7/8 [09:50<01:24, 84.36s/it] 100%|██████████| 8/8 [11:15<00:00, 84.41s/it] 100%|██████████| 8/8 [11:15<00:00, 84.43s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -130,6 +139,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 14m41.368s -user 13m8.275s -sys 0m48.945s +real 12m19.457s +user 12m13.197s +sys 0m38.223s diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json index 6f778e6..05b95c4 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/psnr_result1.json @@ -1,5 +1,5 @@ { "gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4", "pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4", - "psnr": 44.83864567508593 + "psnr": 30.44844270035179 } \ No newline at end of file diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh b/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh index 8fe141f..304cb31 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh @@ -4,7 +4,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils" { time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ --seed 123 \ - --ckpt_path ckpts/unifolm_wma_dual.ckpt \ + --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ --config configs/inference/world_model_interaction.yaml \ --savedir "${res_dir}/output" \ --bs 1 --height 320 --width 512 \ @@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_cleanup_pencils" --n_iter 8 \ --timestep_spacing 'uniform_trailing' \ --guidance_rescale 0.7 \ - --perframe_ae + --perframe_ae \ + --vae_dtype bf16 } 2>&1 | tee "${res_dir}/output.log"