1372 lines
57 KiB
Python
1372 lines
57 KiB
Python
import argparse, os, glob
|
|
import json
|
|
import pandas as pd
|
|
import random
|
|
import torch
|
|
import torchvision
|
|
import h5py
|
|
import numpy as np
|
|
import logging
|
|
import einops
|
|
import warnings
|
|
import imageio
|
|
import time
|
|
|
|
from pytorch_lightning import seed_everything
|
|
from omegaconf import OmegaConf
|
|
from tqdm import tqdm
|
|
from einops import rearrange, repeat
|
|
from collections import OrderedDict
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from eval_utils import populate_queues, log_to_tensorboard
|
|
from collections import deque
|
|
from torch import Tensor
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from PIL import Image
|
|
|
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
|
|
|
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|
"""Get a module's device by checking one of its parameters.
|
|
|
|
Args:
|
|
module (nn.Module): The model whose device is to be inferred.
|
|
|
|
Returns:
|
|
torch.device: The device of the model's parameters.
|
|
"""
|
|
return next(iter(module.parameters())).device
|
|
|
|
|
|
def get_scene_name(sample: pd.Series, fallback: str) -> str:
|
|
"""Resolve the scene label used in analysis logs."""
|
|
if 'data_dir' in sample and pd.notna(sample['data_dir']):
|
|
return str(sample['data_dir'])
|
|
return fallback
|
|
|
|
|
|
def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
|
|
"""Build a stable sample id while keeping the required CSV schema flat."""
|
|
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
|
|
|
|
|
|
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Flatten all non-batch dimensions for batched metric computation."""
|
|
return tensor.detach().float().reshape(tensor.shape[0], -1)
|
|
|
|
|
|
def batch_relative_l2(current: torch.Tensor,
|
|
previous: torch.Tensor) -> list[float]:
|
|
"""Compute ||current-previous|| / ||previous|| for each item in the batch."""
|
|
current_flat = flatten_batch_tensor(current)
|
|
previous_flat = flatten_batch_tensor(previous)
|
|
numerator = torch.linalg.vector_norm(current_flat - previous_flat, dim=1)
|
|
denominator = torch.linalg.vector_norm(previous_flat, dim=1).clamp_min(1e-8)
|
|
return (numerator / denominator).cpu().tolist()
|
|
|
|
|
|
def batch_l2_distance(current: torch.Tensor,
|
|
reference: torch.Tensor) -> list[float]:
|
|
"""Compute L2 distance against a reference tensor for each batch item."""
|
|
current_flat = flatten_batch_tensor(current)
|
|
reference_flat = flatten_batch_tensor(reference)
|
|
return torch.linalg.vector_norm(current_flat - reference_flat,
|
|
dim=1).cpu().tolist()
|
|
|
|
|
|
def batch_cosine_similarity(current: torch.Tensor,
|
|
reference: torch.Tensor) -> list[float]:
|
|
"""Compute cosine similarity against a reference tensor for each batch item."""
|
|
current_flat = flatten_batch_tensor(current)
|
|
reference_flat = flatten_batch_tensor(reference)
|
|
return F.cosine_similarity(current_flat,
|
|
reference_flat,
|
|
dim=1,
|
|
eps=1e-8).cpu().tolist()
|
|
|
|
|
|
def first_consecutive_below(values: list[float], threshold: float,
|
|
window: int) -> float:
|
|
"""Return the first 1-based index where `window` consecutive values are below threshold."""
|
|
if window <= 0 or len(values) < window:
|
|
return np.nan
|
|
for start in range(len(values) - window + 1):
|
|
window_values = values[start:start + window]
|
|
if all(pd.notna(value) and value < threshold for value in window_values):
|
|
return float(start + 1)
|
|
return np.nan
|
|
|
|
|
|
def first_at_least(values: list[float], threshold: float) -> float:
|
|
"""Return the first 1-based index where the series reaches the threshold."""
|
|
for index, value in enumerate(values, start=1):
|
|
if pd.notna(value) and value >= threshold:
|
|
return float(index)
|
|
return np.nan
|
|
|
|
|
|
def first_at_most(values: list[float], threshold: float) -> float:
|
|
"""Return the first 1-based index where the series drops below the threshold."""
|
|
for index, value in enumerate(values, start=1):
|
|
if pd.notna(value) and value <= threshold:
|
|
return float(index)
|
|
return np.nan
|
|
|
|
|
|
def safe_mean(values: list[float]) -> float:
|
|
"""Average numeric values while ignoring NaNs."""
|
|
valid_values = [value for value in values if pd.notna(value)]
|
|
if not valid_values:
|
|
return np.nan
|
|
return float(np.mean(valid_values))
|
|
|
|
|
|
def make_sampling_noise_bundle(model: nn.Module,
|
|
noise_shape: list[int]) -> dict[str, torch.Tensor]:
|
|
"""Create aligned initial noise for latent, action, and state diffusion streams."""
|
|
batch_size = noise_shape[0]
|
|
horizon = noise_shape[2]
|
|
device = model.device
|
|
return {
|
|
'img': torch.randn(noise_shape, device=device),
|
|
'action': torch.randn((batch_size, horizon, model.agent_action_dim),
|
|
device=device),
|
|
'state': torch.randn((batch_size, horizon, model.agent_state_dim),
|
|
device=device),
|
|
}
|
|
|
|
|
|
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
|
|
"""Load optional PSNR values keyed by sample_id or videoid."""
|
|
if not psnr_path:
|
|
return {}
|
|
if not os.path.exists(psnr_path):
|
|
logging.warning("PSNR file not found: %s", psnr_path)
|
|
return {}
|
|
|
|
suffix = os.path.splitext(psnr_path)[1].lower()
|
|
lookup: dict[str, float] = {}
|
|
if suffix == '.csv':
|
|
df = pd.read_csv(psnr_path)
|
|
key_column = 'sample_id' if 'sample_id' in df.columns else 'videoid'
|
|
value_column = 'psnr_full50' if 'psnr_full50' in df.columns else 'psnr'
|
|
for _, row in df.iterrows():
|
|
if pd.notna(row[key_column]) and pd.notna(row[value_column]):
|
|
lookup[str(row[key_column])] = float(row[value_column])
|
|
return lookup
|
|
|
|
if suffix == '.json':
|
|
with open(psnr_path, 'r', encoding='utf-8') as file:
|
|
data = json.load(file)
|
|
if isinstance(data, dict):
|
|
if 'sample_id' in data and ('psnr_full50' in data or 'psnr' in data):
|
|
lookup[str(data['sample_id'])] = float(
|
|
data.get('psnr_full50', data['psnr']))
|
|
else:
|
|
for key, value in data.items():
|
|
if isinstance(value, (int, float)):
|
|
lookup[str(key)] = float(value)
|
|
elif isinstance(data, list):
|
|
for item in data:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
if 'sample_id' in item and ('psnr_full50' in item or 'psnr' in
|
|
item):
|
|
lookup[str(item['sample_id'])] = float(
|
|
item.get('psnr_full50', item['psnr']))
|
|
return lookup
|
|
|
|
logging.warning("Unsupported PSNR file format: %s", psnr_path)
|
|
return {}
|
|
|
|
|
|
class InteractionAnalysisLogger:
|
|
"""Collect stepwise metrics and aggregated per-sample summaries."""
|
|
|
|
STEP_COLUMNS = [
|
|
'sample_id',
|
|
'scene',
|
|
'pass_type',
|
|
'round_id',
|
|
'step',
|
|
'step_time_s',
|
|
'latent_delta',
|
|
'action_delta',
|
|
'state_delta',
|
|
'action_cosine_vs_full50',
|
|
'state_cosine_vs_full50',
|
|
'latent_l2_vs_full50',
|
|
]
|
|
SUMMARY_COLUMNS = [
|
|
'sample_id',
|
|
'scene',
|
|
'pass_type',
|
|
'pass_total_time_s',
|
|
'action_first_stable_step',
|
|
'state_first_stable_step',
|
|
'latent_first_stable_step',
|
|
'action_vs_full50_90pct_step',
|
|
'action_vs_full50_95pct_step',
|
|
'oracle_budget_action',
|
|
'oracle_budget_state',
|
|
'oracle_budget_latent',
|
|
'latent_init_dist_to_prev_round',
|
|
'action_drift_vs_prev_round',
|
|
'round_total_time_s',
|
|
'policy_pass_total_time_s',
|
|
'world_model_pass_total_time_s',
|
|
'psnr_full50',
|
|
]
|
|
ROUND_COLUMNS = [
|
|
'sample_id',
|
|
'scene',
|
|
'round_id',
|
|
'policy_pass_total_time_s',
|
|
'world_model_pass_total_time_s',
|
|
'round_total_time_s',
|
|
'latent_init_dist_to_prev_round',
|
|
'action_drift_vs_prev_round',
|
|
'psnr_full50',
|
|
]
|
|
|
|
def __init__(self, output_dir: str, psnr_lookup: dict[str, float]):
|
|
self.output_dir = output_dir
|
|
self.psnr_lookup = psnr_lookup
|
|
self.step_rows: list[dict] = []
|
|
self.summary_buckets: dict[tuple[str, str, str], dict] = {}
|
|
self.round_rows: list[dict] = []
|
|
self.round_buckets: dict[tuple[str, str], dict] = {}
|
|
self.prev_policy_action: dict[str, torch.Tensor] = {}
|
|
self.prev_world_latent: dict[str, torch.Tensor] = {}
|
|
|
|
def resolve_psnr(self, sample_id: str, videoid: int) -> float:
|
|
"""Resolve a PSNR value by full sample id first, then by raw video id."""
|
|
candidates = [sample_id, sample_id.rsplit('-fs', 1)[0], str(videoid)]
|
|
for candidate in candidates:
|
|
if candidate in self.psnr_lookup:
|
|
return float(self.psnr_lookup[candidate])
|
|
return np.nan
|
|
|
|
def append_summary_row(self, row: dict) -> None:
|
|
"""Store per-round summaries and aggregate them later by sample and pass type."""
|
|
key = (row['sample_id'], row['scene'], row['pass_type'])
|
|
metric_columns = self.SUMMARY_COLUMNS[3:]
|
|
if key not in self.summary_buckets:
|
|
self.summary_buckets[key] = {
|
|
'sample_id': row['sample_id'],
|
|
'scene': row['scene'],
|
|
'pass_type': row['pass_type'],
|
|
**{column: [] for column in metric_columns},
|
|
}
|
|
for column in metric_columns:
|
|
self.summary_buckets[key][column].append(row.get(column, np.nan))
|
|
|
|
def append_round_row(self, row: dict) -> None:
|
|
"""Store per-round metrics and aggregate them later by sample."""
|
|
self.round_rows.append(row)
|
|
key = (row['sample_id'], row['scene'])
|
|
metric_columns = self.ROUND_COLUMNS[3:]
|
|
if key not in self.round_buckets:
|
|
self.round_buckets[key] = {
|
|
'sample_id': row['sample_id'],
|
|
'scene': row['scene'],
|
|
**{column: [] for column in metric_columns},
|
|
}
|
|
for column in metric_columns:
|
|
self.round_buckets[key][column].append(row.get(column, np.nan))
|
|
|
|
def collect_trace_series(self, debug_info: dict, reference_action: torch.Tensor,
|
|
reference_state: torch.Tensor,
|
|
reference_latent: torch.Tensor) -> tuple[list[float], list[float], list[float]]:
|
|
"""Extract cosine/L2 curves for either the target pass or the full-50 reference pass."""
|
|
action_cosines = []
|
|
state_cosines = []
|
|
latent_l2s = []
|
|
for record in debug_info['step_records']:
|
|
action_cosines.append(
|
|
batch_cosine_similarity(record['action'], reference_action)[0])
|
|
state_cosines.append(
|
|
batch_cosine_similarity(record['state'], reference_state)[0])
|
|
latent_l2s.append(
|
|
batch_l2_distance(record['pred_x0'], reference_latent)[0])
|
|
return action_cosines, state_cosines, latent_l2s
|
|
|
|
def log_pass(self, sample_id: str, videoid: int, scene: str, pass_type: str,
|
|
round_id: int, pass_total_time_s: float, target_debug: dict,
|
|
reference_debug: dict) -> dict | None:
|
|
"""Log one pass worth of stepwise and aggregated metrics."""
|
|
if not target_debug or not target_debug.get('step_records'):
|
|
return None
|
|
if not reference_debug or not reference_debug.get('step_records'):
|
|
reference_debug = target_debug
|
|
|
|
reference_final_action = reference_debug['step_records'][-1]['action']
|
|
reference_final_state = reference_debug['step_records'][-1]['state']
|
|
reference_final_latent = reference_debug['step_records'][-1]['pred_x0']
|
|
|
|
prev_img = target_debug['analysis_init']['img']
|
|
prev_action = target_debug['analysis_init']['action']
|
|
prev_state = target_debug['analysis_init']['state']
|
|
action_deltas: list[float] = []
|
|
state_deltas: list[float] = []
|
|
latent_deltas: list[float] = []
|
|
action_cosines: list[float] = []
|
|
state_cosines: list[float] = []
|
|
latent_l2s: list[float] = []
|
|
|
|
for record in target_debug['step_records']:
|
|
latent_delta = batch_relative_l2(record['img'], prev_img)[0]
|
|
action_delta = batch_relative_l2(record['action'], prev_action)[0]
|
|
state_delta = batch_relative_l2(record['state'], prev_state)[0]
|
|
action_cosine = batch_cosine_similarity(record['action'],
|
|
reference_final_action)[0]
|
|
state_cosine = batch_cosine_similarity(record['state'],
|
|
reference_final_state)[0]
|
|
latent_l2 = batch_l2_distance(record['pred_x0'],
|
|
reference_final_latent)[0]
|
|
|
|
action_deltas.append(action_delta)
|
|
state_deltas.append(state_delta)
|
|
latent_deltas.append(latent_delta)
|
|
action_cosines.append(action_cosine)
|
|
state_cosines.append(state_cosine)
|
|
latent_l2s.append(latent_l2)
|
|
|
|
self.step_rows.append({
|
|
'sample_id': sample_id,
|
|
'scene': scene,
|
|
'pass_type': pass_type,
|
|
'round_id': round_id,
|
|
'step': record['step_index'],
|
|
'step_time_s': float(record['step_time_s']),
|
|
'latent_delta': latent_delta,
|
|
'action_delta': action_delta,
|
|
'state_delta': state_delta,
|
|
'action_cosine_vs_full50': action_cosine,
|
|
'state_cosine_vs_full50': state_cosine,
|
|
'latent_l2_vs_full50': latent_l2,
|
|
})
|
|
|
|
prev_img = record['img']
|
|
prev_action = record['action']
|
|
prev_state = record['state']
|
|
|
|
oracle_action_cosines, oracle_state_cosines, oracle_latent_l2s = self.collect_trace_series(
|
|
reference_debug, reference_final_action, reference_final_state,
|
|
reference_final_latent)
|
|
|
|
latent_init_dist_to_prev_round = np.nan
|
|
action_drift_vs_prev_round = np.nan
|
|
if pass_type == 'policy':
|
|
previous_action = self.prev_policy_action.get(sample_id)
|
|
if previous_action is not None:
|
|
action_drift_vs_prev_round = 1.0 - batch_cosine_similarity(
|
|
reference_final_action, previous_action)[0]
|
|
self.prev_policy_action[sample_id] = reference_final_action.clone()
|
|
elif pass_type == 'world_model':
|
|
previous_latent = self.prev_world_latent.get(sample_id)
|
|
if previous_latent is not None:
|
|
latent_init_dist_to_prev_round = batch_l2_distance(
|
|
reference_final_latent, previous_latent)[0]
|
|
self.prev_world_latent[sample_id] = reference_final_latent.clone()
|
|
|
|
summary_row = {
|
|
'sample_id': sample_id,
|
|
'scene': scene,
|
|
'pass_type': pass_type,
|
|
'pass_total_time_s': float(pass_total_time_s),
|
|
'action_first_stable_step': np.nan,
|
|
'state_first_stable_step': np.nan,
|
|
'latent_first_stable_step': np.nan,
|
|
'action_vs_full50_90pct_step': first_at_least(action_cosines, 0.90),
|
|
'action_vs_full50_95pct_step': first_at_least(action_cosines, 0.95),
|
|
'oracle_budget_action': first_at_least(oracle_action_cosines, 0.95),
|
|
'oracle_budget_state': first_at_least(oracle_state_cosines, 0.95),
|
|
'oracle_budget_latent': np.nan,
|
|
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
|
|
'action_drift_vs_prev_round': action_drift_vs_prev_round,
|
|
'round_total_time_s': np.nan,
|
|
'policy_pass_total_time_s': np.nan,
|
|
'world_model_pass_total_time_s': np.nan,
|
|
'psnr_full50': self.resolve_psnr(sample_id, videoid),
|
|
}
|
|
self.append_summary_row(summary_row)
|
|
return summary_row
|
|
|
|
def log_round(self, sample_id: str, videoid: int, scene: str, round_id: int,
|
|
policy_pass_total_time_s: float,
|
|
world_model_pass_total_time_s: float, round_total_time_s: float,
|
|
latent_init_dist_to_prev_round: float,
|
|
action_drift_vs_prev_round: float) -> None:
|
|
"""Log one interaction round consisting of one policy pass and one world-model pass."""
|
|
self.append_round_row({
|
|
'sample_id': sample_id,
|
|
'scene': scene,
|
|
'round_id': round_id,
|
|
'policy_pass_total_time_s': float(policy_pass_total_time_s),
|
|
'world_model_pass_total_time_s': float(world_model_pass_total_time_s),
|
|
'round_total_time_s': float(round_total_time_s),
|
|
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
|
|
'action_drift_vs_prev_round': action_drift_vs_prev_round,
|
|
'psnr_full50': self.resolve_psnr(sample_id, videoid),
|
|
})
|
|
|
|
def flush(self) -> None:
|
|
"""Write analysis CSVs to disk."""
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
stepwise_path = os.path.join(self.output_dir, 'stepwise_log.csv')
|
|
summary_path = os.path.join(self.output_dir, 'sample_summary.csv')
|
|
round_path = os.path.join(self.output_dir, 'round_summary.csv')
|
|
|
|
stepwise_df = pd.DataFrame(self.step_rows, columns=self.STEP_COLUMNS)
|
|
stepwise_df.to_csv(stepwise_path, index=False)
|
|
|
|
round_df = pd.DataFrame(self.round_rows, columns=self.ROUND_COLUMNS)
|
|
round_df.to_csv(round_path, index=False)
|
|
|
|
summary_rows = []
|
|
metric_columns = self.SUMMARY_COLUMNS[3:]
|
|
for bucket in self.summary_buckets.values():
|
|
round_bucket = self.round_buckets.get((bucket['sample_id'],
|
|
bucket['scene']))
|
|
row = {
|
|
'sample_id': bucket['sample_id'],
|
|
'scene': bucket['scene'],
|
|
'pass_type': bucket['pass_type'],
|
|
}
|
|
for column in metric_columns:
|
|
if round_bucket is not None and column in round_bucket:
|
|
row[column] = safe_mean(round_bucket[column])
|
|
else:
|
|
row[column] = safe_mean(bucket[column])
|
|
summary_rows.append(row)
|
|
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
|
|
summary_df.to_csv(summary_path, index=False)
|
|
|
|
|
|
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
|
"""Save a list of frames to a video file.
|
|
|
|
Args:
|
|
video_path (str): Output path for the video.
|
|
stacked_frames (list): List of image frames.
|
|
fps (int): Frames per second for the video.
|
|
"""
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore",
|
|
"pkg_resources is deprecated as an API",
|
|
category=DeprecationWarning)
|
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
|
|
|
|
|
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
|
"""Return sorted list of files in a directory matching specified postfixes.
|
|
|
|
Args:
|
|
data_dir (str): Directory path to search in.
|
|
postfixes (list[str]): List of file extensions to match.
|
|
|
|
Returns:
|
|
list[str]: Sorted list of file paths.
|
|
"""
|
|
patterns = [
|
|
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
|
|
]
|
|
file_list = []
|
|
for pattern in patterns:
|
|
file_list.extend(glob.glob(pattern))
|
|
file_list.sort()
|
|
return file_list
|
|
|
|
|
|
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
|
"""Load model weights from checkpoint file.
|
|
|
|
Args:
|
|
model (nn.Module): Model instance.
|
|
ckpt (str): Path to the checkpoint file.
|
|
|
|
Returns:
|
|
nn.Module: Model with loaded weights.
|
|
"""
|
|
state_dict = torch.load(ckpt, map_location="cpu")
|
|
if "state_dict" in list(state_dict.keys()):
|
|
state_dict = state_dict["state_dict"]
|
|
try:
|
|
model.load_state_dict(state_dict, strict=True)
|
|
except:
|
|
new_pl_sd = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
new_pl_sd[k] = v
|
|
|
|
for k in list(new_pl_sd.keys()):
|
|
if "framestride_embed" in k:
|
|
new_key = k.replace("framestride_embed", "fps_embedding")
|
|
new_pl_sd[new_key] = new_pl_sd[k]
|
|
del new_pl_sd[k]
|
|
model.load_state_dict(new_pl_sd, strict=True)
|
|
else:
|
|
new_pl_sd = OrderedDict()
|
|
for key in state_dict['module'].keys():
|
|
new_pl_sd[key[16:]] = state_dict['module'][key]
|
|
model.load_state_dict(new_pl_sd)
|
|
print('>>> model checkpoint loaded.')
|
|
return model
|
|
|
|
|
|
def is_inferenced(save_dir: str, filename: str) -> bool:
|
|
"""Check if a given filename has already been processed and saved.
|
|
|
|
Args:
|
|
save_dir (str): Directory where results are saved.
|
|
filename (str): Name of the file to check.
|
|
|
|
Returns:
|
|
bool: True if processed file exists, False otherwise.
|
|
"""
|
|
video_file = os.path.join(save_dir, "samples_separate",
|
|
f"{filename[:-4]}_sample0.mp4")
|
|
return os.path.exists(video_file)
|
|
|
|
|
|
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
|
"""Save video tensor to file using torchvision.
|
|
|
|
Args:
|
|
video (Tensor): Tensor of shape (B, C, T, H, W).
|
|
filename (str): Output file path.
|
|
fps (int, optional): Frames per second. Defaults to 8.
|
|
"""
|
|
video = video.detach().cpu()
|
|
video = torch.clamp(video.float(), -1., 1.)
|
|
n = video.shape[0]
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
torchvision.io.write_video(filename,
|
|
grid,
|
|
fps=fps,
|
|
video_codec='h264',
|
|
options={'crf': '10'})
|
|
|
|
|
|
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
|
"""Construct the init_frame path from directory and sample metadata.
|
|
|
|
Args:
|
|
data_dir (str): Base directory containing videos.
|
|
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
|
|
|
|
Returns:
|
|
str: Full path to the video file.
|
|
"""
|
|
rel_video_fp = os.path.join(sample['data_dir'],
|
|
str(sample['videoid']) + '.png')
|
|
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
|
|
return full_image_fp
|
|
|
|
|
|
def get_transition_path(data_dir: str, sample: dict) -> str:
|
|
"""Construct the full transition file path from directory and sample metadata.
|
|
|
|
Args:
|
|
data_dir (str): Base directory containing transition files.
|
|
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
|
|
|
|
Returns:
|
|
str: Full path to the HDF5 transition file.
|
|
"""
|
|
rel_transition_fp = os.path.join(sample['data_dir'],
|
|
str(sample['videoid']) + '.h5')
|
|
full_transition_fp = os.path.join(data_dir, 'transitions',
|
|
rel_transition_fp)
|
|
return full_transition_fp
|
|
|
|
|
|
def prepare_init_input(start_idx: int,
|
|
init_frame_path: str,
|
|
transition_dict: dict[str, torch.Tensor],
|
|
frame_stride: int,
|
|
wma_data,
|
|
video_length: int = 16,
|
|
n_obs_steps: int = 2) -> dict[str, Tensor]:
|
|
"""
|
|
Extracts a structured sample from a video sequence including frames, states, and actions,
|
|
along with properly padded observations and pre-processed tensors for model input.
|
|
|
|
Args:
|
|
start_idx (int): Starting frame index for the current clip.
|
|
video: decord video instance.
|
|
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
|
|
'observation.state', 'action_type', 'state_type'.
|
|
frame_stride (int): Temporal stride between sampled frames.
|
|
wma_data: Object that holds configuration and utility functions like normalization,
|
|
transformation, and resolution info.
|
|
video_length (int, optional): Number of frames to sample from the video. Default is 16.
|
|
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
|
|
"""
|
|
|
|
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
|
init_frame = Image.open(init_frame_path).convert('RGB')
|
|
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
|
|
3, 0, 1, 2).float()
|
|
|
|
if start_idx < n_obs_steps - 1:
|
|
state_indices = list(range(0, start_idx + 1))
|
|
states = transition_dict['observation.state'][state_indices, :]
|
|
num_padding = n_obs_steps - 1 - start_idx
|
|
first_slice = states[0:1, :] # (t, d)
|
|
padding = first_slice.repeat(num_padding, 1)
|
|
states = torch.cat((padding, states), dim=0)
|
|
else:
|
|
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
|
states = transition_dict['observation.state'][state_indices, :]
|
|
|
|
actions = transition_dict['action'][indices, :]
|
|
|
|
ori_state_dim = states.shape[-1]
|
|
ori_action_dim = actions.shape[-1]
|
|
|
|
frames_action_state_dict = {
|
|
'action': actions,
|
|
'observation.state': states,
|
|
}
|
|
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
|
frames_action_state_dict = wma_data.get_uni_vec(
|
|
frames_action_state_dict,
|
|
transition_dict['action_type'],
|
|
transition_dict['state_type'],
|
|
)
|
|
|
|
if wma_data.spatial_transform is not None:
|
|
init_frame = wma_data.spatial_transform(init_frame)
|
|
init_frame = (init_frame / 255 - 0.5) * 2
|
|
|
|
data = {
|
|
'observation.image': init_frame,
|
|
}
|
|
data.update(frames_action_state_dict)
|
|
return data, ori_state_dim, ori_action_dim
|
|
|
|
|
|
def get_latent_z(model, videos: Tensor) -> Tensor:
|
|
"""
|
|
Extracts latent features from a video batch using the model's first-stage encoder.
|
|
|
|
Args:
|
|
model: the world model.
|
|
videos (Tensor): Input videos of shape [B, C, T, H, W].
|
|
|
|
Returns:
|
|
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
|
"""
|
|
b, c, t, h, w = videos.shape
|
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
|
z = model.encode_first_stage(x)
|
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
|
return z
|
|
|
|
|
|
def preprocess_observation(
|
|
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
|
"""Convert environment observation to LeRobot format observation.
|
|
Args:
|
|
observation: Dictionary of observation batches from a Gym vector environment.
|
|
Returns:
|
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
|
"""
|
|
# Map to expected inputs for the policy
|
|
return_observations = {}
|
|
|
|
if isinstance(observations["pixels"], dict):
|
|
imgs = {
|
|
f"observation.images.{key}": img
|
|
for key, img in observations["pixels"].items()
|
|
}
|
|
else:
|
|
imgs = {"observation.images.top": observations["pixels"]}
|
|
|
|
for imgkey, img in imgs.items():
|
|
img = torch.from_numpy(img)
|
|
|
|
# Sanity check that images are channel last
|
|
_, h, w, c = img.shape
|
|
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
|
|
|
# Sanity check that images are uint8
|
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
|
|
|
# Convert to channel first of type float32 in range [0,1]
|
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
img = img.type(torch.float32)
|
|
|
|
return_observations[imgkey] = img
|
|
|
|
return_observations["observation.state"] = torch.from_numpy(
|
|
observations["agent_pos"]).float()
|
|
return_observations['observation.state'] = model.normalize_inputs({
|
|
'observation.state':
|
|
return_observations['observation.state'].to(model.device)
|
|
})['observation.state']
|
|
|
|
return return_observations
|
|
|
|
|
|
def image_guided_synthesis_sim_mode(
|
|
model: torch.nn.Module,
|
|
prompts: list[str],
|
|
observation: dict,
|
|
noise_shape: tuple[int, int, int, int, int],
|
|
action_cond_step: int = 16,
|
|
n_samples: int = 1,
|
|
ddim_steps: int = 50,
|
|
ddim_eta: float = 1.0,
|
|
unconditional_guidance_scale: float = 1.0,
|
|
fs: int | None = None,
|
|
text_input: bool = True,
|
|
timestep_spacing: str = 'uniform',
|
|
guidance_rescale: float = 0.0,
|
|
sim_mode: bool = True,
|
|
init_noise_bundle: dict[str, torch.Tensor] | None = None,
|
|
decode_video: bool = True,
|
|
return_debug_info: bool = False,
|
|
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, dict | None]:
|
|
"""
|
|
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
|
|
|
Args:
|
|
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
|
|
prompts (list[str]): A list of textual prompts to guide the synthesis process.
|
|
observation (dict): A dictionary containing observed inputs including:
|
|
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
|
|
- 'observation.state': Tensor of shape [B, O, D] (state vector)
|
|
- 'action': Tensor of shape [B, T, D] (action sequence)
|
|
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
|
|
typically (B, C, T, H, W).
|
|
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
|
|
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
|
|
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
|
|
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
|
|
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
|
|
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
|
|
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
|
|
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
|
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
|
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
|
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
|
|
decode_video (bool): Whether to decode the final latent into pixel space.
|
|
return_debug_info (bool): Whether to return per-step traces for analysis logging.
|
|
**kwargs: Additional arguments passed to the DDIM sampler.
|
|
|
|
Returns:
|
|
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
|
|
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
|
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
|
debug_info (dict | None): Optional per-step trace used for convergence analysis.
|
|
"""
|
|
b, _, t, _, _ = noise_shape
|
|
ddim_sampler = DDIMSampler(model)
|
|
batch_size = noise_shape[0]
|
|
batch_variants = None
|
|
debug_info = None
|
|
|
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
|
|
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
|
cond_img_emb = model.embedder(cond_img)
|
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
|
|
if model.model.conditioning_key == 'hybrid':
|
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
|
img_cat_cond = z[:, :, -1:, :, :]
|
|
img_cat_cond = repeat(img_cat_cond,
|
|
'b c t h w -> b c (repeat t) h w',
|
|
repeat=noise_shape[2])
|
|
cond = {"c_concat": [img_cat_cond]}
|
|
|
|
if not text_input:
|
|
prompts = [""] * batch_size
|
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
|
|
|
cond_state_emb = model.state_projector(observation['observation.state'])
|
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
|
|
|
cond_action_emb = model.action_projector(observation['action'])
|
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
|
|
|
if not sim_mode:
|
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
|
|
|
cond["c_crossattn"] = [
|
|
torch.cat(
|
|
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
|
dim=1)
|
|
]
|
|
cond["c_crossattn_action"] = [
|
|
observation['observation.images.top'][:, :,
|
|
-model.n_obs_steps_acting:],
|
|
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
|
sim_mode,
|
|
False,
|
|
]
|
|
|
|
uc = None
|
|
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
|
cond_mask = None
|
|
cond_z0 = None
|
|
if ddim_sampler is not None:
|
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
|
S=ddim_steps,
|
|
conditioning=cond,
|
|
batch_size=batch_size,
|
|
shape=noise_shape[1:],
|
|
verbose=False,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=ddim_eta,
|
|
cfg_img=None,
|
|
mask=cond_mask,
|
|
x0=cond_z0,
|
|
fs=fs,
|
|
timestep_spacing=timestep_spacing,
|
|
guidance_rescale=guidance_rescale,
|
|
x_T=None if init_noise_bundle is None else init_noise_bundle['img'],
|
|
action_T=None if init_noise_bundle is None else
|
|
init_noise_bundle['action'],
|
|
state_T=None if init_noise_bundle is None else
|
|
init_noise_bundle['state'],
|
|
record_step_outputs=return_debug_info,
|
|
**kwargs)
|
|
|
|
batch_variants = None
|
|
if decode_video:
|
|
batch_variants = model.decode_first_stage(samples)
|
|
|
|
if return_debug_info:
|
|
debug_info = {
|
|
'analysis_init': intermedia.get('analysis_init'),
|
|
'step_records': intermedia.get('step_records', []),
|
|
'final_latent': samples.detach().cpu(),
|
|
'final_action': actions.detach().cpu(),
|
|
'final_state': states.detach().cpu(),
|
|
}
|
|
|
|
return batch_variants, actions, states, debug_info
|
|
|
|
|
|
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|
"""
|
|
Run inference pipeline on prompts and image inputs.
|
|
|
|
Args:
|
|
args (argparse.Namespace): Parsed command-line arguments.
|
|
gpu_num (int): Number of GPUs.
|
|
gpu_no (int): Index of the current GPU.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
# Create inference and tensorboard dirs
|
|
inference_dir = args.savedir + '/inference'
|
|
os.makedirs(inference_dir, exist_ok=True)
|
|
log_dir = args.savedir + f"/tensorboard"
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
writer = SummaryWriter(log_dir=log_dir)
|
|
analysis_logger = None
|
|
if args.analysis_log_metrics:
|
|
analysis_logger = InteractionAnalysisLogger(
|
|
output_dir=inference_dir,
|
|
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
|
|
)
|
|
|
|
# Load prompt
|
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
|
df = pd.read_csv(csv_path)
|
|
|
|
# Load config
|
|
config = OmegaConf.load(args.config)
|
|
config['model']['params']['wma_config']['params'][
|
|
'use_checkpoint'] = False
|
|
model = instantiate_from_config(config.model)
|
|
model.perframe_ae = args.perframe_ae
|
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
|
model = load_model_checkpoint(model, args.ckpt_path)
|
|
model.eval()
|
|
print(f'>>> Load pre-trained model ...')
|
|
|
|
# Build unnomalizer
|
|
logging.info("***** Configing Data *****")
|
|
data = instantiate_from_config(config.data)
|
|
data.setup()
|
|
print(">>> Dataset is successfully loaded ...")
|
|
|
|
model = model.cuda(gpu_no)
|
|
device = get_device_from_parameters(model)
|
|
|
|
# Run over data
|
|
assert (args.height % 16 == 0) and (
|
|
args.width % 16
|
|
== 0), "Error: image size [h,w] should be multiples of 16!"
|
|
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
|
if args.analysis_log_metrics:
|
|
assert args.ddim_steps > 0, "analysis_log_metrics requires positive --ddim_steps."
|
|
assert args.analysis_reference_steps > 0, (
|
|
"analysis_log_metrics requires positive --analysis_reference_steps.")
|
|
|
|
# Get latent noise shape
|
|
h, w = args.height // 8, args.width // 8
|
|
channels = model.model.diffusion_model.out_channels
|
|
n_frames = args.video_length
|
|
print(f'>>> Generate {n_frames} frames under each generation ...')
|
|
noise_shape = [args.bs, channels, n_frames, h, w]
|
|
|
|
# Start inference
|
|
for idx in range(0, len(df)):
|
|
sample = df.iloc[idx]
|
|
|
|
# Got initial frame path
|
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
|
ori_fps = float(sample['fps'])
|
|
|
|
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
|
|
os.makedirs(video_save_dir, exist_ok=True)
|
|
os.makedirs(video_save_dir + '/dm', exist_ok=True)
|
|
os.makedirs(video_save_dir + '/wm', exist_ok=True)
|
|
|
|
# Load transitions to get the initial state later
|
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
|
with h5py.File(transition_path, 'r') as h5f:
|
|
transition_dict = {}
|
|
for key in h5f.keys():
|
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
|
for key in h5f.attrs.keys():
|
|
transition_dict[key] = h5f.attrs[key]
|
|
|
|
# If many, test various frequence control and world-model generation
|
|
for fs in args.frame_stride:
|
|
sample_id = build_sample_id(args.dataset, sample, fs)
|
|
scene = get_scene_name(sample, args.dataset)
|
|
|
|
# For saving imagens in policy
|
|
sample_save_dir = f'{video_save_dir}/dm/{fs}'
|
|
os.makedirs(sample_save_dir, exist_ok=True)
|
|
# For saving environmental changes in world-model
|
|
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
|
os.makedirs(sample_save_dir, exist_ok=True)
|
|
# For collecting interaction videos
|
|
wm_video = []
|
|
# Initialize observation queues
|
|
cond_obs_queues = {
|
|
"observation.images.top":
|
|
deque(maxlen=model.n_obs_steps_imagen),
|
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
|
"action": deque(maxlen=args.video_length),
|
|
}
|
|
# Obtain initial frame and state
|
|
start_idx = 0
|
|
model_input_fs = ori_fps // fs
|
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
|
start_idx,
|
|
init_frame_path,
|
|
transition_dict,
|
|
fs,
|
|
data.test_datasets[args.dataset],
|
|
n_obs_steps=model.n_obs_steps_imagen)
|
|
observation = {
|
|
'observation.images.top':
|
|
batch['observation.image'].permute(1, 0, 2,
|
|
3)[-1].unsqueeze(0),
|
|
'observation.state':
|
|
batch['observation.state'][-1].unsqueeze(0),
|
|
'action':
|
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
|
|
}
|
|
observation = {
|
|
key: observation[key].to(device, non_blocking=True)
|
|
for key in observation
|
|
}
|
|
# Update observation queues
|
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
|
|
|
# Multi-round interaction with the world-model
|
|
for itr in tqdm(range(args.n_iter)):
|
|
round_start_time = time.time()
|
|
|
|
# Get observation
|
|
observation = {
|
|
'observation.images.top':
|
|
torch.stack(list(
|
|
cond_obs_queues['observation.images.top']),
|
|
dim=1).permute(0, 2, 1, 3, 4),
|
|
'observation.state':
|
|
torch.stack(list(cond_obs_queues['observation.state']),
|
|
dim=1),
|
|
'action':
|
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
}
|
|
observation = {
|
|
key: observation[key].to(device, non_blocking=True)
|
|
for key in observation
|
|
}
|
|
|
|
# Use world-model in policy to generate action
|
|
print(f'>>> Step {itr}: generating actions ...')
|
|
policy_noise_bundle = make_sampling_noise_bundle(
|
|
model, noise_shape) if args.analysis_log_metrics else None
|
|
policy_reference_debug = None
|
|
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
|
|
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
|
|
model,
|
|
sample['instruction'],
|
|
observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.analysis_reference_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
fs=model_input_fs,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
sim_mode=False,
|
|
init_noise_bundle=policy_noise_bundle,
|
|
decode_video=False,
|
|
return_debug_info=True)
|
|
policy_pass_start = time.time()
|
|
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
|
|
model,
|
|
sample['instruction'],
|
|
observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
fs=model_input_fs,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
sim_mode=False,
|
|
init_noise_bundle=policy_noise_bundle,
|
|
return_debug_info=args.analysis_log_metrics)
|
|
policy_pass_total_time_s = time.time() - policy_pass_start
|
|
policy_summary_row = None
|
|
if analysis_logger is not None:
|
|
if policy_reference_debug is None:
|
|
policy_reference_debug = policy_debug
|
|
policy_summary_row = analysis_logger.log_pass(
|
|
sample_id=sample_id,
|
|
videoid=int(sample['videoid']),
|
|
scene=scene,
|
|
pass_type='policy',
|
|
round_id=itr,
|
|
pass_total_time_s=policy_pass_total_time_s,
|
|
target_debug=policy_debug,
|
|
reference_debug=policy_reference_debug,
|
|
)
|
|
|
|
# Update future actions in the observation queues
|
|
for idx in range(len(pred_actions[0])):
|
|
observation = {'action': pred_actions[0][idx:idx + 1]}
|
|
observation['action'][:, ori_action_dim:] = 0.0
|
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
|
observation)
|
|
|
|
# Collect data for interacting the world-model using the predicted actions
|
|
observation = {
|
|
'observation.images.top':
|
|
torch.stack(list(
|
|
cond_obs_queues['observation.images.top']),
|
|
dim=1).permute(0, 2, 1, 3, 4),
|
|
'observation.state':
|
|
torch.stack(list(cond_obs_queues['observation.state']),
|
|
dim=1),
|
|
'action':
|
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
}
|
|
observation = {
|
|
key: observation[key].to(device, non_blocking=True)
|
|
for key in observation
|
|
}
|
|
|
|
# Interaction with the world-model
|
|
print(f'>>> Step {itr}: interacting with world model ...')
|
|
world_noise_bundle = make_sampling_noise_bundle(
|
|
model, noise_shape) if args.analysis_log_metrics else None
|
|
world_reference_debug = None
|
|
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
|
|
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.analysis_reference_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
init_noise_bundle=world_noise_bundle,
|
|
decode_video=False,
|
|
return_debug_info=True)
|
|
world_pass_start = time.time()
|
|
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
init_noise_bundle=world_noise_bundle,
|
|
return_debug_info=args.analysis_log_metrics)
|
|
world_pass_total_time_s = time.time() - world_pass_start
|
|
world_summary_row = None
|
|
if analysis_logger is not None:
|
|
if world_reference_debug is None:
|
|
world_reference_debug = world_debug
|
|
world_summary_row = analysis_logger.log_pass(
|
|
sample_id=sample_id,
|
|
videoid=int(sample['videoid']),
|
|
scene=scene,
|
|
pass_type='world_model',
|
|
round_id=itr,
|
|
pass_total_time_s=world_pass_total_time_s,
|
|
target_debug=world_debug,
|
|
reference_debug=world_reference_debug,
|
|
)
|
|
analysis_logger.log_round(
|
|
sample_id=sample_id,
|
|
videoid=int(sample['videoid']),
|
|
scene=scene,
|
|
round_id=itr,
|
|
policy_pass_total_time_s=policy_pass_total_time_s,
|
|
world_model_pass_total_time_s=
|
|
world_pass_total_time_s,
|
|
round_total_time_s=time.time() - round_start_time,
|
|
latent_init_dist_to_prev_round=np.nan
|
|
if world_summary_row is None else
|
|
world_summary_row['latent_init_dist_to_prev_round'],
|
|
action_drift_vs_prev_round=np.nan
|
|
if policy_summary_row is None else
|
|
policy_summary_row['action_drift_vs_prev_round'],
|
|
)
|
|
|
|
for idx in range(args.exe_steps):
|
|
observation = {
|
|
'observation.images.top':
|
|
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
|
'observation.state':
|
|
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
|
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
|
'action':
|
|
torch.zeros_like(pred_actions[0][-1:])
|
|
}
|
|
observation['observation.state'][:, ori_state_dim:] = 0.0
|
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
|
observation)
|
|
|
|
# Save the imagen videos for decision-making
|
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
|
log_to_tensorboard(writer,
|
|
pred_videos_0,
|
|
sample_tag,
|
|
fps=args.save_fps)
|
|
# Save videos environment changes via world-model interaction
|
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
|
log_to_tensorboard(writer,
|
|
pred_videos_1,
|
|
sample_tag,
|
|
fps=args.save_fps)
|
|
|
|
# Save the imagen videos for decision-making
|
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
|
save_results(pred_videos_0.cpu(),
|
|
sample_video_file,
|
|
fps=args.save_fps)
|
|
# Save videos environment changes via world-model interaction
|
|
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
|
save_results(pred_videos_1.cpu(),
|
|
sample_video_file,
|
|
fps=args.save_fps)
|
|
|
|
print('>' * 24)
|
|
# Collect the result of world-model interactions
|
|
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
|
|
|
full_video = torch.cat(wm_video, dim=2)
|
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
|
log_to_tensorboard(writer,
|
|
full_video,
|
|
sample_tag,
|
|
fps=args.save_fps)
|
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
|
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
|
|
|
if analysis_logger is not None:
|
|
analysis_logger.flush()
|
|
writer.close()
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--savedir",
|
|
type=str,
|
|
default=None,
|
|
help="Path to save the results.")
|
|
parser.add_argument("--ckpt_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to the model checkpoint.")
|
|
parser.add_argument("--config",
|
|
type=str,
|
|
help="Path to the model checkpoint.")
|
|
parser.add_argument(
|
|
"--prompt_dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory containing videos and corresponding prompts.")
|
|
parser.add_argument("--dataset",
|
|
type=str,
|
|
default=None,
|
|
help="the name of dataset to test")
|
|
parser.add_argument(
|
|
"--ddim_steps",
|
|
type=int,
|
|
default=50,
|
|
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
|
parser.add_argument(
|
|
"--ddim_eta",
|
|
type=float,
|
|
default=1.0,
|
|
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
|
parser.add_argument("--bs",
|
|
type=int,
|
|
default=1,
|
|
help="Batch size for inference. Must be 1.")
|
|
parser.add_argument("--height",
|
|
type=int,
|
|
default=320,
|
|
help="Height of the generated images in pixels.")
|
|
parser.add_argument("--width",
|
|
type=int,
|
|
default=512,
|
|
help="Width of the generated images in pixels.")
|
|
parser.add_argument(
|
|
"--frame_stride",
|
|
type=int,
|
|
nargs='+',
|
|
required=True,
|
|
help=
|
|
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
|
|
)
|
|
parser.add_argument(
|
|
"--unconditional_guidance_scale",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scale for classifier-free guidance during sampling.")
|
|
parser.add_argument("--seed",
|
|
type=int,
|
|
default=123,
|
|
help="Random seed for reproducibility.")
|
|
parser.add_argument("--video_length",
|
|
type=int,
|
|
default=16,
|
|
help="Number of frames in the generated video.")
|
|
parser.add_argument("--num_generation",
|
|
type=int,
|
|
default=1,
|
|
help="seed for seed_everything")
|
|
parser.add_argument(
|
|
"--timestep_spacing",
|
|
type=str,
|
|
default="uniform",
|
|
help=
|
|
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_rescale",
|
|
type=float,
|
|
default=0.0,
|
|
help=
|
|
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
|
)
|
|
parser.add_argument(
|
|
"--perframe_ae",
|
|
action='store_true',
|
|
default=False,
|
|
help=
|
|
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
|
)
|
|
parser.add_argument(
|
|
"--n_action_steps",
|
|
type=int,
|
|
default=16,
|
|
help="num of samples per prompt",
|
|
)
|
|
parser.add_argument(
|
|
"--exe_steps",
|
|
type=int,
|
|
default=16,
|
|
help="num of samples to execute",
|
|
)
|
|
parser.add_argument(
|
|
"--n_iter",
|
|
type=int,
|
|
default=40,
|
|
help="num of iteration to interact with the world model",
|
|
)
|
|
parser.add_argument("--zero_pred_state",
|
|
action='store_true',
|
|
default=False,
|
|
help="not using the predicted states as comparison")
|
|
parser.add_argument("--save_fps",
|
|
type=int,
|
|
default=8,
|
|
help="fps for the saving video")
|
|
parser.add_argument("--analysis_log_metrics",
|
|
action='store_true',
|
|
default=False,
|
|
help="Enable DDIM convergence logging and export analysis CSVs.")
|
|
parser.add_argument(
|
|
"--analysis_reference_steps",
|
|
type=int,
|
|
default=50,
|
|
help="Reference DDIM steps used to build the full-step baseline for *_vs_full50 metrics."
|
|
)
|
|
parser.add_argument("--analysis_psnr_path",
|
|
type=str,
|
|
default=None,
|
|
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
|
|
return parser
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
seed = args.seed
|
|
if seed < 0:
|
|
seed = random.randint(0, 2**31)
|
|
seed_everything(seed)
|
|
rank, gpu_num = 0, 1
|
|
run_inference(args, gpu_num, rank)
|