1723 lines
68 KiB
Python
1723 lines
68 KiB
Python
import argparse, os, glob
|
|
import pandas as pd
|
|
import random
|
|
import torch
|
|
import torchvision
|
|
import h5py
|
|
import numpy as np
|
|
import logging
|
|
import einops
|
|
import warnings
|
|
import imageio
|
|
import atexit
|
|
import signal
|
|
import multiprocessing as mp
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from queue import Empty, Queue
|
|
|
|
from pytorch_lightning import seed_everything
|
|
from omegaconf import OmegaConf
|
|
from tqdm import tqdm
|
|
from einops import rearrange, repeat
|
|
from collections import OrderedDict
|
|
from torch import nn
|
|
from eval_utils import populate_queues
|
|
from collections import deque
|
|
from typing import Optional, List, Any
|
|
from types import SimpleNamespace
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
from torch import Tensor
|
|
from PIL import Image
|
|
|
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
|
|
|
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|
"""Get a module's device by checking one of its parameters.
|
|
|
|
Args:
|
|
module (nn.Module): The model whose device is to be inferred.
|
|
|
|
Returns:
|
|
torch.device: The device of the model's parameters.
|
|
"""
|
|
return next(iter(module.parameters())).device
|
|
|
|
|
|
def clone_observation_queues(
|
|
queues: dict[str, deque]) -> dict[str, deque]:
|
|
"""Deep-clone queue tensors so pipeline branches can diverge safely."""
|
|
cloned = {}
|
|
for key, queue in queues.items():
|
|
cloned[key] = deque(
|
|
((item.clone() if torch.is_tensor(item) else item)
|
|
for item in queue),
|
|
maxlen=queue.maxlen)
|
|
return cloned
|
|
|
|
|
|
def clone_observation_queues_to_cpu(
|
|
queues: dict[str, deque]) -> dict[str, deque]:
|
|
cpu_queues = {}
|
|
for key, queue in queues.items():
|
|
cpu_queues[key] = deque(
|
|
(item.detach().cpu().clone() if torch.is_tensor(item) else item
|
|
for item in queue),
|
|
maxlen=queue.maxlen)
|
|
return cpu_queues
|
|
|
|
|
|
def move_observation_queues_to_device(
|
|
queues: dict[str, deque],
|
|
device: torch.device) -> dict[str, deque]:
|
|
moved = {}
|
|
for key, queue in queues.items():
|
|
moved[key] = deque(
|
|
((item.to(device, non_blocking=True) if torch.is_tensor(item) else
|
|
item) for item in queue),
|
|
maxlen=queue.maxlen)
|
|
return moved
|
|
|
|
|
|
def sync_module_device_attributes(module: nn.Module,
|
|
device: torch.device) -> None:
|
|
"""Align cached `.device` attributes with the actual target device."""
|
|
for submodule in module.modules():
|
|
if hasattr(submodule, 'device'):
|
|
try:
|
|
setattr(submodule, 'device', device)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def pipeline_print(message: str) -> None:
|
|
print(message, flush=True)
|
|
|
|
|
|
def build_observation_from_queues(
|
|
queues: dict[str, deque],
|
|
device: torch.device) -> dict[str, torch.Tensor]:
|
|
observation = {
|
|
'observation.images.top':
|
|
torch.stack(list(queues['observation.images.top']), dim=1).permute(
|
|
0, 2, 1, 3, 4),
|
|
'observation.state':
|
|
torch.stack(list(queues['observation.state']), dim=1),
|
|
'action':
|
|
torch.stack(list(queues['action']), dim=1),
|
|
}
|
|
return {
|
|
key: value.to(device, non_blocking=True)
|
|
for key, value in observation.items()
|
|
}
|
|
|
|
|
|
def append_action_sequence(
|
|
queues: dict[str, deque],
|
|
action_seq: torch.Tensor,
|
|
ori_action_dim: int) -> dict[str, deque]:
|
|
for idx in range(action_seq.shape[1]):
|
|
action_frame = action_seq[0][idx:idx + 1].clone()
|
|
action_frame[:, ori_action_dim:] = 0.0
|
|
queues = populate_queues(queues, {'action': action_frame})
|
|
return queues
|
|
|
|
|
|
def rollout_execution_segment(
|
|
queues: dict[str, deque],
|
|
seg_video: torch.Tensor,
|
|
pred_states: torch.Tensor,
|
|
zero_action_template: torch.Tensor,
|
|
exe_steps: int,
|
|
ori_state_dim: int,
|
|
zero_pred_state: bool) -> dict[str, deque]:
|
|
for idx in range(exe_steps):
|
|
state_frame = (torch.zeros_like(pred_states[0][idx:idx + 1])
|
|
if zero_pred_state else pred_states[0][idx:idx + 1].clone())
|
|
state_frame[:, ori_state_dim:] = 0.0
|
|
observation = {
|
|
'observation.images.top':
|
|
seg_video[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
|
'observation.state': state_frame,
|
|
'action': torch.zeros_like(zero_action_template),
|
|
}
|
|
queues = populate_queues(queues, observation)
|
|
return queues
|
|
|
|
|
|
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
|
"""Save a list of frames to a video file.
|
|
|
|
Args:
|
|
video_path (str): Output path for the video.
|
|
stacked_frames (list): List of image frames.
|
|
fps (int): Frames per second for the video.
|
|
"""
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore",
|
|
"pkg_resources is deprecated as an API",
|
|
category=DeprecationWarning)
|
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
|
|
|
|
|
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
|
"""Return sorted list of files in a directory matching specified postfixes.
|
|
|
|
Args:
|
|
data_dir (str): Directory path to search in.
|
|
postfixes (list[str]): List of file extensions to match.
|
|
|
|
Returns:
|
|
list[str]: Sorted list of file paths.
|
|
"""
|
|
patterns = [
|
|
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
|
|
]
|
|
file_list = []
|
|
for pattern in patterns:
|
|
file_list.extend(glob.glob(pattern))
|
|
file_list.sort()
|
|
return file_list
|
|
|
|
|
|
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
|
"""Load model weights from checkpoint file.
|
|
|
|
Args:
|
|
model (nn.Module): Model instance.
|
|
ckpt (str): Path to the checkpoint file.
|
|
|
|
Returns:
|
|
nn.Module: Model with loaded weights.
|
|
"""
|
|
state_dict = torch.load(ckpt, map_location="cpu")
|
|
if "state_dict" in list(state_dict.keys()):
|
|
state_dict = state_dict["state_dict"]
|
|
try:
|
|
model.load_state_dict(state_dict, strict=True)
|
|
except:
|
|
new_pl_sd = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
new_pl_sd[k] = v
|
|
|
|
for k in list(new_pl_sd.keys()):
|
|
if "framestride_embed" in k:
|
|
new_key = k.replace("framestride_embed", "fps_embedding")
|
|
new_pl_sd[new_key] = new_pl_sd[k]
|
|
del new_pl_sd[k]
|
|
model.load_state_dict(new_pl_sd, strict=True)
|
|
else:
|
|
new_pl_sd = OrderedDict()
|
|
for key in state_dict['module'].keys():
|
|
new_pl_sd[key[16:]] = state_dict['module'][key]
|
|
model.load_state_dict(new_pl_sd)
|
|
print('>>> model checkpoint loaded.')
|
|
return model
|
|
|
|
|
|
def is_inferenced(save_dir: str, filename: str) -> bool:
|
|
"""Check if a given filename has already been processed and saved.
|
|
|
|
Args:
|
|
save_dir (str): Directory where results are saved.
|
|
filename (str): Name of the file to check.
|
|
|
|
Returns:
|
|
bool: True if processed file exists, False otherwise.
|
|
"""
|
|
video_file = os.path.join(save_dir, "samples_separate",
|
|
f"{filename[:-4]}_sample0.mp4")
|
|
return os.path.exists(video_file)
|
|
|
|
|
|
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
|
"""Save video tensor to file using torchvision.
|
|
|
|
Args:
|
|
video (Tensor): Tensor of shape (B, C, T, H, W).
|
|
filename (str): Output file path.
|
|
fps (int, optional): Frames per second. Defaults to 8.
|
|
"""
|
|
video = video.detach().cpu()
|
|
video = torch.clamp(video.float(), -1., 1.)
|
|
n = video.shape[0]
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
torchvision.io.write_video(filename,
|
|
grid,
|
|
fps=fps,
|
|
video_codec='h264',
|
|
options={'crf': '10'})
|
|
|
|
|
|
# ========== Async I/O ==========
|
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
|
_io_futures: List[Any] = []
|
|
_child_processes: List[mp.Process] = []
|
|
|
|
|
|
def _get_io_executor() -> ThreadPoolExecutor:
|
|
global _io_executor
|
|
if _io_executor is None:
|
|
_io_executor = ThreadPoolExecutor(max_workers=2)
|
|
return _io_executor
|
|
|
|
|
|
def _flush_io():
|
|
"""Wait for all pending async I/O to finish."""
|
|
global _io_futures
|
|
for fut in _io_futures:
|
|
try:
|
|
fut.result()
|
|
except Exception as e:
|
|
print(f">>> [async I/O] error: {e}")
|
|
_io_futures.clear()
|
|
|
|
|
|
atexit.register(_flush_io)
|
|
|
|
|
|
def _register_child_process(proc: mp.Process) -> None:
|
|
_child_processes.append(proc)
|
|
|
|
|
|
def _unregister_child_process(proc: mp.Process) -> None:
|
|
try:
|
|
_child_processes.remove(proc)
|
|
except ValueError:
|
|
pass
|
|
|
|
|
|
def _terminate_process(proc: mp.Process, join_timeout: float = 3.0) -> None:
|
|
if proc is None:
|
|
return
|
|
try:
|
|
alive = proc.is_alive()
|
|
except Exception:
|
|
alive = False
|
|
if not alive:
|
|
try:
|
|
proc.join(timeout=0.1)
|
|
except Exception:
|
|
pass
|
|
return
|
|
|
|
try:
|
|
proc.terminate()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
proc.join(timeout=join_timeout)
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
if proc.is_alive():
|
|
proc.kill()
|
|
proc.join(timeout=join_timeout)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _terminate_all_child_processes() -> None:
|
|
for proc in list(_child_processes):
|
|
_terminate_process(proc)
|
|
_child_processes.clear()
|
|
|
|
|
|
def _handle_termination_signal(signum, _frame) -> None:
|
|
signame = signal.Signals(signum).name
|
|
print(f">>> Received {signame}, terminating child processes ...")
|
|
_terminate_all_child_processes()
|
|
raise SystemExit(128 + signum)
|
|
|
|
|
|
signal.signal(signal.SIGINT, _handle_termination_signal)
|
|
signal.signal(signal.SIGTERM, _handle_termination_signal)
|
|
atexit.register(_terminate_all_child_processes)
|
|
|
|
|
|
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
|
"""Synchronous save on CPU tensor (runs in background thread)."""
|
|
video = torch.clamp(video_cpu.float(), -1., 1.)
|
|
n = video.shape[0]
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
torchvision.io.write_video(filename,
|
|
grid,
|
|
fps=fps,
|
|
video_codec='h264',
|
|
options={'crf': '10'})
|
|
|
|
|
|
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
|
"""Submit video saving to background thread pool."""
|
|
video_cpu = video.detach().cpu()
|
|
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
|
_io_futures.append(fut)
|
|
|
|
|
|
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
|
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
|
if video_cpu.dim() == 5:
|
|
n = video_cpu.shape[0]
|
|
video = video_cpu.permute(2, 0, 1, 3, 4)
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = grid.unsqueeze(dim=0)
|
|
writer.add_video(tag, grid, fps=fps)
|
|
|
|
|
|
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
|
"""Submit TensorBoard logging to background thread pool."""
|
|
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
|
data_cpu = data.detach().cpu()
|
|
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
|
_io_futures.append(fut)
|
|
|
|
|
|
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
|
|
video = torch.clamp(video.float(), -1., 1.)
|
|
n = video.shape[0]
|
|
if n == 1:
|
|
# Fast path for bs=1: skip make_grid and convert directly.
|
|
frames = video[0].permute(1, 2, 3, 0).contiguous()
|
|
frames = ((frames + 1.0) / 2.0 * 255).to(torch.uint8)
|
|
return frames.numpy()
|
|
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
return grid.numpy()
|
|
|
|
|
|
def _tensor_stats(name: str, tensor: Tensor) -> str:
|
|
t = tensor.detach().float()
|
|
finite_mask = torch.isfinite(t)
|
|
finite_ratio = finite_mask.float().mean().item()
|
|
if finite_mask.any():
|
|
tf = t[finite_mask]
|
|
mean = tf.mean().item()
|
|
std = tf.std(unbiased=False).item()
|
|
min_v = tf.min().item()
|
|
max_v = tf.max().item()
|
|
abs_mean = tf.abs().mean().item()
|
|
else:
|
|
mean = std = min_v = max_v = abs_mean = float('nan')
|
|
return (f"{name}: shape={tuple(t.shape)} dtype={tensor.dtype} "
|
|
f"mean={mean:.6f} std={std:.6f} min={min_v:.6f} max={max_v:.6f} "
|
|
f"abs_mean={abs_mean:.6f} finite={finite_ratio:.6f}")
|
|
|
|
|
|
def _debug_world_model_stats(wm_samples: Tensor, wm_video: Tensor,
|
|
prefix: str) -> None:
|
|
print(f">>> [debug_wm_stats] {prefix} latent {_tensor_stats('wm_samples', wm_samples)}")
|
|
print(f">>> [debug_wm_stats] {prefix} decoded {_tensor_stats('wm_video', wm_video)}")
|
|
|
|
|
|
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
|
|
writer = None
|
|
while True:
|
|
item = q.get()
|
|
if item is None:
|
|
break
|
|
frames = _video_tensor_to_frames(item)
|
|
if writer is None:
|
|
writer = imageio.get_writer(
|
|
filename,
|
|
fps=fps,
|
|
codec='libx264',
|
|
ffmpeg_params=['-crf', '10', '-pix_fmt', 'yuv420p'])
|
|
for frame in frames:
|
|
writer.append_data(frame)
|
|
if writer is not None:
|
|
writer.close()
|
|
|
|
|
|
def _stop_writer_process(writer_proc: mp.Process, write_q: mp.Queue) -> None:
|
|
try:
|
|
write_q.put(None, timeout=1.0)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
writer_proc.join(timeout=5.0)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
if writer_proc.is_alive():
|
|
_terminate_process(writer_proc)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
write_q.close()
|
|
write_q.join_thread()
|
|
except Exception:
|
|
pass
|
|
_unregister_child_process(writer_proc)
|
|
|
|
|
|
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
|
"""Construct the init_frame path from directory and sample metadata.
|
|
|
|
Args:
|
|
data_dir (str): Base directory containing videos.
|
|
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
|
|
|
|
Returns:
|
|
str: Full path to the video file.
|
|
"""
|
|
rel_video_fp = os.path.join(sample['data_dir'],
|
|
str(sample['videoid']) + '.png')
|
|
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
|
|
return full_image_fp
|
|
|
|
|
|
def get_transition_path(data_dir: str, sample: dict) -> str:
|
|
"""Construct the full transition file path from directory and sample metadata.
|
|
|
|
Args:
|
|
data_dir (str): Base directory containing transition files.
|
|
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
|
|
|
|
Returns:
|
|
str: Full path to the HDF5 transition file.
|
|
"""
|
|
rel_transition_fp = os.path.join(sample['data_dir'],
|
|
str(sample['videoid']) + '.h5')
|
|
full_transition_fp = os.path.join(data_dir, 'transitions',
|
|
rel_transition_fp)
|
|
return full_transition_fp
|
|
|
|
|
|
def prepare_init_input(start_idx: int,
|
|
init_frame_path: str,
|
|
transition_dict: dict[str, torch.Tensor],
|
|
frame_stride: int,
|
|
wma_data,
|
|
video_length: int = 16,
|
|
n_obs_steps: int = 2) -> dict[str, Tensor]:
|
|
"""
|
|
Extracts a structured sample from a video sequence including frames, states, and actions,
|
|
along with properly padded observations and pre-processed tensors for model input.
|
|
|
|
Args:
|
|
start_idx (int): Starting frame index for the current clip.
|
|
video: decord video instance.
|
|
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
|
|
'observation.state', 'action_type', 'state_type'.
|
|
frame_stride (int): Temporal stride between sampled frames.
|
|
wma_data: Object that holds configuration and utility functions like normalization,
|
|
transformation, and resolution info.
|
|
video_length (int, optional): Number of frames to sample from the video. Default is 16.
|
|
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
|
|
"""
|
|
|
|
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
|
init_frame = Image.open(init_frame_path).convert('RGB')
|
|
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
|
|
3, 0, 1, 2).float()
|
|
|
|
if start_idx < n_obs_steps - 1:
|
|
state_indices = list(range(0, start_idx + 1))
|
|
states = transition_dict['observation.state'][state_indices, :]
|
|
num_padding = n_obs_steps - 1 - start_idx
|
|
first_slice = states[0:1, :] # (t, d)
|
|
padding = first_slice.repeat(num_padding, 1)
|
|
states = torch.cat((padding, states), dim=0)
|
|
else:
|
|
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
|
states = transition_dict['observation.state'][state_indices, :]
|
|
|
|
actions = transition_dict['action'][indices, :]
|
|
|
|
ori_state_dim = states.shape[-1]
|
|
ori_action_dim = actions.shape[-1]
|
|
|
|
frames_action_state_dict = {
|
|
'action': actions,
|
|
'observation.state': states,
|
|
}
|
|
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
|
frames_action_state_dict = wma_data.get_uni_vec(
|
|
frames_action_state_dict,
|
|
transition_dict['action_type'],
|
|
transition_dict['state_type'],
|
|
)
|
|
|
|
if wma_data.spatial_transform is not None:
|
|
init_frame = wma_data.spatial_transform(init_frame)
|
|
init_frame = (init_frame / 255 - 0.5) * 2
|
|
|
|
data = {
|
|
'observation.image': init_frame,
|
|
}
|
|
data.update(frames_action_state_dict)
|
|
return data, ori_state_dim, ori_action_dim
|
|
|
|
|
|
def get_latent_z(model, videos: Tensor) -> Tensor:
|
|
"""
|
|
Extracts latent features from a video batch using the model's first-stage encoder.
|
|
|
|
Args:
|
|
model: the world model.
|
|
videos (Tensor): Input videos of shape [B, C, T, H, W].
|
|
|
|
Returns:
|
|
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
|
"""
|
|
b, c, t, h, w = videos.shape
|
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
|
z = model.encode_first_stage(x)
|
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
|
return z
|
|
|
|
|
|
def preprocess_observation(
|
|
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
|
"""Convert environment observation to LeRobot format observation.
|
|
Args:
|
|
observation: Dictionary of observation batches from a Gym vector environment.
|
|
Returns:
|
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
|
"""
|
|
# Map to expected inputs for the policy
|
|
return_observations = {}
|
|
|
|
if isinstance(observations["pixels"], dict):
|
|
imgs = {
|
|
f"observation.images.{key}": img
|
|
for key, img in observations["pixels"].items()
|
|
}
|
|
else:
|
|
imgs = {"observation.images.top": observations["pixels"]}
|
|
|
|
for imgkey, img in imgs.items():
|
|
img = torch.from_numpy(img)
|
|
|
|
# Sanity check that images are channel last
|
|
_, h, w, c = img.shape
|
|
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
|
|
|
# Sanity check that images are uint8
|
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
|
|
|
# Convert to channel first of type float32 in range [0,1]
|
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
img = img.type(torch.float32)
|
|
|
|
return_observations[imgkey] = img
|
|
|
|
return_observations["observation.state"] = torch.from_numpy(
|
|
observations["agent_pos"]).float()
|
|
return_observations['observation.state'] = model.normalize_inputs({
|
|
'observation.state':
|
|
return_observations['observation.state'].to(model.device)
|
|
})['observation.state']
|
|
|
|
return return_observations
|
|
|
|
|
|
def image_guided_synthesis_sim_mode(
|
|
model: torch.nn.Module,
|
|
prompts: list[str],
|
|
observation: dict,
|
|
noise_shape: tuple[int, int, int, int, int],
|
|
action_cond_step: int = 16,
|
|
n_samples: int = 1,
|
|
ddim_steps: int = 50,
|
|
ddim_eta: float = 1.0,
|
|
unconditional_guidance_scale: float = 1.0,
|
|
precision: int | None = 16,
|
|
fs: int | None = None,
|
|
text_input: bool = True,
|
|
timestep_spacing: str = 'uniform',
|
|
guidance_rescale: float = 0.0,
|
|
sim_mode: bool = True,
|
|
decode_video: bool = True,
|
|
pipeline_split_step: int = 0,
|
|
pipeline_compare_full: bool = False,
|
|
handoff_callback=None,
|
|
stop_at_handoff: bool = False,
|
|
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor,
|
|
torch.Tensor, dict[str, Any]]:
|
|
"""
|
|
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
|
|
|
Args:
|
|
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
|
|
prompts (list[str]): A list of textual prompts to guide the synthesis process.
|
|
observation (dict): A dictionary containing observed inputs including:
|
|
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
|
|
- 'observation.state': Tensor of shape [B, O, D] (state vector)
|
|
- 'action': Tensor of shape [B, T, D] (action sequence)
|
|
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
|
|
typically (B, C, T, H, W).
|
|
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
|
|
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
|
|
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
|
|
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
|
|
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
|
|
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
|
|
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
|
|
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
|
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
|
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
|
decode_video (bool): Whether to decode latent samples to pixel-space video.
|
|
Set to False to skip VAE decode for speed when only actions/states are needed.
|
|
**kwargs: Additional arguments passed to the DDIM sampler.
|
|
|
|
Returns:
|
|
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
|
|
or None when decode_video=False.
|
|
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
|
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
|
"""
|
|
b, _, t, _, _ = noise_shape
|
|
ddim_sampler = DDIMSampler(model)
|
|
batch_size = noise_shape[0]
|
|
|
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
|
|
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
|
cond_img_emb = model.embedder(cond_img)
|
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
|
|
if model.model.conditioning_key == 'hybrid':
|
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
|
img_cat_cond = z[:, :, -1:, :, :]
|
|
img_cat_cond = repeat(img_cat_cond,
|
|
'b c t h w -> b c (repeat t) h w',
|
|
repeat=noise_shape[2])
|
|
cond = {"c_concat": [img_cat_cond]}
|
|
|
|
if not text_input:
|
|
prompts = [""] * batch_size
|
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
|
|
|
cond_state_emb = model.state_projector(observation['observation.state'])
|
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
|
|
|
cond_action_emb = model.action_projector(observation['action'])
|
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
|
|
|
if not sim_mode:
|
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
|
|
|
cond["c_crossattn"] = [
|
|
torch.cat(
|
|
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
|
dim=1)
|
|
]
|
|
cond["c_crossattn_action"] = [
|
|
observation['observation.images.top'][:, :,
|
|
-model.n_obs_steps_acting:],
|
|
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
|
sim_mode,
|
|
False,
|
|
]
|
|
|
|
uc = None
|
|
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
|
cond_mask = None
|
|
cond_z0 = None
|
|
batch_variants = None
|
|
if ddim_sampler is not None:
|
|
sample_kwargs = dict(
|
|
S=ddim_steps,
|
|
conditioning=cond,
|
|
batch_size=batch_size,
|
|
shape=noise_shape[1:],
|
|
verbose=False,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=ddim_eta,
|
|
cfg_img=None,
|
|
mask=cond_mask,
|
|
x0=cond_z0,
|
|
precision=precision,
|
|
fs=fs,
|
|
timestep_spacing=timestep_spacing,
|
|
guidance_rescale=guidance_rescale,
|
|
**kwargs)
|
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
|
**sample_kwargs,
|
|
handoff_step=pipeline_split_step,
|
|
handoff_callback=handoff_callback,
|
|
stop_at_handoff=stop_at_handoff)
|
|
|
|
if decode_video:
|
|
# Reconstruct from latent to pixel space
|
|
batch_images = model.decode_first_stage(samples)
|
|
batch_variants = batch_images
|
|
|
|
return batch_variants, actions, states, samples, intermedia
|
|
|
|
|
|
def load_inference_runtime(args: argparse.Namespace,
|
|
gpu_no: int) -> SimpleNamespace:
|
|
"""Load the model once and keep all GPU-side runtime state together."""
|
|
config = OmegaConf.load(args.config)
|
|
|
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
|
if os.path.exists(prepared_path):
|
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
|
model = torch.load(prepared_path,
|
|
map_location=f"cuda:{gpu_no}",
|
|
weights_only=False,
|
|
mmap=True)
|
|
model.eval()
|
|
model = model.cuda(gpu_no)
|
|
diffusion_model = model.model.diffusion_model
|
|
if not hasattr(diffusion_model, '_ctx_cache_enabled'):
|
|
diffusion_model._ctx_cache_enabled = False
|
|
if not hasattr(diffusion_model, '_ctx_cache'):
|
|
diffusion_model._ctx_cache = {}
|
|
if not hasattr(diffusion_model, '_trt_backbone'):
|
|
diffusion_model._trt_backbone = None
|
|
if not hasattr(diffusion_model, '_state_stream'):
|
|
diffusion_model._state_stream = torch.cuda.Stream(
|
|
device=torch.device(f"cuda:{gpu_no}"))
|
|
print(f">>> Prepared model loaded.")
|
|
else:
|
|
config['model']['params']['wma_config']['params'][
|
|
'use_checkpoint'] = False
|
|
model = instantiate_from_config(config.model)
|
|
model.perframe_ae = args.perframe_ae
|
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
|
model = load_model_checkpoint(model, args.ckpt_path)
|
|
model.eval()
|
|
model = model.cuda(gpu_no)
|
|
print(f'>>> Load pre-trained model ...')
|
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
|
torch.save(model, prepared_path)
|
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
|
|
|
device = get_device_from_parameters(model)
|
|
sync_module_device_attributes(model, device)
|
|
if hasattr(model, 'cond_stage_model') and model.cond_stage_model is not None:
|
|
sync_module_device_attributes(model.cond_stage_model, device)
|
|
if hasattr(model.model.diffusion_model, '_state_stream'):
|
|
model.model.diffusion_model._state_stream = torch.cuda.Stream(
|
|
device=device)
|
|
|
|
# Fuse KV projections in attention layers (to_k + to_v -> to_kv).
|
|
from unifolm_wma.modules.attention import CrossAttention
|
|
kv_count = sum(1 for m in model.modules()
|
|
if isinstance(m, CrossAttention) and m.fuse_kv())
|
|
print(f" ✓ KV fused: {kv_count} attention layers")
|
|
|
|
# Load TRT backbone if engine exists and the user did not request Torch fallback.
|
|
trt_engine_path = args.trt_engine_path
|
|
if trt_engine_path is None:
|
|
trt_engine_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
|
'..', '..', 'trt_engines',
|
|
'video_backbone.engine')
|
|
if args.disable_trt:
|
|
print(">>> TRT disabled by --disable_trt; using PyTorch video backbone.")
|
|
elif os.path.exists(trt_engine_path):
|
|
model.model.diffusion_model.load_trt_backbone(trt_engine_path)
|
|
else:
|
|
print(f">>> TRT engine not found at {trt_engine_path}; using PyTorch video backbone.")
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize(device)
|
|
|
|
return SimpleNamespace(model=model, config=config, device=device, gpu_no=gpu_no)
|
|
|
|
|
|
def load_inference_data(config: OmegaConf,
|
|
args: argparse.Namespace) -> SimpleNamespace:
|
|
logging.info("***** Configing Data *****")
|
|
test_cfg = OmegaConf.create(
|
|
OmegaConf.to_container(config.data.params.test, resolve=True))
|
|
stats_root = OmegaConf.select(config, "data.params.test.params.data_dir")
|
|
if stats_root is None:
|
|
stats_root = args.prompt_dir
|
|
test_cfg.params["data_dir"] = stats_root
|
|
test_cfg.params["meta_path"] = os.path.join(args.prompt_dir,
|
|
f"{args.dataset}.csv")
|
|
test_cfg.params["transition_dir"] = os.path.join(stats_root,
|
|
"transitions")
|
|
test_cfg.params["dataset_name"] = args.dataset
|
|
|
|
target_dataset = instantiate_from_config(test_cfg)
|
|
data = SimpleNamespace(test_datasets={args.dataset: target_dataset})
|
|
print(">>> Dataset is successfully loaded ...")
|
|
return data
|
|
|
|
|
|
def build_pipeline_iteration_input(
|
|
runtime: SimpleNamespace,
|
|
args: argparse.Namespace,
|
|
base_queues: dict[str, deque],
|
|
action_seq: torch.Tensor,
|
|
noise_shape: list[int],
|
|
model_input_fs: int,
|
|
ori_action_dim: int,
|
|
ori_state_dim: int) -> dict[str, deque]:
|
|
model = runtime.model
|
|
device = runtime.device
|
|
policy_observation = build_observation_from_queues(base_queues, device)
|
|
pipeline_action_queues = clone_observation_queues(base_queues)
|
|
pipeline_action_queues = append_action_sequence(pipeline_action_queues,
|
|
action_seq,
|
|
ori_action_dim)
|
|
wm_handoff_observation = {
|
|
'observation.images.top': policy_observation['observation.images.top'],
|
|
'observation.state': policy_observation['observation.state'],
|
|
'action': torch.stack(list(pipeline_action_queues['action']),
|
|
dim=1).to(device, non_blocking=True),
|
|
}
|
|
_, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
wm_handoff_observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
precision=args.precision,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
decode_video=False,
|
|
pipeline_split_step=args.pipeline_split_step,
|
|
pipeline_compare_full=False,
|
|
stop_at_handoff=True)
|
|
wm_handoff = wm_intermedia.get('handoff', {})
|
|
handoff_video_latent = wm_handoff.get('pred_x0', wm_handoff.get(
|
|
'samples', None))
|
|
if handoff_video_latent is None:
|
|
handoff_video_latent = wm_samples
|
|
handoff_video = model.decode_first_stage(handoff_video_latent)
|
|
handoff_state_seq = wm_handoff.get('states', pred_states)
|
|
|
|
cond_obs_queues = clone_observation_queues(base_queues)
|
|
cond_obs_queues = append_action_sequence(cond_obs_queues, action_seq,
|
|
ori_action_dim)
|
|
cond_obs_queues = rollout_execution_segment(
|
|
queues=cond_obs_queues,
|
|
seg_video=handoff_video[:, :, :args.exe_steps],
|
|
pred_states=handoff_state_seq,
|
|
zero_action_template=action_seq[0][-1:],
|
|
exe_steps=args.exe_steps,
|
|
ori_state_dim=ori_state_dim,
|
|
zero_pred_state=args.zero_pred_state)
|
|
return cond_obs_queues
|
|
|
|
|
|
def run_pipeline_iteration_task(
|
|
runtime: SimpleNamespace,
|
|
args: argparse.Namespace,
|
|
iter_idx: int,
|
|
instruction: str,
|
|
input_payload: dict[str, Any],
|
|
noise_shape: list[int],
|
|
model_input_fs: int,
|
|
ori_action_dim: int,
|
|
ori_state_dim: int,
|
|
handoff_queue: Queue) -> dict[str, Any]:
|
|
model = runtime.model
|
|
device = runtime.device
|
|
|
|
if input_payload['mode'] == 'initial':
|
|
cond_obs_queues = move_observation_queues_to_device(
|
|
input_payload['queues_cpu'], device)
|
|
elif input_payload['mode'] == 'handoff':
|
|
base_queues = move_observation_queues_to_device(
|
|
input_payload['base_queues_cpu'], device)
|
|
action_seq = input_payload['action_seq_cpu'].to(
|
|
device, non_blocking=True)
|
|
source_iter = input_payload['source_iter']
|
|
pipeline_print(
|
|
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: building pipeline input '
|
|
f'from step {source_iter} handoff ...')
|
|
cond_obs_queues = build_pipeline_iteration_input(
|
|
runtime=runtime,
|
|
args=args,
|
|
base_queues=base_queues,
|
|
action_seq=action_seq,
|
|
noise_shape=noise_shape,
|
|
model_input_fs=model_input_fs,
|
|
ori_action_dim=ori_action_dim,
|
|
ori_state_dim=ori_state_dim)
|
|
else:
|
|
raise ValueError(f"Unsupported input mode: {input_payload['mode']}")
|
|
|
|
iter_input_queues = clone_observation_queues(cond_obs_queues)
|
|
iter_input_queues_cpu = clone_observation_queues_to_cpu(iter_input_queues)
|
|
handoff_sent = False
|
|
|
|
def _on_policy_handoff(handoff: dict[str, torch.Tensor]) -> None:
|
|
nonlocal handoff_sent
|
|
if handoff_sent or (iter_idx + 1) >= args.n_iter:
|
|
return
|
|
handoff_queue.put({
|
|
'source_iter':
|
|
iter_idx,
|
|
'next_iter':
|
|
iter_idx + 1,
|
|
'base_queues_cpu':
|
|
iter_input_queues_cpu,
|
|
'action_seq_cpu':
|
|
handoff['actions'].detach().cpu().clone(),
|
|
'handoff_step':
|
|
handoff.get('step', args.pipeline_split_step),
|
|
})
|
|
handoff_sent = True
|
|
handoff_step = handoff.get('step', args.pipeline_split_step)
|
|
pipeline_print(
|
|
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: emitted policy handoff '
|
|
f'for step {iter_idx + 1} at ddim step '
|
|
f'{handoff_step} ...')
|
|
|
|
policy_observation = build_observation_from_queues(cond_obs_queues, device)
|
|
pipeline_print(
|
|
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: generating actions ...')
|
|
_, pred_actions, _, _, _ = image_guided_synthesis_sim_mode(
|
|
model,
|
|
instruction,
|
|
policy_observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
precision=args.precision,
|
|
fs=model_input_fs,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
sim_mode=False,
|
|
decode_video=False,
|
|
pipeline_split_step=args.pipeline_split_step,
|
|
pipeline_compare_full=False,
|
|
handoff_callback=_on_policy_handoff if args.pipeline_split_step > 0 else
|
|
None)
|
|
|
|
cond_obs_queues = append_action_sequence(cond_obs_queues, pred_actions,
|
|
ori_action_dim)
|
|
wm_observation = {
|
|
'observation.images.top': policy_observation['observation.images.top'],
|
|
'observation.state': policy_observation['observation.state'],
|
|
'action': torch.stack(list(cond_obs_queues['action']), dim=1).to(
|
|
device, non_blocking=True),
|
|
}
|
|
pipeline_print(
|
|
f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: interacting with world model ...')
|
|
_, _, pred_states, wm_samples, _ = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
wm_observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
|
precision=args.precision,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
decode_video=False,
|
|
pipeline_split_step=0,
|
|
pipeline_compare_full=False)
|
|
|
|
wm_video = model.decode_first_stage(wm_samples)
|
|
if args.debug_wm_stats:
|
|
_debug_world_model_stats(wm_samples,
|
|
wm_video,
|
|
prefix=f"step={iter_idx}@gpu{runtime.gpu_no}")
|
|
seg_video = wm_video[:, :, :args.exe_steps]
|
|
pipeline_print('>' * 24)
|
|
return {
|
|
'iter_idx': iter_idx,
|
|
'gpu_no': runtime.gpu_no,
|
|
'seg_video_cpu': seg_video.detach().cpu(),
|
|
}
|
|
|
|
|
|
def run_inference_multi_gpu_pipeline(
|
|
args: argparse.Namespace,
|
|
runtimes: list[SimpleNamespace]) -> None:
|
|
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
|
config = runtimes[0].config
|
|
data = load_inference_data(config, args)
|
|
df = pd.read_csv(os.path.join(args.prompt_dir, f"{args.dataset}.csv"))
|
|
|
|
model0 = runtimes[0].model
|
|
assert (args.height % 16 == 0) and (
|
|
args.width % 16
|
|
== 0), "Error: image size [h,w] should be multiples of 16!"
|
|
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
|
|
|
h, w = args.height // 8, args.width // 8
|
|
channels = model0.model.diffusion_model.out_channels
|
|
n_frames = args.video_length
|
|
pipeline_print(f'>>> Generate {n_frames} frames under each generation ...')
|
|
noise_shape = [args.bs, channels, n_frames, h, w]
|
|
gpu_ids = [runtime.gpu_no for runtime in runtimes]
|
|
pipeline_print(f'>>> Multi-GPU pipeline enabled on GPUs {gpu_ids}.')
|
|
|
|
for idx in range(0, len(df)):
|
|
sample = df.iloc[idx]
|
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
|
ori_fps = float(sample['fps'])
|
|
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
|
|
os.makedirs(video_save_dir, exist_ok=True)
|
|
|
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
|
with h5py.File(transition_path, 'r') as h5f:
|
|
transition_dict = {}
|
|
for key in h5f.keys():
|
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
|
for key in h5f.attrs.keys():
|
|
transition_dict[key] = h5f.attrs[key]
|
|
|
|
for fs in args.frame_stride:
|
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
|
write_q = mp.Queue(maxsize=4)
|
|
writer_proc = mp.Process(target=_video_writer_process,
|
|
args=(write_q, sample_full_video_file,
|
|
args.save_fps))
|
|
writer_proc.daemon = True
|
|
writer_proc.start()
|
|
_register_child_process(writer_proc)
|
|
executors = {
|
|
runtime.gpu_no: ThreadPoolExecutor(max_workers=1)
|
|
for runtime in runtimes
|
|
}
|
|
handoff_events = Queue()
|
|
future_meta = {}
|
|
submitted_iters = set()
|
|
completed_results = {}
|
|
next_write_idx = 0
|
|
model_input_fs = ori_fps // fs
|
|
progress = tqdm(total=args.n_iter,
|
|
ascii=False,
|
|
desc=f'fs={fs} pipeline',
|
|
leave=True)
|
|
try:
|
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
|
0,
|
|
init_frame_path,
|
|
transition_dict,
|
|
fs,
|
|
data.test_datasets[args.dataset],
|
|
n_obs_steps=model0.n_obs_steps_imagen)
|
|
initial_observation = {
|
|
'observation.images.top':
|
|
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
|
'observation.state':
|
|
batch['observation.state'][-1].unsqueeze(0),
|
|
'action':
|
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
|
}
|
|
initial_queues = {
|
|
"observation.images.top":
|
|
deque(maxlen=model0.n_obs_steps_imagen),
|
|
"observation.state": deque(maxlen=model0.n_obs_steps_imagen),
|
|
"action": deque(maxlen=args.video_length),
|
|
}
|
|
initial_queues = populate_queues(initial_queues,
|
|
initial_observation)
|
|
initial_payload = {
|
|
'mode': 'initial',
|
|
'queues_cpu': clone_observation_queues_to_cpu(initial_queues),
|
|
}
|
|
first_gpu = runtimes[0].gpu_no
|
|
future = executors[first_gpu].submit(
|
|
run_pipeline_iteration_task, runtimes[0], args, 0,
|
|
sample['instruction'], initial_payload, noise_shape,
|
|
model_input_fs, ori_action_dim, ori_state_dim,
|
|
handoff_events)
|
|
future_meta[future] = {'iter_idx': 0, 'gpu_no': first_gpu}
|
|
submitted_iters.add(0)
|
|
|
|
while next_write_idx < args.n_iter:
|
|
done_futures = [fut for fut in list(future_meta.keys())
|
|
if fut.done()]
|
|
for fut in done_futures:
|
|
meta = future_meta.pop(fut)
|
|
result = fut.result()
|
|
completed_results[result['iter_idx']] = result
|
|
pipeline_print(
|
|
f'>>> Step {result["iter_idx"]}@gpu{meta["gpu_no"]}: '
|
|
f'final segment ready.')
|
|
|
|
while True:
|
|
try:
|
|
event = handoff_events.get_nowait()
|
|
except Empty:
|
|
break
|
|
next_iter = event['next_iter']
|
|
if next_iter >= args.n_iter or next_iter in submitted_iters:
|
|
continue
|
|
target_runtime = runtimes[next_iter % len(runtimes)]
|
|
payload = {
|
|
'mode': 'handoff',
|
|
'base_queues_cpu': event['base_queues_cpu'],
|
|
'action_seq_cpu': event['action_seq_cpu'],
|
|
'source_iter': event['source_iter'],
|
|
'handoff_step': event['handoff_step'],
|
|
}
|
|
pipeline_print(
|
|
f'>>> Step {event["source_iter"]}: queueing step '
|
|
f'{next_iter} on gpu{target_runtime.gpu_no} from '
|
|
f'ddim step {event["handoff_step"]} ...')
|
|
future = executors[target_runtime.gpu_no].submit(
|
|
run_pipeline_iteration_task, target_runtime, args,
|
|
next_iter, sample['instruction'], payload,
|
|
noise_shape, model_input_fs, ori_action_dim,
|
|
ori_state_dim, handoff_events)
|
|
future_meta[future] = {
|
|
'iter_idx': next_iter,
|
|
'gpu_no': target_runtime.gpu_no
|
|
}
|
|
submitted_iters.add(next_iter)
|
|
|
|
while next_write_idx in completed_results:
|
|
result = completed_results.pop(next_write_idx)
|
|
write_q.put(result['seg_video_cpu'])
|
|
next_write_idx += 1
|
|
progress.update(1)
|
|
progress.set_postfix_str(
|
|
f'written={next_write_idx}/{args.n_iter}',
|
|
refresh=False)
|
|
|
|
if next_write_idx < args.n_iter:
|
|
if not future_meta and len(submitted_iters) >= args.n_iter:
|
|
raise RuntimeError(
|
|
"Pipeline scheduler is idle before all iterations finished.")
|
|
time.sleep(0.05)
|
|
finally:
|
|
progress.close()
|
|
for executor in executors.values():
|
|
executor.shutdown(wait=True, cancel_futures=False)
|
|
_stop_writer_process(writer_proc, write_q)
|
|
|
|
_flush_io()
|
|
|
|
|
|
def run_inference(args: argparse.Namespace,
|
|
gpu_num: int,
|
|
gpu_no: int,
|
|
runtime: Optional[SimpleNamespace] = None) -> None:
|
|
"""
|
|
Run inference pipeline on prompts and image inputs.
|
|
|
|
Args:
|
|
args (argparse.Namespace): Parsed command-line arguments.
|
|
gpu_num (int): Number of GPUs.
|
|
gpu_no (int): Index of the current GPU.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
# Create inference dir
|
|
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
|
|
|
# Load prompt
|
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
|
df = pd.read_csv(csv_path)
|
|
|
|
# Load config (always needed for data setup)
|
|
config = runtime.config if runtime is not None else OmegaConf.load(args.config)
|
|
|
|
# Parallel loading: model and data are independent
|
|
def _load_model():
|
|
return load_inference_runtime(args, gpu_no).model
|
|
|
|
def _load_data():
|
|
return load_inference_data(config, args)
|
|
|
|
if runtime is None:
|
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
model_future = executor.submit(_load_model)
|
|
data_future = executor.submit(_load_data)
|
|
model = model_future.result()
|
|
data = data_future.result()
|
|
device = get_device_from_parameters(model)
|
|
else:
|
|
model = runtime.model
|
|
device = runtime.device
|
|
data = _load_data()
|
|
|
|
# Run over data
|
|
assert (args.height % 16 == 0) and (
|
|
args.width % 16
|
|
== 0), "Error: image size [h,w] should be multiples of 16!"
|
|
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
|
|
|
# Get latent noise shape
|
|
h, w = args.height // 8, args.width // 8
|
|
channels = model.model.diffusion_model.out_channels
|
|
n_frames = args.video_length
|
|
print(f'>>> Generate {n_frames} frames under each generation ...')
|
|
noise_shape = [args.bs, channels, n_frames, h, w]
|
|
|
|
# Start inference
|
|
for idx in range(0, len(df)):
|
|
sample = df.iloc[idx]
|
|
|
|
# Got initial frame path
|
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
|
ori_fps = float(sample['fps'])
|
|
|
|
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
|
|
os.makedirs(video_save_dir, exist_ok=True)
|
|
|
|
# Load transitions to get the initial state later
|
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
|
with h5py.File(transition_path, 'r') as h5f:
|
|
transition_dict = {}
|
|
for key in h5f.keys():
|
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
|
for key in h5f.attrs.keys():
|
|
transition_dict[key] = h5f.attrs[key]
|
|
|
|
# If many, test various frequence control and world-model generation
|
|
for fs in args.frame_stride:
|
|
# Writer process for incremental video saving
|
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
|
write_q = mp.Queue(maxsize=4)
|
|
writer_proc = mp.Process(
|
|
target=_video_writer_process,
|
|
args=(write_q, sample_full_video_file, args.save_fps))
|
|
writer_proc.daemon = True
|
|
writer_proc.start()
|
|
_register_child_process(writer_proc)
|
|
try:
|
|
# Initialize observation queues
|
|
cond_obs_queues = {
|
|
"observation.images.top":
|
|
deque(maxlen=model.n_obs_steps_imagen),
|
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
|
"action": deque(maxlen=args.video_length),
|
|
}
|
|
# Obtain initial frame and state
|
|
start_idx = 0
|
|
model_input_fs = ori_fps // fs
|
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
|
start_idx,
|
|
init_frame_path,
|
|
transition_dict,
|
|
fs,
|
|
data.test_datasets[args.dataset],
|
|
n_obs_steps=model.n_obs_steps_imagen)
|
|
observation = {
|
|
'observation.images.top':
|
|
batch['observation.image'].permute(1, 0, 2,
|
|
3)[-1].unsqueeze(0),
|
|
'observation.state':
|
|
batch['observation.state'][-1].unsqueeze(0),
|
|
'action':
|
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
|
|
}
|
|
observation = {
|
|
key: observation[key].to(device, non_blocking=True)
|
|
for key in observation
|
|
}
|
|
# Update observation queues
|
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
|
|
|
# Multi-round interaction with the world-model
|
|
pending_handoff = None
|
|
pending_handoff_source_itr = None
|
|
pending_handoff_ddim_step = None
|
|
for itr in tqdm(range(args.n_iter), ascii=False):
|
|
|
|
# Build observation for policy pass.
|
|
if pending_handoff is not None:
|
|
print(
|
|
f'>>> Step {itr}: consuming pipeline handoff '
|
|
f'from step {pending_handoff_source_itr} '
|
|
f'at ddim step {pending_handoff_ddim_step} ...')
|
|
cond_obs_queues = pending_handoff
|
|
pending_handoff = None
|
|
pending_handoff_source_itr = None
|
|
pending_handoff_ddim_step = None
|
|
|
|
iter_input_queues = clone_observation_queues(cond_obs_queues)
|
|
policy_observation = build_observation_from_queues(
|
|
cond_obs_queues, device)
|
|
|
|
# Use world-model in policy to generate action
|
|
print(f'>>> Step {itr}: generating actions ...')
|
|
_, pred_actions, _, _, policy_intermedia = image_guided_synthesis_sim_mode(
|
|
model,
|
|
sample['instruction'],
|
|
policy_observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
precision=args.precision,
|
|
fs=model_input_fs,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
sim_mode=False,
|
|
decode_video=False,
|
|
pipeline_split_step=args.pipeline_split_step,
|
|
pipeline_compare_full=False)
|
|
|
|
pipeline_action_seq = None
|
|
pipeline_pred_states = None
|
|
pipeline_wm_samples = None
|
|
pipeline_wm_intermedia = {}
|
|
if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter:
|
|
pipeline_action_seq = policy_intermedia.get(
|
|
'handoff', {}).get('actions', pred_actions)
|
|
pipeline_action_queues = clone_observation_queues(
|
|
iter_input_queues)
|
|
pipeline_action_queues = append_action_sequence(
|
|
pipeline_action_queues, pipeline_action_seq,
|
|
ori_action_dim)
|
|
wm_handoff_observation = {
|
|
'observation.images.top':
|
|
policy_observation['observation.images.top'],
|
|
'observation.state':
|
|
policy_observation['observation.state'],
|
|
'action':
|
|
torch.stack(list(pipeline_action_queues['action']),
|
|
dim=1).to(device, non_blocking=True),
|
|
}
|
|
print(
|
|
f'>>> Step {itr}: preparing pipeline handoff branch ...'
|
|
)
|
|
_, _, pipeline_pred_states, pipeline_wm_samples, pipeline_wm_intermedia = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
wm_handoff_observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
precision=args.precision,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
decode_video=False,
|
|
pipeline_split_step=args.pipeline_split_step,
|
|
pipeline_compare_full=False)
|
|
|
|
# Update future actions in the observation queues
|
|
cond_obs_queues = append_action_sequence(
|
|
cond_obs_queues, pred_actions, ori_action_dim)
|
|
|
|
# Reuse images/state and only rebuild action for WM pass.
|
|
wm_observation = {
|
|
'observation.images.top': policy_observation['observation.images.top'],
|
|
'observation.state': policy_observation['observation.state'],
|
|
'action': torch.stack(list(cond_obs_queues['action']), dim=1).to(
|
|
device, non_blocking=True),
|
|
}
|
|
|
|
# Interaction with the world-model
|
|
print(f'>>> Step {itr}: interacting with world model ...')
|
|
_, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
wm_observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
precision=args.precision,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
decode_video=False,
|
|
pipeline_split_step=args.pipeline_split_step,
|
|
pipeline_compare_full=False)
|
|
|
|
# Decode full WM clip, then take executable segment to keep behavior closer to previous path.
|
|
wm_video = model.decode_first_stage(wm_samples)
|
|
if args.debug_wm_stats:
|
|
_debug_world_model_stats(
|
|
wm_samples,
|
|
wm_video,
|
|
prefix=f"step={itr}")
|
|
seg_video = wm_video[:, :, :args.exe_steps]
|
|
|
|
cond_obs_queues = rollout_execution_segment(
|
|
queues=cond_obs_queues,
|
|
seg_video=seg_video,
|
|
pred_states=pred_states,
|
|
zero_action_template=pred_actions[0][-1:],
|
|
exe_steps=args.exe_steps,
|
|
ori_state_dim=ori_state_dim,
|
|
zero_pred_state=args.zero_pred_state)
|
|
|
|
if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter:
|
|
policy_handoff = policy_intermedia.get('handoff', {})
|
|
wm_handoff = pipeline_wm_intermedia.get('handoff', {})
|
|
handoff_ddim_step = wm_handoff.get(
|
|
'step',
|
|
policy_handoff.get('step', args.pipeline_split_step))
|
|
next_action_seq = pipeline_action_seq
|
|
handoff_video_latent = wm_handoff.get(
|
|
'pred_x0', wm_handoff.get('samples', None))
|
|
if handoff_video_latent is None:
|
|
handoff_video_latent = pipeline_wm_samples
|
|
handoff_video = model.decode_first_stage(
|
|
handoff_video_latent)
|
|
handoff_state_seq = wm_handoff.get('states',
|
|
pipeline_pred_states)
|
|
|
|
pending_handoff = clone_observation_queues(
|
|
iter_input_queues)
|
|
pending_handoff = append_action_sequence(
|
|
pending_handoff, next_action_seq, ori_action_dim)
|
|
pending_handoff = rollout_execution_segment(
|
|
queues=pending_handoff,
|
|
seg_video=handoff_video[:, :, :args.exe_steps],
|
|
pred_states=handoff_state_seq,
|
|
zero_action_template=next_action_seq[0][-1:],
|
|
exe_steps=args.exe_steps,
|
|
ori_state_dim=ori_state_dim,
|
|
zero_pred_state=args.zero_pred_state)
|
|
pending_handoff_source_itr = itr
|
|
pending_handoff_ddim_step = handoff_ddim_step
|
|
print(
|
|
f'>>> Step {itr}: prepared pipeline handoff '
|
|
f'for step {itr + 1} at ddim step '
|
|
f'{handoff_ddim_step} ...')
|
|
|
|
print('>' * 24)
|
|
# Send decoded segment to writer process
|
|
write_q.put(seg_video.detach().cpu())
|
|
finally:
|
|
_stop_writer_process(writer_proc, write_q)
|
|
|
|
# Wait for all async I/O to complete
|
|
_flush_io()
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--savedir",
|
|
type=str,
|
|
default=None,
|
|
help="Path to save the results.")
|
|
parser.add_argument("--ckpt_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to the model checkpoint.")
|
|
parser.add_argument("--config",
|
|
type=str,
|
|
help="Path to the model checkpoint.")
|
|
parser.add_argument(
|
|
"--prompt_dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory containing videos and corresponding prompts.")
|
|
parser.add_argument("--dataset",
|
|
type=str,
|
|
default=None,
|
|
help="the name of dataset to test")
|
|
parser.add_argument(
|
|
"--ddim_steps",
|
|
type=int,
|
|
default=50,
|
|
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
|
parser.add_argument(
|
|
"--ddim_eta",
|
|
type=float,
|
|
default=1.0,
|
|
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
|
parser.add_argument("--bs",
|
|
type=int,
|
|
default=1,
|
|
help="Batch size for inference. Must be 1.")
|
|
parser.add_argument("--height",
|
|
type=int,
|
|
default=320,
|
|
help="Height of the generated images in pixels.")
|
|
parser.add_argument("--width",
|
|
type=int,
|
|
default=512,
|
|
help="Width of the generated images in pixels.")
|
|
parser.add_argument(
|
|
"--frame_stride",
|
|
type=int,
|
|
nargs='+',
|
|
required=True,
|
|
help=
|
|
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
|
|
)
|
|
parser.add_argument(
|
|
"--unconditional_guidance_scale",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scale for classifier-free guidance during sampling.")
|
|
parser.add_argument("--seed",
|
|
type=int,
|
|
default=123,
|
|
help="Random seed for reproducibility.")
|
|
parser.add_argument("--video_length",
|
|
type=int,
|
|
default=16,
|
|
help="Number of frames in the generated video.")
|
|
parser.add_argument("--num_generation",
|
|
type=int,
|
|
default=1,
|
|
help="seed for seed_everything")
|
|
parser.add_argument(
|
|
"--timestep_spacing",
|
|
type=str,
|
|
default="uniform",
|
|
help=
|
|
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_rescale",
|
|
type=float,
|
|
default=0.0,
|
|
help=
|
|
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
|
)
|
|
parser.add_argument(
|
|
"--perframe_ae",
|
|
action='store_true',
|
|
default=False,
|
|
help=
|
|
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
|
)
|
|
parser.add_argument(
|
|
"--precision",
|
|
type=int,
|
|
default=16,
|
|
choices=[16, 32],
|
|
help="Sampling precision for latent/action/state noise initialization. Default is 16.")
|
|
parser.add_argument(
|
|
"--n_action_steps",
|
|
type=int,
|
|
default=16,
|
|
help="num of samples per prompt",
|
|
)
|
|
parser.add_argument(
|
|
"--exe_steps",
|
|
type=int,
|
|
default=16,
|
|
help="num of samples to execute",
|
|
)
|
|
parser.add_argument(
|
|
"--n_iter",
|
|
type=int,
|
|
default=40,
|
|
help="num of iteration to interact with the world model",
|
|
)
|
|
parser.add_argument("--zero_pred_state",
|
|
action='store_true',
|
|
default=False,
|
|
help="not using the predicted states as comparison")
|
|
parser.add_argument(
|
|
"--fast_policy_no_decode",
|
|
action='store_true',
|
|
default=False,
|
|
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
|
parser.add_argument(
|
|
"--debug_wm_stats",
|
|
action='store_true',
|
|
default=False,
|
|
help="Print latent/decode statistics for world-model samples before writing video.")
|
|
parser.add_argument(
|
|
"--disable_trt",
|
|
action='store_true',
|
|
default=False,
|
|
help="Disable TensorRT backbone loading and force the PyTorch video backbone path.")
|
|
parser.add_argument(
|
|
"--trt_engine_path",
|
|
type=str,
|
|
default=None,
|
|
help="Optional explicit TensorRT engine path. Defaults to trt_engines/video_backbone.engine.")
|
|
parser.add_argument("--save_fps",
|
|
type=int,
|
|
default=8,
|
|
help="fps for the saving video")
|
|
parser.add_argument(
|
|
"--pipeline_split_step",
|
|
type=int,
|
|
default=0,
|
|
help="Run DDIM sampling in two segments on a single GPU for pipeline experiments. "
|
|
"Set to a value like 25 to test 25+25 splitting; 0 disables it.")
|
|
parser.add_argument(
|
|
"--pipeline_compare_full",
|
|
action='store_true',
|
|
default=False,
|
|
help="When pipeline_split_step is enabled, also run the original full DDIM pass "
|
|
"with the same initial noise and print max-abs diffs for validation.")
|
|
parser.add_argument(
|
|
"--pipeline_multi_gpu",
|
|
action='store_true',
|
|
default=False,
|
|
help="Enable true asynchronous pipeline execution across multiple GPUs. "
|
|
"When disabled, keep the original single-GPU/serial inference path.")
|
|
parser.add_argument(
|
|
"--pipeline_gpu_ids",
|
|
type=int,
|
|
nargs='+',
|
|
default=[0, 1],
|
|
help="Logical CUDA device ids used by the multi-GPU pipeline scheduler.")
|
|
return parser
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
seed = args.seed
|
|
if seed < 0:
|
|
seed = random.randint(0, 2**31)
|
|
seed_everything(seed)
|
|
if args.pipeline_multi_gpu:
|
|
if args.pipeline_split_step <= 0:
|
|
raise ValueError(
|
|
"--pipeline_multi_gpu requires --pipeline_split_step > 0.")
|
|
if len(args.pipeline_gpu_ids) < 2:
|
|
raise ValueError(
|
|
"--pipeline_multi_gpu requires at least two gpu ids.")
|
|
runtimes = [
|
|
load_inference_runtime(args, gpu_id)
|
|
for gpu_id in args.pipeline_gpu_ids
|
|
]
|
|
run_inference_multi_gpu_pipeline(args, runtimes[:2])
|
|
else:
|
|
rank, gpu_num = 0, 1
|
|
run_inference(args, gpu_num, rank)
|