Files
unifolm-world-model-action/scripts/evaluation/world_model_interaction.py
2026-05-17 15:07:06 +08:00

1723 lines
68 KiB
Python

import argparse, os, glob
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
import atexit
import signal
import multiprocessing as mp
import time
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Queue
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues
from collections import deque
from typing import Optional, List, Any
from types import SimpleNamespace
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torch import Tensor
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): The model whose device is to be inferred.
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
def clone_observation_queues(
queues: dict[str, deque]) -> dict[str, deque]:
"""Deep-clone queue tensors so pipeline branches can diverge safely."""
cloned = {}
for key, queue in queues.items():
cloned[key] = deque(
((item.clone() if torch.is_tensor(item) else item)
for item in queue),
maxlen=queue.maxlen)
return cloned
def clone_observation_queues_to_cpu(
queues: dict[str, deque]) -> dict[str, deque]:
cpu_queues = {}
for key, queue in queues.items():
cpu_queues[key] = deque(
(item.detach().cpu().clone() if torch.is_tensor(item) else item
for item in queue),
maxlen=queue.maxlen)
return cpu_queues
def move_observation_queues_to_device(
queues: dict[str, deque],
device: torch.device) -> dict[str, deque]:
moved = {}
for key, queue in queues.items():
moved[key] = deque(
((item.to(device, non_blocking=True) if torch.is_tensor(item) else
item) for item in queue),
maxlen=queue.maxlen)
return moved
def sync_module_device_attributes(module: nn.Module,
device: torch.device) -> None:
"""Align cached `.device` attributes with the actual target device."""
for submodule in module.modules():
if hasattr(submodule, 'device'):
try:
setattr(submodule, 'device', device)
except Exception:
pass
def pipeline_print(message: str) -> None:
print(message, flush=True)
def build_observation_from_queues(
queues: dict[str, deque],
device: torch.device) -> dict[str, torch.Tensor]:
observation = {
'observation.images.top':
torch.stack(list(queues['observation.images.top']), dim=1).permute(
0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(queues['observation.state']), dim=1),
'action':
torch.stack(list(queues['action']), dim=1),
}
return {
key: value.to(device, non_blocking=True)
for key, value in observation.items()
}
def append_action_sequence(
queues: dict[str, deque],
action_seq: torch.Tensor,
ori_action_dim: int) -> dict[str, deque]:
for idx in range(action_seq.shape[1]):
action_frame = action_seq[0][idx:idx + 1].clone()
action_frame[:, ori_action_dim:] = 0.0
queues = populate_queues(queues, {'action': action_frame})
return queues
def rollout_execution_segment(
queues: dict[str, deque],
seg_video: torch.Tensor,
pred_states: torch.Tensor,
zero_action_template: torch.Tensor,
exe_steps: int,
ori_state_dim: int,
zero_pred_state: bool) -> dict[str, deque]:
for idx in range(exe_steps):
state_frame = (torch.zeros_like(pred_states[0][idx:idx + 1])
if zero_pred_state else pred_states[0][idx:idx + 1].clone())
state_frame[:, ori_state_dim:] = 0.0
observation = {
'observation.images.top':
seg_video[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state': state_frame,
'action': torch.zeros_like(zero_action_template),
}
queues = populate_queues(queues, observation)
return queues
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
Args:
video_path (str): Output path for the video.
stacked_frames (list): List of image frames.
fps (int): Frames per second for the video.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""Return sorted list of files in a directory matching specified postfixes.
Args:
data_dir (str): Directory path to search in.
postfixes (list[str]): List of file extensions to match.
Returns:
list[str]: Sorted list of file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
save_dir (str): Directory where results are saved.
filename (str): Name of the file to check.
Returns:
bool: True if processed file exists, False otherwise.
"""
video_file = os.path.join(save_dir, "samples_separate",
f"{filename[:-4]}_sample0.mp4")
return os.path.exists(video_file)
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
"""Save video tensor to file using torchvision.
Args:
video (Tensor): Tensor of shape (B, C, T, H, W).
filename (str): Output file path.
fps (int, optional): Frames per second. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
# ========== Async I/O ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
_child_processes: List[mp.Process] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _register_child_process(proc: mp.Process) -> None:
_child_processes.append(proc)
def _unregister_child_process(proc: mp.Process) -> None:
try:
_child_processes.remove(proc)
except ValueError:
pass
def _terminate_process(proc: mp.Process, join_timeout: float = 3.0) -> None:
if proc is None:
return
try:
alive = proc.is_alive()
except Exception:
alive = False
if not alive:
try:
proc.join(timeout=0.1)
except Exception:
pass
return
try:
proc.terminate()
except Exception:
pass
try:
proc.join(timeout=join_timeout)
except Exception:
pass
try:
if proc.is_alive():
proc.kill()
proc.join(timeout=join_timeout)
except Exception:
pass
def _terminate_all_child_processes() -> None:
for proc in list(_child_processes):
_terminate_process(proc)
_child_processes.clear()
def _handle_termination_signal(signum, _frame) -> None:
signame = signal.Signals(signum).name
print(f">>> Received {signame}, terminating child processes ...")
_terminate_all_child_processes()
raise SystemExit(128 + signum)
signal.signal(signal.SIGINT, _handle_termination_signal)
signal.signal(signal.SIGTERM, _handle_termination_signal)
atexit.register(_terminate_all_child_processes)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
if video_cpu.dim() == 5:
n = video_cpu.shape[0]
video = video_cpu.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
"""Submit TensorBoard logging to background thread pool."""
if isinstance(data, torch.Tensor) and data.dim() == 5:
data_cpu = data.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
_io_futures.append(fut)
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
if n == 1:
# Fast path for bs=1: skip make_grid and convert directly.
frames = video[0].permute(1, 2, 3, 0).contiguous()
frames = ((frames + 1.0) / 2.0 * 255).to(torch.uint8)
return frames.numpy()
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
]
grid = torch.stack(frame_grids, dim=0)
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
return grid.numpy()
def _tensor_stats(name: str, tensor: Tensor) -> str:
t = tensor.detach().float()
finite_mask = torch.isfinite(t)
finite_ratio = finite_mask.float().mean().item()
if finite_mask.any():
tf = t[finite_mask]
mean = tf.mean().item()
std = tf.std(unbiased=False).item()
min_v = tf.min().item()
max_v = tf.max().item()
abs_mean = tf.abs().mean().item()
else:
mean = std = min_v = max_v = abs_mean = float('nan')
return (f"{name}: shape={tuple(t.shape)} dtype={tensor.dtype} "
f"mean={mean:.6f} std={std:.6f} min={min_v:.6f} max={max_v:.6f} "
f"abs_mean={abs_mean:.6f} finite={finite_ratio:.6f}")
def _debug_world_model_stats(wm_samples: Tensor, wm_video: Tensor,
prefix: str) -> None:
print(f">>> [debug_wm_stats] {prefix} latent {_tensor_stats('wm_samples', wm_samples)}")
print(f">>> [debug_wm_stats] {prefix} decoded {_tensor_stats('wm_video', wm_video)}")
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
writer = None
while True:
item = q.get()
if item is None:
break
frames = _video_tensor_to_frames(item)
if writer is None:
writer = imageio.get_writer(
filename,
fps=fps,
codec='libx264',
ffmpeg_params=['-crf', '10', '-pix_fmt', 'yuv420p'])
for frame in frames:
writer.append_data(frame)
if writer is not None:
writer.close()
def _stop_writer_process(writer_proc: mp.Process, write_q: mp.Queue) -> None:
try:
write_q.put(None, timeout=1.0)
except Exception:
pass
try:
writer_proc.join(timeout=5.0)
except Exception:
pass
try:
if writer_proc.is_alive():
_terminate_process(writer_proc)
except Exception:
pass
try:
write_q.close()
write_q.join_thread()
except Exception:
pass
_unregister_child_process(writer_proc)
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
Args:
data_dir (str): Base directory containing videos.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the video file.
"""
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.png')
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
return full_image_fp
def get_transition_path(data_dir: str, sample: dict) -> str:
"""Construct the full transition file path from directory and sample metadata.
Args:
data_dir (str): Base directory containing transition files.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the HDF5 transition file.
"""
rel_transition_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def prepare_init_input(start_idx: int,
init_frame_path: str,
transition_dict: dict[str, torch.Tensor],
frame_stride: int,
wma_data,
video_length: int = 16,
n_obs_steps: int = 2) -> dict[str, Tensor]:
"""
Extracts a structured sample from a video sequence including frames, states, and actions,
along with properly padded observations and pre-processed tensors for model input.
Args:
start_idx (int): Starting frame index for the current clip.
video: decord video instance.
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
'observation.state', 'action_type', 'state_type'.
frame_stride (int): Temporal stride between sampled frames.
wma_data: Object that holds configuration and utility functions like normalization,
transformation, and resolution info.
video_length (int, optional): Number of frames to sample from the video. Default is 16.
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
"""
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {
'observation.image': init_frame,
}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
Args:
model: the world model.
videos (Tensor): Input videos of shape [B, C, T, H, W].
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# Map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
imgs = {
f"observation.images.{key}": img
for key, img in observations["pixels"].items()
}
else:
imgs = {"observation.images.top": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# Sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# Sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# Convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(
observations["agent_pos"]).float()
return_observations['observation.state'] = model.normalize_inputs({
'observation.state':
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
precision: int | None = 16,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
decode_video: bool = True,
pipeline_split_step: int = 0,
pipeline_compare_full: bool = False,
handoff_callback=None,
stop_at_handoff: bool = False,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor,
torch.Tensor, dict[str, Any]]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
Args:
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
prompts (list[str]): A list of textual prompts to guide the synthesis process.
observation (dict): A dictionary containing observed inputs including:
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
- 'observation.state': Tensor of shape [B, O, D] (state vector)
- 'action': Tensor of shape [B, T, D] (action sequence)
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
typically (B, C, T, H, W).
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
decode_video (bool): Whether to decode latent samples to pixel-space video.
Set to False to skip VAE decode for speed when only actions/states are needed.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
or None when decode_video=False.
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat(
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :,
-model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
batch_variants = None
if ddim_sampler is not None:
sample_kwargs = dict(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
precision=precision,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
samples, actions, states, intermedia = ddim_sampler.sample(
**sample_kwargs,
handoff_step=pipeline_split_step,
handoff_callback=handoff_callback,
stop_at_handoff=stop_at_handoff)
if decode_video:
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states, samples, intermedia
def load_inference_runtime(args: argparse.Namespace,
gpu_no: int) -> SimpleNamespace:
"""Load the model once and keep all GPU-side runtime state together."""
config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False,
mmap=True)
model.eval()
model = model.cuda(gpu_no)
diffusion_model = model.model.diffusion_model
if not hasattr(diffusion_model, '_ctx_cache_enabled'):
diffusion_model._ctx_cache_enabled = False
if not hasattr(diffusion_model, '_ctx_cache'):
diffusion_model._ctx_cache = {}
if not hasattr(diffusion_model, '_trt_backbone'):
diffusion_model._trt_backbone = None
if not hasattr(diffusion_model, '_state_stream'):
diffusion_model._state_stream = torch.cuda.Stream(
device=torch.device(f"cuda:{gpu_no}"))
print(f">>> Prepared model loaded.")
else:
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
model = model.cuda(gpu_no)
print(f'>>> Load pre-trained model ...')
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
device = get_device_from_parameters(model)
sync_module_device_attributes(model, device)
if hasattr(model, 'cond_stage_model') and model.cond_stage_model is not None:
sync_module_device_attributes(model.cond_stage_model, device)
if hasattr(model.model.diffusion_model, '_state_stream'):
model.model.diffusion_model._state_stream = torch.cuda.Stream(
device=device)
# Fuse KV projections in attention layers (to_k + to_v -> to_kv).
from unifolm_wma.modules.attention import CrossAttention
kv_count = sum(1 for m in model.modules()
if isinstance(m, CrossAttention) and m.fuse_kv())
print(f" ✓ KV fused: {kv_count} attention layers")
# Load TRT backbone if engine exists and the user did not request Torch fallback.
trt_engine_path = args.trt_engine_path
if trt_engine_path is None:
trt_engine_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'..', '..', 'trt_engines',
'video_backbone.engine')
if args.disable_trt:
print(">>> TRT disabled by --disable_trt; using PyTorch video backbone.")
elif os.path.exists(trt_engine_path):
model.model.diffusion_model.load_trt_backbone(trt_engine_path)
else:
print(f">>> TRT engine not found at {trt_engine_path}; using PyTorch video backbone.")
if torch.cuda.is_available():
torch.cuda.synchronize(device)
return SimpleNamespace(model=model, config=config, device=device, gpu_no=gpu_no)
def load_inference_data(config: OmegaConf,
args: argparse.Namespace) -> SimpleNamespace:
logging.info("***** Configing Data *****")
test_cfg = OmegaConf.create(
OmegaConf.to_container(config.data.params.test, resolve=True))
stats_root = OmegaConf.select(config, "data.params.test.params.data_dir")
if stats_root is None:
stats_root = args.prompt_dir
test_cfg.params["data_dir"] = stats_root
test_cfg.params["meta_path"] = os.path.join(args.prompt_dir,
f"{args.dataset}.csv")
test_cfg.params["transition_dir"] = os.path.join(stats_root,
"transitions")
test_cfg.params["dataset_name"] = args.dataset
target_dataset = instantiate_from_config(test_cfg)
data = SimpleNamespace(test_datasets={args.dataset: target_dataset})
print(">>> Dataset is successfully loaded ...")
return data
def build_pipeline_iteration_input(
runtime: SimpleNamespace,
args: argparse.Namespace,
base_queues: dict[str, deque],
action_seq: torch.Tensor,
noise_shape: list[int],
model_input_fs: int,
ori_action_dim: int,
ori_state_dim: int) -> dict[str, deque]:
model = runtime.model
device = runtime.device
policy_observation = build_observation_from_queues(base_queues, device)
pipeline_action_queues = clone_observation_queues(base_queues)
pipeline_action_queues = append_action_sequence(pipeline_action_queues,
action_seq,
ori_action_dim)
wm_handoff_observation = {
'observation.images.top': policy_observation['observation.images.top'],
'observation.state': policy_observation['observation.state'],
'action': torch.stack(list(pipeline_action_queues['action']),
dim=1).to(device, non_blocking=True),
}
_, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode(
model,
"",
wm_handoff_observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
precision=args.precision,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
decode_video=False,
pipeline_split_step=args.pipeline_split_step,
pipeline_compare_full=False,
stop_at_handoff=True)
wm_handoff = wm_intermedia.get('handoff', {})
handoff_video_latent = wm_handoff.get('pred_x0', wm_handoff.get(
'samples', None))
if handoff_video_latent is None:
handoff_video_latent = wm_samples
handoff_video = model.decode_first_stage(handoff_video_latent)
handoff_state_seq = wm_handoff.get('states', pred_states)
cond_obs_queues = clone_observation_queues(base_queues)
cond_obs_queues = append_action_sequence(cond_obs_queues, action_seq,
ori_action_dim)
cond_obs_queues = rollout_execution_segment(
queues=cond_obs_queues,
seg_video=handoff_video[:, :, :args.exe_steps],
pred_states=handoff_state_seq,
zero_action_template=action_seq[0][-1:],
exe_steps=args.exe_steps,
ori_state_dim=ori_state_dim,
zero_pred_state=args.zero_pred_state)
return cond_obs_queues
def run_pipeline_iteration_task(
runtime: SimpleNamespace,
args: argparse.Namespace,
iter_idx: int,
instruction: str,
input_payload: dict[str, Any],
noise_shape: list[int],
model_input_fs: int,
ori_action_dim: int,
ori_state_dim: int,
handoff_queue: Queue) -> dict[str, Any]:
model = runtime.model
device = runtime.device
if input_payload['mode'] == 'initial':
cond_obs_queues = move_observation_queues_to_device(
input_payload['queues_cpu'], device)
elif input_payload['mode'] == 'handoff':
base_queues = move_observation_queues_to_device(
input_payload['base_queues_cpu'], device)
action_seq = input_payload['action_seq_cpu'].to(
device, non_blocking=True)
source_iter = input_payload['source_iter']
pipeline_print(
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: building pipeline input '
f'from step {source_iter} handoff ...')
cond_obs_queues = build_pipeline_iteration_input(
runtime=runtime,
args=args,
base_queues=base_queues,
action_seq=action_seq,
noise_shape=noise_shape,
model_input_fs=model_input_fs,
ori_action_dim=ori_action_dim,
ori_state_dim=ori_state_dim)
else:
raise ValueError(f"Unsupported input mode: {input_payload['mode']}")
iter_input_queues = clone_observation_queues(cond_obs_queues)
iter_input_queues_cpu = clone_observation_queues_to_cpu(iter_input_queues)
handoff_sent = False
def _on_policy_handoff(handoff: dict[str, torch.Tensor]) -> None:
nonlocal handoff_sent
if handoff_sent or (iter_idx + 1) >= args.n_iter:
return
handoff_queue.put({
'source_iter':
iter_idx,
'next_iter':
iter_idx + 1,
'base_queues_cpu':
iter_input_queues_cpu,
'action_seq_cpu':
handoff['actions'].detach().cpu().clone(),
'handoff_step':
handoff.get('step', args.pipeline_split_step),
})
handoff_sent = True
handoff_step = handoff.get('step', args.pipeline_split_step)
pipeline_print(
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: emitted policy handoff '
f'for step {iter_idx + 1} at ddim step '
f'{handoff_step} ...')
policy_observation = build_observation_from_queues(cond_obs_queues, device)
pipeline_print(
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: generating actions ...')
_, pred_actions, _, _, _ = image_guided_synthesis_sim_mode(
model,
instruction,
policy_observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
precision=args.precision,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=False,
pipeline_split_step=args.pipeline_split_step,
pipeline_compare_full=False,
handoff_callback=_on_policy_handoff if args.pipeline_split_step > 0 else
None)
cond_obs_queues = append_action_sequence(cond_obs_queues, pred_actions,
ori_action_dim)
wm_observation = {
'observation.images.top': policy_observation['observation.images.top'],
'observation.state': policy_observation['observation.state'],
'action': torch.stack(list(cond_obs_queues['action']), dim=1).to(
device, non_blocking=True),
}
pipeline_print(
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: interacting with world model ...')
_, _, pred_states, wm_samples, _ = image_guided_synthesis_sim_mode(
model,
"",
wm_observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
precision=args.precision,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
decode_video=False,
pipeline_split_step=0,
pipeline_compare_full=False)
wm_video = model.decode_first_stage(wm_samples)
if args.debug_wm_stats:
_debug_world_model_stats(wm_samples,
wm_video,
prefix=f"step={iter_idx}@gpu{runtime.gpu_no}")
seg_video = wm_video[:, :, :args.exe_steps]
pipeline_print('>' * 24)
return {
'iter_idx': iter_idx,
'gpu_no': runtime.gpu_no,
'seg_video_cpu': seg_video.detach().cpu(),
}
def run_inference_multi_gpu_pipeline(
args: argparse.Namespace,
runtimes: list[SimpleNamespace]) -> None:
os.makedirs(args.savedir + '/inference', exist_ok=True)
config = runtimes[0].config
data = load_inference_data(config, args)
df = pd.read_csv(os.path.join(args.prompt_dir, f"{args.dataset}.csv"))
model0 = runtimes[0].model
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
h, w = args.height // 8, args.width // 8
channels = model0.model.diffusion_model.out_channels
n_frames = args.video_length
pipeline_print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
gpu_ids = [runtime.gpu_no for runtime in runtimes]
pipeline_print(f'>>> Multi-GPU pipeline enabled on GPUs {gpu_ids}.')
for idx in range(0, len(df)):
sample = df.iloc[idx]
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
for fs in args.frame_stride:
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
write_q = mp.Queue(maxsize=4)
writer_proc = mp.Process(target=_video_writer_process,
args=(write_q, sample_full_video_file,
args.save_fps))
writer_proc.daemon = True
writer_proc.start()
_register_child_process(writer_proc)
executors = {
runtime.gpu_no: ThreadPoolExecutor(max_workers=1)
for runtime in runtimes
}
handoff_events = Queue()
future_meta = {}
submitted_iters = set()
completed_results = {}
next_write_idx = 0
model_input_fs = ori_fps // fs
progress = tqdm(total=args.n_iter,
ascii=False,
desc=f'fs={fs} pipeline',
leave=True)
try:
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model0.n_obs_steps_imagen)
initial_observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
initial_queues = {
"observation.images.top":
deque(maxlen=model0.n_obs_steps_imagen),
"observation.state": deque(maxlen=model0.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
initial_queues = populate_queues(initial_queues,
initial_observation)
initial_payload = {
'mode': 'initial',
'queues_cpu': clone_observation_queues_to_cpu(initial_queues),
}
first_gpu = runtimes[0].gpu_no
future = executors[first_gpu].submit(
run_pipeline_iteration_task, runtimes[0], args, 0,
sample['instruction'], initial_payload, noise_shape,
model_input_fs, ori_action_dim, ori_state_dim,
handoff_events)
future_meta[future] = {'iter_idx': 0, 'gpu_no': first_gpu}
submitted_iters.add(0)
while next_write_idx < args.n_iter:
done_futures = [fut for fut in list(future_meta.keys())
if fut.done()]
for fut in done_futures:
meta = future_meta.pop(fut)
result = fut.result()
completed_results[result['iter_idx']] = result
pipeline_print(
f'>>> Step {result["iter_idx"]}@gpu{meta["gpu_no"]}: '
f'final segment ready.')
while True:
try:
event = handoff_events.get_nowait()
except Empty:
break
next_iter = event['next_iter']
if next_iter >= args.n_iter or next_iter in submitted_iters:
continue
target_runtime = runtimes[next_iter % len(runtimes)]
payload = {
'mode': 'handoff',
'base_queues_cpu': event['base_queues_cpu'],
'action_seq_cpu': event['action_seq_cpu'],
'source_iter': event['source_iter'],
'handoff_step': event['handoff_step'],
}
pipeline_print(
f'>>> Step {event["source_iter"]}: queueing step '
f'{next_iter} on gpu{target_runtime.gpu_no} from '
f'ddim step {event["handoff_step"]} ...')
future = executors[target_runtime.gpu_no].submit(
run_pipeline_iteration_task, target_runtime, args,
next_iter, sample['instruction'], payload,
noise_shape, model_input_fs, ori_action_dim,
ori_state_dim, handoff_events)
future_meta[future] = {
'iter_idx': next_iter,
'gpu_no': target_runtime.gpu_no
}
submitted_iters.add(next_iter)
while next_write_idx in completed_results:
result = completed_results.pop(next_write_idx)
write_q.put(result['seg_video_cpu'])
next_write_idx += 1
progress.update(1)
progress.set_postfix_str(
f'written={next_write_idx}/{args.n_iter}',
refresh=False)
if next_write_idx < args.n_iter:
if not future_meta and len(submitted_iters) >= args.n_iter:
raise RuntimeError(
"Pipeline scheduler is idle before all iterations finished.")
time.sleep(0.05)
finally:
progress.close()
for executor in executors.values():
executor.shutdown(wait=True, cancel_futures=False)
_stop_writer_process(writer_proc, write_q)
_flush_io()
def run_inference(args: argparse.Namespace,
gpu_num: int,
gpu_no: int,
runtime: Optional[SimpleNamespace] = None) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Create inference dir
os.makedirs(args.savedir + '/inference', exist_ok=True)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config (always needed for data setup)
config = runtime.config if runtime is not None else OmegaConf.load(args.config)
# Parallel loading: model and data are independent
def _load_model():
return load_inference_runtime(args, gpu_no).model
def _load_data():
return load_inference_data(config, args)
if runtime is None:
with ThreadPoolExecutor(max_workers=2) as executor:
model_future = executor.submit(_load_model)
data_future = executor.submit(_load_data)
model = model_future.result()
data = data_future.result()
device = get_device_from_parameters(model)
else:
model = runtime.model
device = runtime.device
data = _load_data()
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
# Got initial frame path
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
# Writer process for incremental video saving
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
write_q = mp.Queue(maxsize=4)
writer_proc = mp.Process(
target=_video_writer_process,
args=(write_q, sample_full_video_file, args.save_fps))
writer_proc.daemon = True
writer_proc.start()
_register_child_process(writer_proc)
try:
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
# Obtain initial frame and state
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
start_idx,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2,
3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Multi-round interaction with the world-model
pending_handoff = None
pending_handoff_source_itr = None
pending_handoff_ddim_step = None
for itr in tqdm(range(args.n_iter), ascii=False):
# Build observation for policy pass.
if pending_handoff is not None:
print(
f'>>> Step {itr}: consuming pipeline handoff '
f'from step {pending_handoff_source_itr} '
f'at ddim step {pending_handoff_ddim_step} ...')
cond_obs_queues = pending_handoff
pending_handoff = None
pending_handoff_source_itr = None
pending_handoff_ddim_step = None
iter_input_queues = clone_observation_queues(cond_obs_queues)
policy_observation = build_observation_from_queues(
cond_obs_queues, device)
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
_, pred_actions, _, _, policy_intermedia = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
policy_observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
precision=args.precision,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=False,
pipeline_split_step=args.pipeline_split_step,
pipeline_compare_full=False)
pipeline_action_seq = None
pipeline_pred_states = None
pipeline_wm_samples = None
pipeline_wm_intermedia = {}
if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter:
pipeline_action_seq = policy_intermedia.get(
'handoff', {}).get('actions', pred_actions)
pipeline_action_queues = clone_observation_queues(
iter_input_queues)
pipeline_action_queues = append_action_sequence(
pipeline_action_queues, pipeline_action_seq,
ori_action_dim)
wm_handoff_observation = {
'observation.images.top':
policy_observation['observation.images.top'],
'observation.state':
policy_observation['observation.state'],
'action':
torch.stack(list(pipeline_action_queues['action']),
dim=1).to(device, non_blocking=True),
}
print(
f'>>> Step {itr}: preparing pipeline handoff branch ...'
)
_, _, pipeline_pred_states, pipeline_wm_samples, pipeline_wm_intermedia = image_guided_synthesis_sim_mode(
model,
"",
wm_handoff_observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
precision=args.precision,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
decode_video=False,
pipeline_split_step=args.pipeline_split_step,
pipeline_compare_full=False)
# Update future actions in the observation queues
cond_obs_queues = append_action_sequence(
cond_obs_queues, pred_actions, ori_action_dim)
# Reuse images/state and only rebuild action for WM pass.
wm_observation = {
'observation.images.top': policy_observation['observation.images.top'],
'observation.state': policy_observation['observation.state'],
'action': torch.stack(list(cond_obs_queues['action']), dim=1).to(
device, non_blocking=True),
}
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
_, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode(
model,
"",
wm_observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
precision=args.precision,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
decode_video=False,
pipeline_split_step=args.pipeline_split_step,
pipeline_compare_full=False)
# Decode full WM clip, then take executable segment to keep behavior closer to previous path.
wm_video = model.decode_first_stage(wm_samples)
if args.debug_wm_stats:
_debug_world_model_stats(
wm_samples,
wm_video,
prefix=f"step={itr}")
seg_video = wm_video[:, :, :args.exe_steps]
cond_obs_queues = rollout_execution_segment(
queues=cond_obs_queues,
seg_video=seg_video,
pred_states=pred_states,
zero_action_template=pred_actions[0][-1:],
exe_steps=args.exe_steps,
ori_state_dim=ori_state_dim,
zero_pred_state=args.zero_pred_state)
if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter:
policy_handoff = policy_intermedia.get('handoff', {})
wm_handoff = pipeline_wm_intermedia.get('handoff', {})
handoff_ddim_step = wm_handoff.get(
'step',
policy_handoff.get('step', args.pipeline_split_step))
next_action_seq = pipeline_action_seq
handoff_video_latent = wm_handoff.get(
'pred_x0', wm_handoff.get('samples', None))
if handoff_video_latent is None:
handoff_video_latent = pipeline_wm_samples
handoff_video = model.decode_first_stage(
handoff_video_latent)
handoff_state_seq = wm_handoff.get('states',
pipeline_pred_states)
pending_handoff = clone_observation_queues(
iter_input_queues)
pending_handoff = append_action_sequence(
pending_handoff, next_action_seq, ori_action_dim)
pending_handoff = rollout_execution_segment(
queues=pending_handoff,
seg_video=handoff_video[:, :, :args.exe_steps],
pred_states=handoff_state_seq,
zero_action_template=next_action_seq[0][-1:],
exe_steps=args.exe_steps,
ori_state_dim=ori_state_dim,
zero_pred_state=args.zero_pred_state)
pending_handoff_source_itr = itr
pending_handoff_ddim_step = handoff_ddim_step
print(
f'>>> Step {itr}: prepared pipeline handoff '
f'for step {itr + 1} at ddim step '
f'{handoff_ddim_step} ...')
print('>' * 24)
# Send decoded segment to writer process
write_q.put(seg_video.detach().cpu())
finally:
_stop_writer_process(writer_proc, write_q)
# Wait for all async I/O to complete
_flush_io()
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the model checkpoint.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument("--dataset",
type=str,
default=None,
help="the name of dataset to test")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
nargs='+',
required=True,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument("--num_generation",
type=int,
default=1,
help="seed for seed_everything")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--precision",
type=int,
default=16,
choices=[16, 32],
help="Sampling precision for latent/action/state noise initialization. Default is 16.")
parser.add_argument(
"--n_action_steps",
type=int,
default=16,
help="num of samples per prompt",
)
parser.add_argument(
"--exe_steps",
type=int,
default=16,
help="num of samples to execute",
)
parser.add_argument(
"--n_iter",
type=int,
default=40,
help="num of iteration to interact with the world model",
)
parser.add_argument("--zero_pred_state",
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument(
"--fast_policy_no_decode",
action='store_true',
default=False,
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
parser.add_argument(
"--debug_wm_stats",
action='store_true',
default=False,
help="Print latent/decode statistics for world-model samples before writing video.")
parser.add_argument(
"--disable_trt",
action='store_true',
default=False,
help="Disable TensorRT backbone loading and force the PyTorch video backbone path.")
parser.add_argument(
"--trt_engine_path",
type=str,
default=None,
help="Optional explicit TensorRT engine path. Defaults to trt_engines/video_backbone.engine.")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
parser.add_argument(
"--pipeline_split_step",
type=int,
default=0,
help="Run DDIM sampling in two segments on a single GPU for pipeline experiments. "
"Set to a value like 25 to test 25+25 splitting; 0 disables it.")
parser.add_argument(
"--pipeline_compare_full",
action='store_true',
default=False,
help="When pipeline_split_step is enabled, also run the original full DDIM pass "
"with the same initial noise and print max-abs diffs for validation.")
parser.add_argument(
"--pipeline_multi_gpu",
action='store_true',
default=False,
help="Enable true asynchronous pipeline execution across multiple GPUs. "
"When disabled, keep the original single-GPU/serial inference path.")
parser.add_argument(
"--pipeline_gpu_ids",
type=int,
nargs='+',
default=[0, 1],
help="Logical CUDA device ids used by the multi-GPU pipeline scheduler.")
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
if args.pipeline_multi_gpu:
if args.pipeline_split_step <= 0:
raise ValueError(
"--pipeline_multi_gpu requires --pipeline_split_step > 0.")
if len(args.pipeline_gpu_ids) < 2:
raise ValueError(
"--pipeline_multi_gpu requires at least two gpu ids.")
runtimes = [
load_inference_runtime(args, gpu_id)
for gpu_id in args.pipeline_gpu_ids
]
run_inference_multi_gpu_pipeline(args, runtimes[:2])
else:
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)