Files
unifolm-world-model-action/scripts/evaluation/world_model_interaction.py
olivame 1d23e5d36d Layer 3: 延迟 decode,只解码 CLIP 需要的 1 帧
- world model 调用 decode_video=False,跳过 16 帧全量 decode
- 只 decode 最后 1 帧给 CLIP embedding / observation queue
- 存 raw latent,循环结束后统一 batch decode 生成最终视频
- 每轮省 15 次 VAE decode,8 轮共省 120 次
- 跳过中间迭代的 wm tensorboard/mp4 保存

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-11 07:11:55 +00:00

1015 lines
40 KiB
Python

import argparse, os, glob
from contextlib import nullcontext
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): The model whose device is to be inferred.
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
"""Apply precision settings to model components based on command-line arguments.
Args:
model (nn.Module): The model to apply precision settings to.
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
Returns:
nn.Module: Model with precision settings applied.
"""
print(f">>> Applying precision settings:")
print(f" - Diffusion dtype: {args.diffusion_dtype}")
print(f" - Projector mode: {args.projector_mode}")
print(f" - Encoder mode: {args.encoder_mode}")
print(f" - VAE dtype: {args.vae_dtype}")
# 1. Set Diffusion backbone precision
if args.diffusion_dtype == "bf16":
# Convert diffusion model weights to bf16
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
print(" ✓ Diffusion model weights converted to bfloat16")
else:
model.diffusion_autocast_dtype = torch.bfloat16
print(" ✓ Diffusion model using fp32")
# 2. Set Projector precision
if args.projector_mode == "bf16_full":
model.state_projector.to(torch.bfloat16)
model.action_projector.to(torch.bfloat16)
model.projector_autocast_dtype = None
print(" ✓ Projectors converted to bfloat16")
elif args.projector_mode == "autocast":
model.projector_autocast_dtype = torch.bfloat16
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
else:
model.projector_autocast_dtype = None
# fp32 mode: do nothing, keep original precision
# 3. Set Encoder precision
if args.encoder_mode == "bf16_full":
model.embedder.to(torch.bfloat16)
model.image_proj_model.to(torch.bfloat16)
model.encoder_autocast_dtype = None
print(" ✓ Encoders converted to bfloat16")
elif args.encoder_mode == "autocast":
model.encoder_autocast_dtype = torch.bfloat16
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
else:
model.encoder_autocast_dtype = None
# fp32 mode: do nothing, keep original precision
# 4. Set VAE precision
if args.vae_dtype == "bf16":
model.first_stage_model.to(torch.bfloat16)
print(" ✓ VAE converted to bfloat16")
else:
print(" ✓ VAE kept in fp32 for best quality")
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
if args.diffusion_dtype == "bf16":
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
if fp32_params:
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
for name, param in fp32_params:
param.data = param.data.to(torch.bfloat16)
print(" ✓ All parameters converted to bfloat16")
return model
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
return model
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
Args:
video_path (str): Output path for the video.
stacked_frames (list): List of image frames.
fps (int): Frames per second for the video.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""Return sorted list of files in a directory matching specified postfixes.
Args:
data_dir (str): Directory path to search in.
postfixes (list[str]): List of file extensions to match.
Returns:
list[str]: Sorted list of file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
save_dir (str): Directory where results are saved.
filename (str): Name of the file to check.
Returns:
bool: True if processed file exists, False otherwise.
"""
video_file = os.path.join(save_dir, "samples_separate",
f"{filename[:-4]}_sample0.mp4")
return os.path.exists(video_file)
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
"""Save video tensor to file using torchvision.
Args:
video (Tensor): Tensor of shape (B, C, T, H, W).
filename (str): Output file path.
fps (int, optional): Frames per second. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
Args:
data_dir (str): Base directory containing videos.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the video file.
"""
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.png')
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
return full_image_fp
def get_transition_path(data_dir: str, sample: dict) -> str:
"""Construct the full transition file path from directory and sample metadata.
Args:
data_dir (str): Base directory containing transition files.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the HDF5 transition file.
"""
rel_transition_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def prepare_init_input(start_idx: int,
init_frame_path: str,
transition_dict: dict[str, torch.Tensor],
frame_stride: int,
wma_data,
video_length: int = 16,
n_obs_steps: int = 2) -> dict[str, Tensor]:
"""
Extracts a structured sample from a video sequence including frames, states, and actions,
along with properly padded observations and pre-processed tensors for model input.
Args:
start_idx (int): Starting frame index for the current clip.
video: decord video instance.
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
'observation.state', 'action_type', 'state_type'.
frame_stride (int): Temporal stride between sampled frames.
wma_data: Object that holds configuration and utility functions like normalization,
transformation, and resolution info.
video_length (int, optional): Number of frames to sample from the video. Default is 16.
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
"""
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {
'observation.image': init_frame,
}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
Args:
model: the world model.
videos (Tensor): Input videos of shape [B, C, T, H, W].
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
# Auto-detect VAE dtype and convert input
vae_dtype = next(model.first_stage_model.parameters()).dtype
x = x.to(dtype=vae_dtype)
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# Map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
imgs = {
f"observation.images.{key}": img
for key, img in observations["pixels"].items()
}
else:
imgs = {"observation.images.top": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# Sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# Sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# Convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(
observations["agent_pos"]).float()
return_observations['observation.state'] = model.normalize_inputs({
'observation.state':
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
decode_video: bool = True,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
Args:
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
prompts (list[str]): A list of textual prompts to guide the synthesis process.
observation (dict): A dictionary containing observed inputs including:
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
- 'observation.state': Tensor of shape [B, O, D] (state vector)
- 'action': Tensor of shape [B, T, D] (action sequence)
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
typically (B, C, T, H, W).
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
decode_video (bool): Whether to decode latent samples to pixel-space video.
Set to False to skip VAE decode for speed when only actions/states are needed.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
or None when decode_video=False.
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
# Auto-detect model dtype and convert inputs accordingly
model_dtype = next(model.embedder.parameters()).dtype
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
# Encoder autocast: weights stay fp32, compute in bf16
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
if enc_ac_dtype is not None and model.device.type == 'cuda':
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
else:
enc_ctx = nullcontext()
with enc_ctx:
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
# Auto-detect projector dtype and convert inputs
projector_dtype = next(model.state_projector.parameters()).dtype
# Projector autocast: weights stay fp32, compute in bf16
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
if proj_ac_dtype is not None and model.device.type == 'cuda':
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
else:
proj_ctx = nullcontext()
with proj_ctx:
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat(
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :,
-model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
# Setup autocast context for diffusion sampling
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
if autocast_dtype is not None and model.device.type == 'cuda':
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
else:
autocast_ctx = nullcontext()
batch_variants = None
samples = None
if ddim_sampler is not None:
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
if decode_video:
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states, samples
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
print(f'>>> Load pre-trained model ...')
# Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
from unifolm_wma.modules.attention import CrossAttention
kv_count = sum(1 for m in model.modules()
if isinstance(m, CrossAttention) and m.fuse_kv())
print(f" ✓ KV fused: {kv_count} attention layers")
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
torch.save({"state_dict": model.state_dict()}, export_path)
print(f">>> Precision-converted checkpoint saved to: {export_path}")
return
# Build unnomalizer
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
# Got initial frame path
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
os.makedirs(video_save_dir + '/dm', exist_ok=True)
os.makedirs(video_save_dir + '/wm', exist_ok=True)
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_latent = []
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
# Obtain initial frame and state
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
start_idx,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2,
3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Multi-round interaction with the world-model
for itr in tqdm(range(args.n_iter)):
# Get observation
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
observation = {'action': pred_actions[0][idx:idx + 1]}
observation['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
# Collect data for interacting the world-model using the predicted actions
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
decode_video=False)
# Decode only the last frame for CLIP embedding in next iteration
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
for idx in range(args.exe_steps):
observation = {
'observation.images.top':
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][idx:idx + 1]) if
args.zero_pred_state else pred_states[0][idx:idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:])
}
observation['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
# Save the imagen videos for decision-making
if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
print('>' * 24)
# Store raw latent for deferred decode
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
# Deferred decode: batch decode all stored latents
full_latent = torch.cat(wm_latent, dim=2).to(device)
full_video = model.decode_first_stage(full_latent).cpu()
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the model checkpoint.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument("--dataset",
type=str,
default=None,
help="the name of dataset to test")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
nargs='+',
required=True,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument("--num_generation",
type=int,
default=1,
help="seed for seed_everything")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--n_action_steps",
type=int,
default=16,
help="num of samples per prompt",
)
parser.add_argument(
"--exe_steps",
type=int,
default=16,
help="num of samples to execute",
)
parser.add_argument(
"--n_iter",
type=int,
default=40,
help="num of iteration to interact with the world model",
)
parser.add_argument("--zero_pred_state",
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument(
"--fast_policy_no_decode",
action='store_true',
default=False,
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="bf16",
help="Diffusion backbone precision (fp32/bf16)")
parser.add_argument(
"--projector_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="bf16_full",
help="Projector precision mode (fp32/autocast/bf16_full)")
parser.add_argument(
"--encoder_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="bf16_full",
help="Encoder precision mode (fp32/autocast/bf16_full)")
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="VAE precision (fp32/bf16, most affects image quality)")
parser.add_argument(
"--export_precision_ckpt",
type=str,
default=None,
help="Export precision-converted checkpoint to this path, then exit.")
return parser
if __name__ == '__main__':
patch_norm_bypass_autocast()
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)