2 Commits

Author SHA1 Message Date
qhy
3069666a15 脚本修改 2026-02-10 14:49:26 +08:00
qhy
68369cc15f 合并后测试 2026-02-10 14:45:14 +08:00
7 changed files with 178 additions and 55 deletions

View File

@@ -222,7 +222,7 @@ data:
test: test:
target: unifolm_wma.data.wma_data.WMAData target: unifolm_wma.data.wma_data.WMAData
params: params:
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts' data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length} video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2 frame_stride: 2
load_raw_resolution: True load_raw_resolution: True

View File

@@ -1,5 +1,7 @@
import argparse, os, glob import argparse, os, glob
from contextlib import nullcontext from contextlib import nullcontext
import atexit
from concurrent.futures import ThreadPoolExecutor
import pandas as pd import pandas as pd
import random import random
import torch import torch
@@ -11,13 +13,15 @@ import einops
import warnings import warnings
import imageio import imageio
from typing import Optional, List, Any
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from collections import OrderedDict from collections import OrderedDict
from torch import nn from torch import nn
from eval_utils import populate_queues, log_to_tensorboard from eval_utils import populate_queues
from collections import deque from collections import deque
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@@ -28,6 +32,80 @@ from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F import torch.nn.functional as F
# ========== Async I/O utilities ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
video = video_cpu.float()
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
"""Submit tensorboard logging to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
_io_futures.append(fut)
def patch_norm_bypass_autocast(): def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy. """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward.""" This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
@@ -185,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
"""Load model weights from checkpoint file. """Load model weights from checkpoint file.
Args: Args:
model (nn.Module): Model instance. model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file. ckpt (str): Path to the checkpoint file.
device (str): Target device for loaded tensors.
Returns: Returns:
nn.Module: Model with loaded weights. nn.Module: Model with loaded weights.
""" """
state_dict = torch.load(ckpt, map_location="cpu") state_dict = torch.load(ckpt, map_location=device)
if "state_dict" in list(state_dict.keys()): if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
try: try:
@@ -610,36 +689,63 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Load config # Load config
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
print(f'>>> Load pre-trained model ...')
# Apply precision settings before moving to GPU prepared_path = args.ckpt_path + ".prepared.pt"
model = apply_precision_settings(model, args) if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False)
model.eval()
# Compile hot ResBlocks for operator fusion # Restore autocast attributes (weights already cast, just need contexts)
apply_torch_compile(model) model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
# Export precision-converted checkpoint if requested # Compile hot ResBlocks for operator fusion
if args.export_precision_ckpt: apply_torch_compile(model)
export_path = args.export_precision_ckpt
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
torch.save({"state_dict": model.state_dict()}, export_path)
print(f">>> Precision-converted checkpoint saved to: {export_path}")
return
# Build unnomalizer print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval()
print(f'>>> Load pre-trained model ...')
# Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
torch.save({"state_dict": model.state_dict()}, export_path)
print(f">>> Precision-converted checkpoint saved to: {export_path}")
return
model = model.cuda(gpu_no)
# Save prepared model for fast loading next time (before torch.compile)
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
apply_torch_compile(model)
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****") logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)
data.setup() data.setup()
print(">>> Dataset is successfully loaded ...") print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model) device = get_device_from_parameters(model)
# Run over data # Run over data
@@ -817,28 +923,28 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Save the imagen videos for decision-making # Save the imagen videos for decision-making
if pred_videos_0 is not None: if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer, log_to_tensorboard_async(writer,
pred_videos_0, pred_videos_0,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer, log_to_tensorboard_async(writer,
pred_videos_1, pred_videos_1,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
# Save the imagen videos for decision-making # Save the imagen videos for decision-making
if pred_videos_0 is not None: if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(), save_results_async(pred_videos_0,
sample_video_file, sample_video_file,
fps=args.save_fps) fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(), save_results_async(pred_videos_1,
sample_video_file, sample_video_file,
fps=args.save_fps) fps=args.save_fps)
print('>' * 24) print('>' * 24)
# Collect the result of world-model interactions # Collect the result of world-model interactions
@@ -846,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
full_video = torch.cat(wm_video, dim=2) full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer, log_to_tensorboard_async(writer,
full_video, full_video,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps) save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Wait for all async I/O to complete
_flush_io()
def get_parser(): def get_parser():

View File

@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}") print(f"Restored from {path}")
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
if getattr(self, '_channels_last', False):
x = x.to(memory_format=torch.channels_last)
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
return posterior return posterior
def decode(self, z, **kwargs): def decode(self, z, **kwargs):
if getattr(self, '_channels_last', False):
z = z.to(memory_format=torch.channels_last)
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec

View File

@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
encoder_posterior = self.first_stage_model.encode(x) encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach() results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower else: ## Consume less GPU memory but slower
bs = getattr(self, 'vae_encode_bs', 1)
results = [] results = []
for index in range(x.shape[0]): for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[index:index + frame_batch = self.first_stage_model.encode(x[i:i + bs])
1, :, :, :])
frame_result = self.get_first_stage_encoding( frame_result = self.get_first_stage_encoding(
frame_batch).detach() frame_batch).detach()
results.append(frame_result) results.append(frame_result)
@@ -1109,14 +1109,14 @@ class LatentDiffusion(DDPM):
vae_dtype = next(self.first_stage_model.parameters()).dtype vae_dtype = next(self.first_stage_model.parameters()).dtype
z = z.to(dtype=vae_dtype) z = z.to(dtype=vae_dtype)
z = 1. / self.scale_factor * z
if not self.perframe_ae: if not self.perframe_ae:
z = 1. / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs) results = self.first_stage_model.decode(z, **kwargs)
else: else:
bs = getattr(self, 'vae_decode_bs', 1)
results = [] results = []
for index in range(z.shape[0]): for i in range(0, z.shape[0], bs):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
results.append(frame_result) results.append(frame_result)
results = torch.cat(results, dim=0) results = torch.cat(results, dim=0)

View File

@@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(python3:*)"
]
}
}

View File

@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return x * torch.sigmoid(x) return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):

View File

@@ -4,7 +4,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
{ {
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ --ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \
--savedir "${res_dir}/output" \ --savedir "${res_dir}/output" \
--bs 1 --height 320 --width 512 \ --bs 1 --height 320 --width 512 \
@@ -21,6 +21,9 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae \ --perframe_ae \
--vae_dtype bf16 \ --diffusion_dtype fp32 \
--projector_mode fp32 \
--encoder_mode fp32 \
--vae_dtype fp32 \
--fast_policy_no_decode --fast_policy_no_decode
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"