Files
unifolm-world-model-action/scripts/evaluation/world_model_interaction.py
2026-03-15 12:41:53 +08:00

1372 lines
57 KiB
Python

import argparse, os, glob
import json
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
import time
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
import torch.nn.functional as F
from eval_utils import populate_queues, log_to_tensorboard
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): The model whose device is to be inferred.
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
def get_scene_name(sample: pd.Series, fallback: str) -> str:
"""Resolve the scene label used in analysis logs."""
if 'data_dir' in sample and pd.notna(sample['data_dir']):
return str(sample['data_dir'])
return fallback
def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
"""Build a stable sample id while keeping the required CSV schema flat."""
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Flatten all non-batch dimensions for batched metric computation."""
return tensor.detach().float().reshape(tensor.shape[0], -1)
def batch_relative_l2(current: torch.Tensor,
previous: torch.Tensor) -> list[float]:
"""Compute ||current-previous|| / ||previous|| for each item in the batch."""
current_flat = flatten_batch_tensor(current)
previous_flat = flatten_batch_tensor(previous)
numerator = torch.linalg.vector_norm(current_flat - previous_flat, dim=1)
denominator = torch.linalg.vector_norm(previous_flat, dim=1).clamp_min(1e-8)
return (numerator / denominator).cpu().tolist()
def batch_l2_distance(current: torch.Tensor,
reference: torch.Tensor) -> list[float]:
"""Compute L2 distance against a reference tensor for each batch item."""
current_flat = flatten_batch_tensor(current)
reference_flat = flatten_batch_tensor(reference)
return torch.linalg.vector_norm(current_flat - reference_flat,
dim=1).cpu().tolist()
def batch_cosine_similarity(current: torch.Tensor,
reference: torch.Tensor) -> list[float]:
"""Compute cosine similarity against a reference tensor for each batch item."""
current_flat = flatten_batch_tensor(current)
reference_flat = flatten_batch_tensor(reference)
return F.cosine_similarity(current_flat,
reference_flat,
dim=1,
eps=1e-8).cpu().tolist()
def first_consecutive_below(values: list[float], threshold: float,
window: int) -> float:
"""Return the first 1-based index where `window` consecutive values are below threshold."""
if window <= 0 or len(values) < window:
return np.nan
for start in range(len(values) - window + 1):
window_values = values[start:start + window]
if all(pd.notna(value) and value < threshold for value in window_values):
return float(start + 1)
return np.nan
def first_at_least(values: list[float], threshold: float) -> float:
"""Return the first 1-based index where the series reaches the threshold."""
for index, value in enumerate(values, start=1):
if pd.notna(value) and value >= threshold:
return float(index)
return np.nan
def first_at_most(values: list[float], threshold: float) -> float:
"""Return the first 1-based index where the series drops below the threshold."""
for index, value in enumerate(values, start=1):
if pd.notna(value) and value <= threshold:
return float(index)
return np.nan
def safe_mean(values: list[float]) -> float:
"""Average numeric values while ignoring NaNs."""
valid_values = [value for value in values if pd.notna(value)]
if not valid_values:
return np.nan
return float(np.mean(valid_values))
def make_sampling_noise_bundle(model: nn.Module,
noise_shape: list[int]) -> dict[str, torch.Tensor]:
"""Create aligned initial noise for latent, action, and state diffusion streams."""
batch_size = noise_shape[0]
horizon = noise_shape[2]
device = model.device
return {
'img': torch.randn(noise_shape, device=device),
'action': torch.randn((batch_size, horizon, model.agent_action_dim),
device=device),
'state': torch.randn((batch_size, horizon, model.agent_state_dim),
device=device),
}
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
"""Load optional PSNR values keyed by sample_id or videoid."""
if not psnr_path:
return {}
if not os.path.exists(psnr_path):
logging.warning("PSNR file not found: %s", psnr_path)
return {}
suffix = os.path.splitext(psnr_path)[1].lower()
lookup: dict[str, float] = {}
if suffix == '.csv':
df = pd.read_csv(psnr_path)
key_column = 'sample_id' if 'sample_id' in df.columns else 'videoid'
value_column = 'psnr_full50' if 'psnr_full50' in df.columns else 'psnr'
for _, row in df.iterrows():
if pd.notna(row[key_column]) and pd.notna(row[value_column]):
lookup[str(row[key_column])] = float(row[value_column])
return lookup
if suffix == '.json':
with open(psnr_path, 'r', encoding='utf-8') as file:
data = json.load(file)
if isinstance(data, dict):
if 'sample_id' in data and ('psnr_full50' in data or 'psnr' in data):
lookup[str(data['sample_id'])] = float(
data.get('psnr_full50', data['psnr']))
else:
for key, value in data.items():
if isinstance(value, (int, float)):
lookup[str(key)] = float(value)
elif isinstance(data, list):
for item in data:
if not isinstance(item, dict):
continue
if 'sample_id' in item and ('psnr_full50' in item or 'psnr' in
item):
lookup[str(item['sample_id'])] = float(
item.get('psnr_full50', item['psnr']))
return lookup
logging.warning("Unsupported PSNR file format: %s", psnr_path)
return {}
class InteractionAnalysisLogger:
"""Collect stepwise metrics and aggregated per-sample summaries."""
STEP_COLUMNS = [
'sample_id',
'scene',
'pass_type',
'round_id',
'step',
'step_time_s',
'latent_delta',
'action_delta',
'state_delta',
'action_cosine_vs_full50',
'state_cosine_vs_full50',
'latent_l2_vs_full50',
]
SUMMARY_COLUMNS = [
'sample_id',
'scene',
'pass_type',
'pass_total_time_s',
'action_first_stable_step',
'state_first_stable_step',
'latent_first_stable_step',
'action_vs_full50_90pct_step',
'action_vs_full50_95pct_step',
'oracle_budget_action',
'oracle_budget_state',
'oracle_budget_latent',
'latent_init_dist_to_prev_round',
'action_drift_vs_prev_round',
'round_total_time_s',
'policy_pass_total_time_s',
'world_model_pass_total_time_s',
'psnr_full50',
]
ROUND_COLUMNS = [
'sample_id',
'scene',
'round_id',
'policy_pass_total_time_s',
'world_model_pass_total_time_s',
'round_total_time_s',
'latent_init_dist_to_prev_round',
'action_drift_vs_prev_round',
'psnr_full50',
]
def __init__(self, output_dir: str, psnr_lookup: dict[str, float]):
self.output_dir = output_dir
self.psnr_lookup = psnr_lookup
self.step_rows: list[dict] = []
self.summary_buckets: dict[tuple[str, str, str], dict] = {}
self.round_rows: list[dict] = []
self.round_buckets: dict[tuple[str, str], dict] = {}
self.prev_policy_action: dict[str, torch.Tensor] = {}
self.prev_world_latent: dict[str, torch.Tensor] = {}
def resolve_psnr(self, sample_id: str, videoid: int) -> float:
"""Resolve a PSNR value by full sample id first, then by raw video id."""
candidates = [sample_id, sample_id.rsplit('-fs', 1)[0], str(videoid)]
for candidate in candidates:
if candidate in self.psnr_lookup:
return float(self.psnr_lookup[candidate])
return np.nan
def append_summary_row(self, row: dict) -> None:
"""Store per-round summaries and aggregate them later by sample and pass type."""
key = (row['sample_id'], row['scene'], row['pass_type'])
metric_columns = self.SUMMARY_COLUMNS[3:]
if key not in self.summary_buckets:
self.summary_buckets[key] = {
'sample_id': row['sample_id'],
'scene': row['scene'],
'pass_type': row['pass_type'],
**{column: [] for column in metric_columns},
}
for column in metric_columns:
self.summary_buckets[key][column].append(row.get(column, np.nan))
def append_round_row(self, row: dict) -> None:
"""Store per-round metrics and aggregate them later by sample."""
self.round_rows.append(row)
key = (row['sample_id'], row['scene'])
metric_columns = self.ROUND_COLUMNS[3:]
if key not in self.round_buckets:
self.round_buckets[key] = {
'sample_id': row['sample_id'],
'scene': row['scene'],
**{column: [] for column in metric_columns},
}
for column in metric_columns:
self.round_buckets[key][column].append(row.get(column, np.nan))
def collect_trace_series(self, debug_info: dict, reference_action: torch.Tensor,
reference_state: torch.Tensor,
reference_latent: torch.Tensor) -> tuple[list[float], list[float], list[float]]:
"""Extract cosine/L2 curves for either the target pass or the full-50 reference pass."""
action_cosines = []
state_cosines = []
latent_l2s = []
for record in debug_info['step_records']:
action_cosines.append(
batch_cosine_similarity(record['action'], reference_action)[0])
state_cosines.append(
batch_cosine_similarity(record['state'], reference_state)[0])
latent_l2s.append(
batch_l2_distance(record['pred_x0'], reference_latent)[0])
return action_cosines, state_cosines, latent_l2s
def log_pass(self, sample_id: str, videoid: int, scene: str, pass_type: str,
round_id: int, pass_total_time_s: float, target_debug: dict,
reference_debug: dict) -> dict | None:
"""Log one pass worth of stepwise and aggregated metrics."""
if not target_debug or not target_debug.get('step_records'):
return None
if not reference_debug or not reference_debug.get('step_records'):
reference_debug = target_debug
reference_final_action = reference_debug['step_records'][-1]['action']
reference_final_state = reference_debug['step_records'][-1]['state']
reference_final_latent = reference_debug['step_records'][-1]['pred_x0']
prev_img = target_debug['analysis_init']['img']
prev_action = target_debug['analysis_init']['action']
prev_state = target_debug['analysis_init']['state']
action_deltas: list[float] = []
state_deltas: list[float] = []
latent_deltas: list[float] = []
action_cosines: list[float] = []
state_cosines: list[float] = []
latent_l2s: list[float] = []
for record in target_debug['step_records']:
latent_delta = batch_relative_l2(record['img'], prev_img)[0]
action_delta = batch_relative_l2(record['action'], prev_action)[0]
state_delta = batch_relative_l2(record['state'], prev_state)[0]
action_cosine = batch_cosine_similarity(record['action'],
reference_final_action)[0]
state_cosine = batch_cosine_similarity(record['state'],
reference_final_state)[0]
latent_l2 = batch_l2_distance(record['pred_x0'],
reference_final_latent)[0]
action_deltas.append(action_delta)
state_deltas.append(state_delta)
latent_deltas.append(latent_delta)
action_cosines.append(action_cosine)
state_cosines.append(state_cosine)
latent_l2s.append(latent_l2)
self.step_rows.append({
'sample_id': sample_id,
'scene': scene,
'pass_type': pass_type,
'round_id': round_id,
'step': record['step_index'],
'step_time_s': float(record['step_time_s']),
'latent_delta': latent_delta,
'action_delta': action_delta,
'state_delta': state_delta,
'action_cosine_vs_full50': action_cosine,
'state_cosine_vs_full50': state_cosine,
'latent_l2_vs_full50': latent_l2,
})
prev_img = record['img']
prev_action = record['action']
prev_state = record['state']
oracle_action_cosines, oracle_state_cosines, oracle_latent_l2s = self.collect_trace_series(
reference_debug, reference_final_action, reference_final_state,
reference_final_latent)
latent_init_dist_to_prev_round = np.nan
action_drift_vs_prev_round = np.nan
if pass_type == 'policy':
previous_action = self.prev_policy_action.get(sample_id)
if previous_action is not None:
action_drift_vs_prev_round = 1.0 - batch_cosine_similarity(
reference_final_action, previous_action)[0]
self.prev_policy_action[sample_id] = reference_final_action.clone()
elif pass_type == 'world_model':
previous_latent = self.prev_world_latent.get(sample_id)
if previous_latent is not None:
latent_init_dist_to_prev_round = batch_l2_distance(
reference_final_latent, previous_latent)[0]
self.prev_world_latent[sample_id] = reference_final_latent.clone()
summary_row = {
'sample_id': sample_id,
'scene': scene,
'pass_type': pass_type,
'pass_total_time_s': float(pass_total_time_s),
'action_first_stable_step': np.nan,
'state_first_stable_step': np.nan,
'latent_first_stable_step': np.nan,
'action_vs_full50_90pct_step': first_at_least(action_cosines, 0.90),
'action_vs_full50_95pct_step': first_at_least(action_cosines, 0.95),
'oracle_budget_action': first_at_least(oracle_action_cosines, 0.95),
'oracle_budget_state': first_at_least(oracle_state_cosines, 0.95),
'oracle_budget_latent': np.nan,
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
'action_drift_vs_prev_round': action_drift_vs_prev_round,
'round_total_time_s': np.nan,
'policy_pass_total_time_s': np.nan,
'world_model_pass_total_time_s': np.nan,
'psnr_full50': self.resolve_psnr(sample_id, videoid),
}
self.append_summary_row(summary_row)
return summary_row
def log_round(self, sample_id: str, videoid: int, scene: str, round_id: int,
policy_pass_total_time_s: float,
world_model_pass_total_time_s: float, round_total_time_s: float,
latent_init_dist_to_prev_round: float,
action_drift_vs_prev_round: float) -> None:
"""Log one interaction round consisting of one policy pass and one world-model pass."""
self.append_round_row({
'sample_id': sample_id,
'scene': scene,
'round_id': round_id,
'policy_pass_total_time_s': float(policy_pass_total_time_s),
'world_model_pass_total_time_s': float(world_model_pass_total_time_s),
'round_total_time_s': float(round_total_time_s),
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
'action_drift_vs_prev_round': action_drift_vs_prev_round,
'psnr_full50': self.resolve_psnr(sample_id, videoid),
})
def flush(self) -> None:
"""Write analysis CSVs to disk."""
os.makedirs(self.output_dir, exist_ok=True)
stepwise_path = os.path.join(self.output_dir, 'stepwise_log.csv')
summary_path = os.path.join(self.output_dir, 'sample_summary.csv')
round_path = os.path.join(self.output_dir, 'round_summary.csv')
stepwise_df = pd.DataFrame(self.step_rows, columns=self.STEP_COLUMNS)
stepwise_df.to_csv(stepwise_path, index=False)
round_df = pd.DataFrame(self.round_rows, columns=self.ROUND_COLUMNS)
round_df.to_csv(round_path, index=False)
summary_rows = []
metric_columns = self.SUMMARY_COLUMNS[3:]
for bucket in self.summary_buckets.values():
round_bucket = self.round_buckets.get((bucket['sample_id'],
bucket['scene']))
row = {
'sample_id': bucket['sample_id'],
'scene': bucket['scene'],
'pass_type': bucket['pass_type'],
}
for column in metric_columns:
if round_bucket is not None and column in round_bucket:
row[column] = safe_mean(round_bucket[column])
else:
row[column] = safe_mean(bucket[column])
summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
summary_df.to_csv(summary_path, index=False)
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
Args:
video_path (str): Output path for the video.
stacked_frames (list): List of image frames.
fps (int): Frames per second for the video.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""Return sorted list of files in a directory matching specified postfixes.
Args:
data_dir (str): Directory path to search in.
postfixes (list[str]): List of file extensions to match.
Returns:
list[str]: Sorted list of file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
save_dir (str): Directory where results are saved.
filename (str): Name of the file to check.
Returns:
bool: True if processed file exists, False otherwise.
"""
video_file = os.path.join(save_dir, "samples_separate",
f"{filename[:-4]}_sample0.mp4")
return os.path.exists(video_file)
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
"""Save video tensor to file using torchvision.
Args:
video (Tensor): Tensor of shape (B, C, T, H, W).
filename (str): Output file path.
fps (int, optional): Frames per second. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
Args:
data_dir (str): Base directory containing videos.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the video file.
"""
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.png')
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
return full_image_fp
def get_transition_path(data_dir: str, sample: dict) -> str:
"""Construct the full transition file path from directory and sample metadata.
Args:
data_dir (str): Base directory containing transition files.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the HDF5 transition file.
"""
rel_transition_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def prepare_init_input(start_idx: int,
init_frame_path: str,
transition_dict: dict[str, torch.Tensor],
frame_stride: int,
wma_data,
video_length: int = 16,
n_obs_steps: int = 2) -> dict[str, Tensor]:
"""
Extracts a structured sample from a video sequence including frames, states, and actions,
along with properly padded observations and pre-processed tensors for model input.
Args:
start_idx (int): Starting frame index for the current clip.
video: decord video instance.
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
'observation.state', 'action_type', 'state_type'.
frame_stride (int): Temporal stride between sampled frames.
wma_data: Object that holds configuration and utility functions like normalization,
transformation, and resolution info.
video_length (int, optional): Number of frames to sample from the video. Default is 16.
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
"""
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {
'observation.image': init_frame,
}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
Args:
model: the world model.
videos (Tensor): Input videos of shape [B, C, T, H, W].
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# Map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
imgs = {
f"observation.images.{key}": img
for key, img in observations["pixels"].items()
}
else:
imgs = {"observation.images.top": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# Sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# Sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# Convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(
observations["agent_pos"]).float()
return_observations['observation.state'] = model.normalize_inputs({
'observation.state':
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
init_noise_bundle: dict[str, torch.Tensor] | None = None,
decode_video: bool = True,
return_debug_info: bool = False,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, dict | None]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
Args:
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
prompts (list[str]): A list of textual prompts to guide the synthesis process.
observation (dict): A dictionary containing observed inputs including:
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
- 'observation.state': Tensor of shape [B, O, D] (state vector)
- 'action': Tensor of shape [B, T, D] (action sequence)
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
typically (B, C, T, H, W).
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
decode_video (bool): Whether to decode the final latent into pixel space.
return_debug_info (bool): Whether to return per-step traces for analysis logging.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
debug_info (dict | None): Optional per-step trace used for convergence analysis.
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
batch_variants = None
debug_info = None
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat(
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :,
-model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
x_T=None if init_noise_bundle is None else init_noise_bundle['img'],
action_T=None if init_noise_bundle is None else
init_noise_bundle['action'],
state_T=None if init_noise_bundle is None else
init_noise_bundle['state'],
record_step_outputs=return_debug_info,
**kwargs)
batch_variants = None
if decode_video:
batch_variants = model.decode_first_stage(samples)
if return_debug_info:
debug_info = {
'analysis_init': intermedia.get('analysis_init'),
'step_records': intermedia.get('step_records', []),
'final_latent': samples.detach().cpu(),
'final_action': actions.detach().cpu(),
'final_state': states.detach().cpu(),
}
return batch_variants, actions, states, debug_info
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Create inference and tensorboard dirs
inference_dir = args.savedir + '/inference'
os.makedirs(inference_dir, exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
analysis_logger = None
if args.analysis_log_metrics:
analysis_logger = InteractionAnalysisLogger(
output_dir=inference_dir,
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
print(f'>>> Load pre-trained model ...')
# Build unnomalizer
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
if args.analysis_log_metrics:
assert args.ddim_steps > 0, "analysis_log_metrics requires positive --ddim_steps."
assert args.analysis_reference_steps > 0, (
"analysis_log_metrics requires positive --analysis_reference_steps.")
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
# Got initial frame path
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
os.makedirs(video_save_dir + '/dm', exist_ok=True)
os.makedirs(video_save_dir + '/wm', exist_ok=True)
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
sample_id = build_sample_id(args.dataset, sample, fs)
scene = get_scene_name(sample, args.dataset)
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_video = []
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
# Obtain initial frame and state
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
start_idx,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2,
3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Multi-round interaction with the world-model
for itr in tqdm(range(args.n_iter)):
round_start_time = time.time()
# Get observation
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
policy_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
policy_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.analysis_reference_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
decode_video=False,
return_debug_info=True)
policy_pass_start = time.time()
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
return_debug_info=args.analysis_log_metrics)
policy_pass_total_time_s = time.time() - policy_pass_start
policy_summary_row = None
if analysis_logger is not None:
if policy_reference_debug is None:
policy_reference_debug = policy_debug
policy_summary_row = analysis_logger.log_pass(
sample_id=sample_id,
videoid=int(sample['videoid']),
scene=scene,
pass_type='policy',
round_id=itr,
pass_total_time_s=policy_pass_total_time_s,
target_debug=policy_debug,
reference_debug=policy_reference_debug,
)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
observation = {'action': pred_actions[0][idx:idx + 1]}
observation['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
# Collect data for interacting the world-model using the predicted actions
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
world_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
world_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.analysis_reference_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
decode_video=False,
return_debug_info=True)
world_pass_start = time.time()
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
return_debug_info=args.analysis_log_metrics)
world_pass_total_time_s = time.time() - world_pass_start
world_summary_row = None
if analysis_logger is not None:
if world_reference_debug is None:
world_reference_debug = world_debug
world_summary_row = analysis_logger.log_pass(
sample_id=sample_id,
videoid=int(sample['videoid']),
scene=scene,
pass_type='world_model',
round_id=itr,
pass_total_time_s=world_pass_total_time_s,
target_debug=world_debug,
reference_debug=world_reference_debug,
)
analysis_logger.log_round(
sample_id=sample_id,
videoid=int(sample['videoid']),
scene=scene,
round_id=itr,
policy_pass_total_time_s=policy_pass_total_time_s,
world_model_pass_total_time_s=
world_pass_total_time_s,
round_total_time_s=time.time() - round_start_time,
latent_init_dist_to_prev_round=np.nan
if world_summary_row is None else
world_summary_row['latent_init_dist_to_prev_round'],
action_drift_vs_prev_round=np.nan
if policy_summary_row is None else
policy_summary_row['action_drift_vs_prev_round'],
)
for idx in range(args.exe_steps):
observation = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][idx:idx + 1]) if
args.zero_pred_state else pred_states[0][idx:idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:])
}
observation['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
# Save the imagen videos for decision-making
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
sample_video_file,
fps=args.save_fps)
print('>' * 24)
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
if analysis_logger is not None:
analysis_logger.flush()
writer.close()
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the model checkpoint.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument("--dataset",
type=str,
default=None,
help="the name of dataset to test")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
nargs='+',
required=True,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument("--num_generation",
type=int,
default=1,
help="seed for seed_everything")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--n_action_steps",
type=int,
default=16,
help="num of samples per prompt",
)
parser.add_argument(
"--exe_steps",
type=int,
default=16,
help="num of samples to execute",
)
parser.add_argument(
"--n_iter",
type=int,
default=40,
help="num of iteration to interact with the world model",
)
parser.add_argument("--zero_pred_state",
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
parser.add_argument("--analysis_log_metrics",
action='store_true',
default=False,
help="Enable DDIM convergence logging and export analysis CSVs.")
parser.add_argument(
"--analysis_reference_steps",
type=int,
default=50,
help="Reference DDIM steps used to build the full-step baseline for *_vs_full50 metrics."
)
parser.add_argument("--analysis_psnr_path",
type=str,
default=None,
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)