diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 2994dc9..284bf3c 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -1,28 +1,35 @@ import argparse, os, glob -import json import pandas as pd import random import torch -import torchvision -import h5py +import torchvision +import h5py import numpy as np import logging import einops import warnings import imageio +import atexit +import signal +import multiprocessing as mp import time - -from pytorch_lightning import seed_everything -from omegaconf import OmegaConf -from tqdm import tqdm -from einops import rearrange, repeat -from collections import OrderedDict -from torch import nn -import torch.nn.functional as F -from eval_utils import populate_queues, log_to_tensorboard +from concurrent.futures import ThreadPoolExecutor +from queue import Empty, Queue + +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf +from tqdm import tqdm +from einops import rearrange, repeat +from collections import OrderedDict +from torch import nn +from eval_utils import populate_queues from collections import deque +from typing import Optional, List, Any +from types import SimpleNamespace + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True from torch import Tensor -from torch.utils.tensorboard import SummaryWriter from PIL import Image from unifolm_wma.models.samplers.ddim import DDIMSampler @@ -41,577 +48,105 @@ def get_device_from_parameters(module: nn.Module) -> torch.device: return next(iter(module.parameters())).device -def get_scene_name(sample: pd.Series, fallback: str) -> str: - """Resolve the scene label used in analysis logs.""" - if 'data_dir' in sample and pd.notna(sample['data_dir']): - return str(sample['data_dir']) - return fallback +def clone_observation_queues( + queues: dict[str, deque]) -> dict[str, deque]: + """Deep-clone queue tensors so pipeline branches can diverge safely.""" + cloned = {} + for key, queue in queues.items(): + cloned[key] = deque( + ((item.clone() if torch.is_tensor(item) else item) + for item in queue), + maxlen=queue.maxlen) + return cloned -def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str: - """Build a stable sample id while keeping the required CSV schema flat.""" - return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}" +def clone_observation_queues_to_cpu( + queues: dict[str, deque]) -> dict[str, deque]: + cpu_queues = {} + for key, queue in queues.items(): + cpu_queues[key] = deque( + (item.detach().cpu().clone() if torch.is_tensor(item) else item + for item in queue), + maxlen=queue.maxlen) + return cpu_queues -def get_case_id(prompt_dir: str) -> str: - """Resolve case id from a prompt directory like */case1/world_model_interaction_prompts.""" - normalized = os.path.normpath(prompt_dir) - return os.path.basename(os.path.dirname(normalized)) +def move_observation_queues_to_device( + queues: dict[str, deque], + device: torch.device) -> dict[str, deque]: + moved = {} + for key, queue in queues.items(): + moved[key] = deque( + ((item.to(device, non_blocking=True) if torch.is_tensor(item) else + item) for item in queue), + maxlen=queue.maxlen) + return moved -def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor: - """Flatten all non-batch dimensions for batched metric computation.""" - return tensor.detach().float().reshape(tensor.shape[0], -1) +def sync_module_device_attributes(module: nn.Module, + device: torch.device) -> None: + """Align cached `.device` attributes with the actual target device.""" + for submodule in module.modules(): + if hasattr(submodule, 'device'): + try: + setattr(submodule, 'device', device) + except Exception: + pass -def batch_relative_l2(current: torch.Tensor, - previous: torch.Tensor) -> list[float]: - """Compute ||current-previous|| / ||previous|| for each item in the batch.""" - current_flat = flatten_batch_tensor(current) - previous_flat = flatten_batch_tensor(previous) - numerator = torch.linalg.vector_norm(current_flat - previous_flat, dim=1) - denominator = torch.linalg.vector_norm(previous_flat, dim=1).clamp_min(1e-8) - return (numerator / denominator).cpu().tolist() +def pipeline_print(message: str) -> None: + print(message, flush=True) -def batch_l2_distance(current: torch.Tensor, - reference: torch.Tensor) -> list[float]: - """Compute L2 distance against a reference tensor for each batch item.""" - current_flat = flatten_batch_tensor(current) - reference_flat = flatten_batch_tensor(reference) - return torch.linalg.vector_norm(current_flat - reference_flat, - dim=1).cpu().tolist() - - -def batch_cosine_similarity(current: torch.Tensor, - reference: torch.Tensor) -> list[float]: - """Compute cosine similarity against a reference tensor for each batch item.""" - current_flat = flatten_batch_tensor(current) - reference_flat = flatten_batch_tensor(reference) - return F.cosine_similarity(current_flat, - reference_flat, - dim=1, - eps=1e-8).cpu().tolist() - - -def first_consecutive_below(values: list[float], threshold: float, - window: int) -> float: - """Return the first 1-based index where `window` consecutive values are below threshold.""" - if window <= 0 or len(values) < window: - return np.nan - for start in range(len(values) - window + 1): - window_values = values[start:start + window] - if all(pd.notna(value) and value < threshold for value in window_values): - return float(start + 1) - return np.nan - - -def first_at_least(values: list[float], threshold: float) -> float: - """Return the first 1-based index where the series reaches the threshold.""" - for index, value in enumerate(values, start=1): - if pd.notna(value) and value >= threshold: - return float(index) - return np.nan - - -def first_at_most(values: list[float], threshold: float) -> float: - """Return the first 1-based index where the series drops below the threshold.""" - for index, value in enumerate(values, start=1): - if pd.notna(value) and value <= threshold: - return float(index) - return np.nan - - -def safe_mean(values: list[float]) -> float: - """Average numeric values while ignoring NaNs.""" - valid_values = [value for value in values if pd.notna(value)] - if not valid_values: - return np.nan - return float(np.mean(valid_values)) - - -def flatten_tensor(tensor: torch.Tensor) -> torch.Tensor: - """Flatten an arbitrary tensor into one 1D float vector.""" - return tensor.detach().float().reshape(-1) - - -def tensor_l2_distance(current: torch.Tensor, reference: torch.Tensor) -> float: - """Compute ||current-reference|| for arbitrary tensors.""" - current_flat = flatten_tensor(current) - reference_flat = flatten_tensor(reference) - return float(torch.linalg.vector_norm(current_flat - reference_flat).item()) - - -def tensor_relative_l2(current: torch.Tensor, previous: torch.Tensor) -> float: - """Compute ||current-previous|| / (||previous|| + eps) for arbitrary tensors.""" - current_flat = flatten_tensor(current) - previous_flat = flatten_tensor(previous) - numerator = torch.linalg.vector_norm(current_flat - previous_flat) - denominator = torch.linalg.vector_norm(previous_flat).clamp_min(1e-8) - return float((numerator / denominator).item()) - - -def tensor_cosine_similarity(current: torch.Tensor, - reference: torch.Tensor) -> float: - """Compute cosine similarity between arbitrary tensors.""" - current_flat = flatten_tensor(current) - reference_flat = flatten_tensor(reference) - return float( - F.cosine_similarity(current_flat, reference_flat, dim=0, - eps=1e-8).item()) - - -def make_sampling_noise_bundle(model: nn.Module, - noise_shape: list[int]) -> dict[str, torch.Tensor]: - """Create aligned initial noise for latent, action, and state diffusion streams.""" - batch_size = noise_shape[0] - horizon = noise_shape[2] - device = model.device +def build_observation_from_queues( + queues: dict[str, deque], + device: torch.device) -> dict[str, torch.Tensor]: + observation = { + 'observation.images.top': + torch.stack(list(queues['observation.images.top']), dim=1).permute( + 0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(queues['observation.state']), dim=1), + 'action': + torch.stack(list(queues['action']), dim=1), + } return { - 'img': torch.randn(noise_shape, device=device), - 'action': torch.randn((batch_size, horizon, model.agent_action_dim), - device=device), - 'state': torch.randn((batch_size, horizon, model.agent_state_dim), - device=device), + key: value.to(device, non_blocking=True) + for key, value in observation.items() } -def reset_sampling_seed(seed: int) -> None: - """Reset RNGs so repeated dense passes follow the same stochastic DDIM path.""" - random.seed(seed) - np.random.seed(seed % (2**32)) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) +def append_action_sequence( + queues: dict[str, deque], + action_seq: torch.Tensor, + ori_action_dim: int) -> dict[str, deque]: + for idx in range(action_seq.shape[1]): + action_frame = action_seq[0][idx:idx + 1].clone() + action_frame[:, ori_action_dim:] = 0.0 + queues = populate_queues(queues, {'action': action_frame}) + return queues -def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]: - """Load optional PSNR values keyed by sample_id or videoid.""" - if not psnr_path: - return {} - if not os.path.exists(psnr_path): - logging.warning("PSNR file not found: %s", psnr_path) - return {} - - suffix = os.path.splitext(psnr_path)[1].lower() - lookup: dict[str, float] = {} - if suffix == '.csv': - df = pd.read_csv(psnr_path) - key_column = 'sample_id' if 'sample_id' in df.columns else 'videoid' - value_column = 'psnr_full50' if 'psnr_full50' in df.columns else 'psnr' - for _, row in df.iterrows(): - if pd.notna(row[key_column]) and pd.notna(row[value_column]): - lookup[str(row[key_column])] = float(row[value_column]) - return lookup - - if suffix == '.json': - with open(psnr_path, 'r', encoding='utf-8') as file: - data = json.load(file) - if isinstance(data, dict): - if 'sample_id' in data and ('psnr_full50' in data or 'psnr' in data): - lookup[str(data['sample_id'])] = float( - data.get('psnr_full50', data['psnr'])) - else: - for key, value in data.items(): - if isinstance(value, (int, float)): - lookup[str(key)] = float(value) - elif isinstance(data, list): - for item in data: - if not isinstance(item, dict): - continue - if 'sample_id' in item and ('psnr_full50' in item or 'psnr' in - item): - lookup[str(item['sample_id'])] = float( - item.get('psnr_full50', item['psnr'])) - return lookup - - logging.warning("Unsupported PSNR file format: %s", psnr_path) - return {} - - -class InteractionAnalysisLogger: - """Collect stepwise metrics and aggregated per-sample summaries.""" - - STEP_COLUMNS = [ - 'sample_id', - 'scene', - 'pass_type', - 'round_id', - 'step', - 'step_time_s', - 'backbone_reused_blocks_count', - 'backbone_reuse_hit_blocks', - 'latent_delta', - 'action_delta', - 'state_delta', - 'action_cosine_vs_full50', - 'state_cosine_vs_full50', - 'latent_l2_vs_full50', - ] - SUMMARY_COLUMNS = [ - 'sample_id', - 'scene', - 'pass_type', - 'pass_total_time_s', - 'action_first_stable_step', - 'state_first_stable_step', - 'latent_first_stable_step', - 'action_vs_full50_90pct_step', - 'action_vs_full50_95pct_step', - 'oracle_budget_action', - 'oracle_budget_state', - 'oracle_budget_latent', - 'latent_init_dist_to_prev_round', - 'action_drift_vs_prev_round', - 'round_total_time_s', - 'policy_pass_total_time_s', - 'world_model_pass_total_time_s', - 'psnr_full50', - ] - ROUND_COLUMNS = [ - 'sample_id', - 'scene', - 'round_id', - 'policy_pass_total_time_s', - 'world_model_pass_total_time_s', - 'round_total_time_s', - 'latent_init_dist_to_prev_round', - 'action_drift_vs_prev_round', - 'psnr_full50', - ] - - def __init__(self, output_dir: str, psnr_lookup: dict[str, float]): - self.output_dir = output_dir - self.psnr_lookup = psnr_lookup - self.step_rows: list[dict] = [] - self.summary_buckets: dict[tuple[str, str, str], dict] = {} - self.round_rows: list[dict] = [] - self.round_buckets: dict[tuple[str, str], dict] = {} - self.prev_policy_action: dict[str, torch.Tensor] = {} - self.prev_world_latent: dict[str, torch.Tensor] = {} - - def resolve_psnr(self, sample_id: str, videoid: int) -> float: - """Resolve a PSNR value by full sample id first, then by raw video id.""" - candidates = [sample_id, sample_id.rsplit('-fs', 1)[0], str(videoid)] - for candidate in candidates: - if candidate in self.psnr_lookup: - return float(self.psnr_lookup[candidate]) - return np.nan - - def append_summary_row(self, row: dict) -> None: - """Store per-round summaries and aggregate them later by sample and pass type.""" - key = (row['sample_id'], row['scene'], row['pass_type']) - metric_columns = self.SUMMARY_COLUMNS[3:] - if key not in self.summary_buckets: - self.summary_buckets[key] = { - 'sample_id': row['sample_id'], - 'scene': row['scene'], - 'pass_type': row['pass_type'], - **{column: [] for column in metric_columns}, - } - for column in metric_columns: - self.summary_buckets[key][column].append(row.get(column, np.nan)) - - def append_round_row(self, row: dict) -> None: - """Store per-round metrics and aggregate them later by sample.""" - self.round_rows.append(row) - key = (row['sample_id'], row['scene']) - metric_columns = self.ROUND_COLUMNS[3:] - if key not in self.round_buckets: - self.round_buckets[key] = { - 'sample_id': row['sample_id'], - 'scene': row['scene'], - **{column: [] for column in metric_columns}, - } - for column in metric_columns: - self.round_buckets[key][column].append(row.get(column, np.nan)) - - def collect_trace_series(self, debug_info: dict, reference_action: torch.Tensor, - reference_state: torch.Tensor, - reference_latent: torch.Tensor) -> tuple[list[float], list[float], list[float]]: - """Extract cosine/L2 curves for either the target pass or the full-50 reference pass.""" - action_cosines = [] - state_cosines = [] - latent_l2s = [] - for record in debug_info['step_records']: - action_cosines.append( - batch_cosine_similarity(record['action'], reference_action)[0]) - state_cosines.append( - batch_cosine_similarity(record['state'], reference_state)[0]) - latent_l2s.append( - batch_l2_distance(record['pred_x0'], reference_latent)[0]) - return action_cosines, state_cosines, latent_l2s - - def log_pass(self, sample_id: str, videoid: int, scene: str, pass_type: str, - round_id: int, pass_total_time_s: float, target_debug: dict, - reference_debug: dict) -> dict | None: - """Log one pass worth of stepwise and aggregated metrics.""" - if not target_debug or not target_debug.get('step_records'): - return None - if not reference_debug or not reference_debug.get('step_records'): - reference_debug = target_debug - - reference_final_action = reference_debug['step_records'][-1]['action'] - reference_final_state = reference_debug['step_records'][-1]['state'] - reference_final_latent = reference_debug['step_records'][-1]['pred_x0'] - - prev_img = target_debug['analysis_init']['img'] - prev_action = target_debug['analysis_init']['action'] - prev_state = target_debug['analysis_init']['state'] - action_deltas: list[float] = [] - state_deltas: list[float] = [] - latent_deltas: list[float] = [] - action_cosines: list[float] = [] - state_cosines: list[float] = [] - latent_l2s: list[float] = [] - - for record in target_debug['step_records']: - latent_delta = batch_relative_l2(record['img'], prev_img)[0] - action_delta = batch_relative_l2(record['action'], prev_action)[0] - state_delta = batch_relative_l2(record['state'], prev_state)[0] - action_cosine = batch_cosine_similarity(record['action'], - reference_final_action)[0] - state_cosine = batch_cosine_similarity(record['state'], - reference_final_state)[0] - latent_l2 = batch_l2_distance(record['pred_x0'], - reference_final_latent)[0] - - action_deltas.append(action_delta) - state_deltas.append(state_delta) - latent_deltas.append(latent_delta) - action_cosines.append(action_cosine) - state_cosines.append(state_cosine) - latent_l2s.append(latent_l2) - - self.step_rows.append({ - 'sample_id': sample_id, - 'scene': scene, - 'pass_type': pass_type, - 'round_id': round_id, - 'step': record['step_index'], - 'step_time_s': float(record['step_time_s']), - 'backbone_reused_blocks_count': int( - record.get('backbone_reused_blocks_count', 0)), - 'backbone_reuse_hit_blocks': - record.get('backbone_reuse_hit_blocks', ''), - 'latent_delta': latent_delta, - 'action_delta': action_delta, - 'state_delta': state_delta, - 'action_cosine_vs_full50': action_cosine, - 'state_cosine_vs_full50': state_cosine, - 'latent_l2_vs_full50': latent_l2, - }) - - prev_img = record['img'] - prev_action = record['action'] - prev_state = record['state'] - - oracle_action_cosines, oracle_state_cosines, oracle_latent_l2s = self.collect_trace_series( - reference_debug, reference_final_action, reference_final_state, - reference_final_latent) - - latent_init_dist_to_prev_round = np.nan - action_drift_vs_prev_round = np.nan - if pass_type == 'policy': - previous_action = self.prev_policy_action.get(sample_id) - if previous_action is not None: - action_drift_vs_prev_round = 1.0 - batch_cosine_similarity( - reference_final_action, previous_action)[0] - self.prev_policy_action[sample_id] = reference_final_action.clone() - elif pass_type == 'world_model': - previous_latent = self.prev_world_latent.get(sample_id) - if previous_latent is not None: - latent_init_dist_to_prev_round = batch_l2_distance( - reference_final_latent, previous_latent)[0] - self.prev_world_latent[sample_id] = reference_final_latent.clone() - - summary_row = { - 'sample_id': sample_id, - 'scene': scene, - 'pass_type': pass_type, - 'pass_total_time_s': float(pass_total_time_s), - 'action_first_stable_step': np.nan, - 'state_first_stable_step': np.nan, - 'latent_first_stable_step': np.nan, - 'action_vs_full50_90pct_step': first_at_least(action_cosines, 0.90), - 'action_vs_full50_95pct_step': first_at_least(action_cosines, 0.95), - 'oracle_budget_action': first_at_least(oracle_action_cosines, 0.95), - 'oracle_budget_state': first_at_least(oracle_state_cosines, 0.95), - 'oracle_budget_latent': np.nan, - 'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round, - 'action_drift_vs_prev_round': action_drift_vs_prev_round, - 'round_total_time_s': np.nan, - 'policy_pass_total_time_s': np.nan, - 'world_model_pass_total_time_s': np.nan, - 'psnr_full50': self.resolve_psnr(sample_id, videoid), +def rollout_execution_segment( + queues: dict[str, deque], + seg_video: torch.Tensor, + pred_states: torch.Tensor, + zero_action_template: torch.Tensor, + exe_steps: int, + ori_state_dim: int, + zero_pred_state: bool) -> dict[str, deque]: + for idx in range(exe_steps): + state_frame = (torch.zeros_like(pred_states[0][idx:idx + 1]) + if zero_pred_state else pred_states[0][idx:idx + 1].clone()) + state_frame[:, ori_state_dim:] = 0.0 + observation = { + 'observation.images.top': + seg_video[0][:, idx:idx + 1].permute(1, 0, 2, 3), + 'observation.state': state_frame, + 'action': torch.zeros_like(zero_action_template), } - self.append_summary_row(summary_row) - return summary_row - - def log_round(self, sample_id: str, videoid: int, scene: str, round_id: int, - policy_pass_total_time_s: float, - world_model_pass_total_time_s: float, round_total_time_s: float, - latent_init_dist_to_prev_round: float, - action_drift_vs_prev_round: float) -> None: - """Log one interaction round consisting of one policy pass and one world-model pass.""" - self.append_round_row({ - 'sample_id': sample_id, - 'scene': scene, - 'round_id': round_id, - 'policy_pass_total_time_s': float(policy_pass_total_time_s), - 'world_model_pass_total_time_s': float(world_model_pass_total_time_s), - 'round_total_time_s': float(round_total_time_s), - 'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round, - 'action_drift_vs_prev_round': action_drift_vs_prev_round, - 'psnr_full50': self.resolve_psnr(sample_id, videoid), - }) - - def flush(self) -> None: - """Write analysis CSVs to disk.""" - os.makedirs(self.output_dir, exist_ok=True) - stepwise_path = os.path.join(self.output_dir, 'stepwise_log.csv') - summary_path = os.path.join(self.output_dir, 'sample_summary.csv') - round_path = os.path.join(self.output_dir, 'round_summary.csv') - - stepwise_df = pd.DataFrame(self.step_rows, columns=self.STEP_COLUMNS) - stepwise_df.to_csv(stepwise_path, index=False) - - round_df = pd.DataFrame(self.round_rows, columns=self.ROUND_COLUMNS) - round_df.to_csv(round_path, index=False) - - summary_rows = [] - metric_columns = self.SUMMARY_COLUMNS[3:] - for bucket in self.summary_buckets.values(): - round_bucket = self.round_buckets.get((bucket['sample_id'], - bucket['scene'])) - row = { - 'sample_id': bucket['sample_id'], - 'scene': bucket['scene'], - 'pass_type': bucket['pass_type'], - } - for column in metric_columns: - if round_bucket is not None and column in round_bucket: - row[column] = safe_mean(round_bucket[column]) - else: - row[column] = safe_mean(bucket[column]) - summary_rows.append(row) - summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS) - summary_df.to_csv(summary_path, index=False) - - -class BackboneBlockProfiler: - """Collect dense backbone block features and timings with a low-memory two-pass flow.""" - - COLUMNS = [ - 'sample_id', - 'case_id', - 'scene', - 'pass_type', - 'round_id', - 'step', - 'block_name', - 'block_stage', - 'block_index', - 'shape', - 'forward_time_ms', - 'l2_delta_vs_prev', - 'rel_l2_delta_vs_prev', - 'cosine_vs_prev', - 'l2_delta_vs_full50', - 'cosine_vs_full50', - ] - - def __init__(self, output_dir: str): - self.output_dir = output_dir - self.rows: list[dict] = [] - self.reference_features: dict[tuple[str, str, int], dict[str, - torch.Tensor]] = {} - self.mode: str | None = None - self.pass_key: tuple[str, str, int] | None = None - self.pass_meta: dict[str, str | int] = {} - self.current_reference: dict[str, torch.Tensor] = {} - self.previous_features: dict[str, torch.Tensor] = {} - - def _set_pass(self, mode: str, sample_id: str, case_id: str, scene: str, - pass_type: str, round_id: int) -> None: - self.mode = mode - self.pass_key = (sample_id, pass_type, int(round_id)) - self.pass_meta = { - 'sample_id': sample_id, - 'case_id': case_id, - 'scene': scene, - 'pass_type': pass_type, - 'round_id': int(round_id), - } - self.current_reference = {} - self.previous_features = {} - - def start_reference_pass(self, sample_id: str, case_id: str, scene: str, - pass_type: str, round_id: int) -> None: - self._set_pass('reference', sample_id, case_id, scene, pass_type, - round_id) - - def start_target_pass(self, sample_id: str, case_id: str, scene: str, - pass_type: str, round_id: int) -> None: - self._set_pass('target', sample_id, case_id, scene, pass_type, round_id) - - def finish_pass(self) -> None: - if self.mode == 'reference' and self.pass_key is not None: - self.reference_features[self.pass_key] = self.current_reference - elif self.mode == 'target' and self.pass_key is not None: - self.reference_features.pop(self.pass_key, None) - self.mode = None - self.pass_key = None - self.pass_meta = {} - self.current_reference = {} - self.previous_features = {} - - def record_block(self, step: int, block_name: str, block_stage: str, - block_index: int | None, output: torch.Tensor, - forward_time_ms: float) -> None: - if self.mode is None or self.pass_key is None: - return - block_output = output.detach().float().cpu() - if self.mode == 'reference': - self.current_reference[block_name] = block_output - return - - previous = self.previous_features.get(block_name) - reference = self.reference_features.get(self.pass_key, {}).get(block_name) - row = { - **self.pass_meta, - 'step': int(step), - 'block_name': block_name, - 'block_stage': block_stage, - 'block_index': -1 if block_index is None else int(block_index), - 'shape': str(tuple(block_output.shape)), - 'forward_time_ms': float(forward_time_ms), - 'l2_delta_vs_prev': np.nan, - 'rel_l2_delta_vs_prev': np.nan, - 'cosine_vs_prev': np.nan, - 'l2_delta_vs_full50': np.nan, - 'cosine_vs_full50': np.nan, - } - if previous is not None: - row['l2_delta_vs_prev'] = tensor_l2_distance(block_output, previous) - row['rel_l2_delta_vs_prev'] = tensor_relative_l2( - block_output, previous) - row['cosine_vs_prev'] = tensor_cosine_similarity( - block_output, previous) - if reference is not None: - row['l2_delta_vs_full50'] = tensor_l2_distance( - block_output, reference) - row['cosine_vs_full50'] = tensor_cosine_similarity( - block_output, reference) - - self.previous_features[block_name] = block_output - self.rows.append(row) - - def flush(self) -> None: - os.makedirs(self.output_dir, exist_ok=True) - path = os.path.join(self.output_dir, 'backbone_block_log.csv') - df = pd.DataFrame(self.rows, columns=self.COLUMNS) - df.to_csv(path, index=False) + queues = populate_queues(queues, observation) + return queues def write_video(video_path: str, stacked_frames: list, fps: int) -> None: @@ -726,6 +261,225 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None: options={'crf': '10'}) +# ========== Async I/O ========== +_io_executor: Optional[ThreadPoolExecutor] = None +_io_futures: List[Any] = [] +_child_processes: List[mp.Process] = [] + + +def _get_io_executor() -> ThreadPoolExecutor: + global _io_executor + if _io_executor is None: + _io_executor = ThreadPoolExecutor(max_workers=2) + return _io_executor + + +def _flush_io(): + """Wait for all pending async I/O to finish.""" + global _io_futures + for fut in _io_futures: + try: + fut.result() + except Exception as e: + print(f">>> [async I/O] error: {e}") + _io_futures.clear() + + +atexit.register(_flush_io) + + +def _register_child_process(proc: mp.Process) -> None: + _child_processes.append(proc) + + +def _unregister_child_process(proc: mp.Process) -> None: + try: + _child_processes.remove(proc) + except ValueError: + pass + + +def _terminate_process(proc: mp.Process, join_timeout: float = 3.0) -> None: + if proc is None: + return + try: + alive = proc.is_alive() + except Exception: + alive = False + if not alive: + try: + proc.join(timeout=0.1) + except Exception: + pass + return + + try: + proc.terminate() + except Exception: + pass + try: + proc.join(timeout=join_timeout) + except Exception: + pass + + try: + if proc.is_alive(): + proc.kill() + proc.join(timeout=join_timeout) + except Exception: + pass + + +def _terminate_all_child_processes() -> None: + for proc in list(_child_processes): + _terminate_process(proc) + _child_processes.clear() + + +def _handle_termination_signal(signum, _frame) -> None: + signame = signal.Signals(signum).name + print(f">>> Received {signame}, terminating child processes ...") + _terminate_all_child_processes() + raise SystemExit(128 + signum) + + +signal.signal(signal.SIGINT, _handle_termination_signal) +signal.signal(signal.SIGTERM, _handle_termination_signal) +atexit.register(_terminate_all_child_processes) + + +def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None: + """Synchronous save on CPU tensor (runs in background thread).""" + video = torch.clamp(video_cpu.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(filename, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None: + """Submit video saving to background thread pool.""" + video_cpu = video.detach().cpu() + fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps) + _io_futures.append(fut) + + +def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None: + """Synchronous TensorBoard log on CPU tensor (runs in background thread).""" + if video_cpu.dim() == 5: + n = video_cpu.shape[0] + video = video_cpu.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + writer.add_video(tag, grid, fps=fps) + + +def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None: + """Submit TensorBoard logging to background thread pool.""" + if isinstance(data, torch.Tensor) and data.dim() == 5: + data_cpu = data.detach().cpu() + fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps) + _io_futures.append(fut) + + +def _video_tensor_to_frames(video: Tensor) -> np.ndarray: + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + if n == 1: + # Fast path for bs=1: skip make_grid and convert directly. + frames = video[0].permute(1, 2, 3, 0).contiguous() + frames = ((frames + 1.0) / 2.0 * 255).to(torch.uint8) + return frames.numpy() + + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1) + return grid.numpy() + + +def _tensor_stats(name: str, tensor: Tensor) -> str: + t = tensor.detach().float() + finite_mask = torch.isfinite(t) + finite_ratio = finite_mask.float().mean().item() + if finite_mask.any(): + tf = t[finite_mask] + mean = tf.mean().item() + std = tf.std(unbiased=False).item() + min_v = tf.min().item() + max_v = tf.max().item() + abs_mean = tf.abs().mean().item() + else: + mean = std = min_v = max_v = abs_mean = float('nan') + return (f"{name}: shape={tuple(t.shape)} dtype={tensor.dtype} " + f"mean={mean:.6f} std={std:.6f} min={min_v:.6f} max={max_v:.6f} " + f"abs_mean={abs_mean:.6f} finite={finite_ratio:.6f}") + + +def _debug_world_model_stats(wm_samples: Tensor, wm_video: Tensor, + prefix: str) -> None: + print(f">>> [debug_wm_stats] {prefix} latent {_tensor_stats('wm_samples', wm_samples)}") + print(f">>> [debug_wm_stats] {prefix} decoded {_tensor_stats('wm_video', wm_video)}") + + +def _video_writer_process(q: mp.Queue, filename: str, fps: int): + writer = None + while True: + item = q.get() + if item is None: + break + frames = _video_tensor_to_frames(item) + if writer is None: + writer = imageio.get_writer( + filename, + fps=fps, + codec='libx264', + ffmpeg_params=['-crf', '10', '-pix_fmt', 'yuv420p']) + for frame in frames: + writer.append_data(frame) + if writer is not None: + writer.close() + + +def _stop_writer_process(writer_proc: mp.Process, write_q: mp.Queue) -> None: + try: + write_q.put(None, timeout=1.0) + except Exception: + pass + try: + writer_proc.join(timeout=5.0) + except Exception: + pass + try: + if writer_proc.is_alive(): + _terminate_process(writer_proc) + except Exception: + pass + try: + write_q.close() + write_q.join_thread() + except Exception: + pass + _unregister_child_process(writer_proc) + + def get_init_frame_path(data_dir: str, sample: dict) -> str: """Construct the init_frame path from directory and sample metadata. @@ -896,17 +650,21 @@ def image_guided_synthesis_sim_mode( action_cond_step: int = 16, n_samples: int = 1, ddim_steps: int = 50, - ddim_eta: float = 1.0, - unconditional_guidance_scale: float = 1.0, - fs: int | None = None, + ddim_eta: float = 1.0, + unconditional_guidance_scale: float = 1.0, + precision: int | None = 16, + fs: int | None = None, text_input: bool = True, timestep_spacing: str = 'uniform', guidance_rescale: float = 0.0, sim_mode: bool = True, - init_noise_bundle: dict[str, torch.Tensor] | None = None, decode_video: bool = True, - return_debug_info: bool = False, - **kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, dict | None]: + pipeline_split_step: int = 0, + pipeline_compare_full: bool = False, + handoff_callback=None, + stop_at_handoff: bool = False, + **kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, + torch.Tensor, dict[str, Any]]: """ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). @@ -929,26 +687,21 @@ def image_guided_synthesis_sim_mode( timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. - init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state. - decode_video (bool): Whether to decode the final latent into pixel space. - return_debug_info (bool): Whether to return per-step traces for analysis logging. - **kwargs: Additional arguments passed to the DDIM sampler, including - sparse head controls such as `head_schedule`, `head_log_steps`, - and `head_skip_mode`, plus optional decoder block reuse settings. - - Returns: - batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W]. - actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. - states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. - debug_info (dict | None): Optional per-step trace used for convergence analysis. + decode_video (bool): Whether to decode latent samples to pixel-space video. + Set to False to skip VAE decode for speed when only actions/states are needed. + **kwargs: Additional arguments passed to the DDIM sampler. + + Returns: + batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W], + or None when decode_video=False. + actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. + states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. """ - b, _, t, _, _ = noise_shape - ddim_sampler = DDIMSampler(model) - batch_size = noise_shape[0] - batch_variants = None - debug_info = None - - fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) + b, _, t, _, _ = noise_shape + ddim_sampler = DDIMSampler(model) + batch_size = noise_shape[0] + + fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] @@ -989,52 +742,494 @@ def image_guided_synthesis_sim_mode( False, ] - uc = None - kwargs.update({"unconditional_conditioning_img_nonetext": None}) - cond_mask = None - cond_z0 = None + uc = None + kwargs.update({"unconditional_conditioning_img_nonetext": None}) + cond_mask = None + cond_z0 = None + batch_variants = None if ddim_sampler is not None: - samples, actions, states, intermedia = ddim_sampler.sample( + sample_kwargs = dict( S=ddim_steps, conditioning=cond, batch_size=batch_size, - shape=noise_shape[1:], - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - cfg_img=None, - mask=cond_mask, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=None, + mask=cond_mask, x0=cond_z0, + precision=precision, fs=fs, timestep_spacing=timestep_spacing, guidance_rescale=guidance_rescale, - x_T=None if init_noise_bundle is None else init_noise_bundle['img'], - action_T=None if init_noise_bundle is None else - init_noise_bundle['action'], - state_T=None if init_noise_bundle is None else - init_noise_bundle['state'], - record_step_outputs=return_debug_info, **kwargs) + samples, actions, states, intermedia = ddim_sampler.sample( + **sample_kwargs, + handoff_step=pipeline_split_step, + handoff_callback=handoff_callback, + stop_at_handoff=stop_at_handoff) - batch_variants = None if decode_video: - batch_variants = model.decode_first_stage(samples) + # Reconstruct from latent to pixel space + batch_images = model.decode_first_stage(samples) + batch_variants = batch_images - if return_debug_info or intermedia.get('head_sparse_logs'): - debug_info = { - 'analysis_init': intermedia.get('analysis_init'), - 'step_records': intermedia.get('step_records', []), - 'head_sparse_logs': intermedia.get('head_sparse_logs', {}), - 'final_latent': samples.detach().cpu(), - 'final_action': actions.detach().cpu(), - 'final_state': states.detach().cpu(), + return batch_variants, actions, states, samples, intermedia + + +def load_inference_runtime(args: argparse.Namespace, + gpu_no: int) -> SimpleNamespace: + """Load the model once and keep all GPU-side runtime state together.""" + config = OmegaConf.load(args.config) + + prepared_path = args.ckpt_path + ".prepared.pt" + if os.path.exists(prepared_path): + print(f">>> Loading prepared model from {prepared_path} ...") + model = torch.load(prepared_path, + map_location=f"cuda:{gpu_no}", + weights_only=False, + mmap=True) + model.eval() + model = model.cuda(gpu_no) + diffusion_model = model.model.diffusion_model + if not hasattr(diffusion_model, '_ctx_cache_enabled'): + diffusion_model._ctx_cache_enabled = False + if not hasattr(diffusion_model, '_ctx_cache'): + diffusion_model._ctx_cache = {} + if not hasattr(diffusion_model, '_trt_backbone'): + diffusion_model._trt_backbone = None + if not hasattr(diffusion_model, '_state_stream'): + diffusion_model._state_stream = torch.cuda.Stream( + device=torch.device(f"cuda:{gpu_no}")) + print(f">>> Prepared model loaded.") + else: + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + model = load_model_checkpoint(model, args.ckpt_path) + model.eval() + model = model.cuda(gpu_no) + print(f'>>> Load pre-trained model ...') + print(f">>> Saving prepared model to {prepared_path} ...") + torch.save(model, prepared_path) + print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") + + device = get_device_from_parameters(model) + sync_module_device_attributes(model, device) + if hasattr(model, 'cond_stage_model') and model.cond_stage_model is not None: + sync_module_device_attributes(model.cond_stage_model, device) + if hasattr(model.model.diffusion_model, '_state_stream'): + model.model.diffusion_model._state_stream = torch.cuda.Stream( + device=device) + + # Fuse KV projections in attention layers (to_k + to_v -> to_kv). + from unifolm_wma.modules.attention import CrossAttention + kv_count = sum(1 for m in model.modules() + if isinstance(m, CrossAttention) and m.fuse_kv()) + print(f" ✓ KV fused: {kv_count} attention layers") + + # Load TRT backbone if engine exists and the user did not request Torch fallback. + trt_engine_path = args.trt_engine_path + if trt_engine_path is None: + trt_engine_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + '..', '..', 'trt_engines', + 'video_backbone.engine') + if args.disable_trt: + print(">>> TRT disabled by --disable_trt; using PyTorch video backbone.") + elif os.path.exists(trt_engine_path): + model.model.diffusion_model.load_trt_backbone(trt_engine_path) + else: + print(f">>> TRT engine not found at {trt_engine_path}; using PyTorch video backbone.") + + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + + return SimpleNamespace(model=model, config=config, device=device, gpu_no=gpu_no) + + +def load_inference_data(config: OmegaConf, + args: argparse.Namespace) -> SimpleNamespace: + logging.info("***** Configing Data *****") + test_cfg = OmegaConf.create( + OmegaConf.to_container(config.data.params.test, resolve=True)) + stats_root = OmegaConf.select(config, "data.params.test.params.data_dir") + if stats_root is None: + stats_root = args.prompt_dir + test_cfg.params["data_dir"] = stats_root + test_cfg.params["meta_path"] = os.path.join(args.prompt_dir, + f"{args.dataset}.csv") + test_cfg.params["transition_dir"] = os.path.join(stats_root, + "transitions") + test_cfg.params["dataset_name"] = args.dataset + + target_dataset = instantiate_from_config(test_cfg) + data = SimpleNamespace(test_datasets={args.dataset: target_dataset}) + print(">>> Dataset is successfully loaded ...") + return data + + +def build_pipeline_iteration_input( + runtime: SimpleNamespace, + args: argparse.Namespace, + base_queues: dict[str, deque], + action_seq: torch.Tensor, + noise_shape: list[int], + model_input_fs: int, + ori_action_dim: int, + ori_state_dim: int) -> dict[str, deque]: + model = runtime.model + device = runtime.device + policy_observation = build_observation_from_queues(base_queues, device) + pipeline_action_queues = clone_observation_queues(base_queues) + pipeline_action_queues = append_action_sequence(pipeline_action_queues, + action_seq, + ori_action_dim) + wm_handoff_observation = { + 'observation.images.top': policy_observation['observation.images.top'], + 'observation.state': policy_observation['observation.state'], + 'action': torch.stack(list(pipeline_action_queues['action']), + dim=1).to(device, non_blocking=True), + } + _, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode( + model, + "", + wm_handoff_observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + precision=args.precision, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + decode_video=False, + pipeline_split_step=args.pipeline_split_step, + pipeline_compare_full=False, + stop_at_handoff=True) + wm_handoff = wm_intermedia.get('handoff', {}) + handoff_video_latent = wm_handoff.get('pred_x0', wm_handoff.get( + 'samples', None)) + if handoff_video_latent is None: + handoff_video_latent = wm_samples + handoff_video = model.decode_first_stage(handoff_video_latent) + handoff_state_seq = wm_handoff.get('states', pred_states) + + cond_obs_queues = clone_observation_queues(base_queues) + cond_obs_queues = append_action_sequence(cond_obs_queues, action_seq, + ori_action_dim) + cond_obs_queues = rollout_execution_segment( + queues=cond_obs_queues, + seg_video=handoff_video[:, :, :args.exe_steps], + pred_states=handoff_state_seq, + zero_action_template=action_seq[0][-1:], + exe_steps=args.exe_steps, + ori_state_dim=ori_state_dim, + zero_pred_state=args.zero_pred_state) + return cond_obs_queues + + +def run_pipeline_iteration_task( + runtime: SimpleNamespace, + args: argparse.Namespace, + iter_idx: int, + instruction: str, + input_payload: dict[str, Any], + noise_shape: list[int], + model_input_fs: int, + ori_action_dim: int, + ori_state_dim: int, + handoff_queue: Queue) -> dict[str, Any]: + model = runtime.model + device = runtime.device + + if input_payload['mode'] == 'initial': + cond_obs_queues = move_observation_queues_to_device( + input_payload['queues_cpu'], device) + elif input_payload['mode'] == 'handoff': + base_queues = move_observation_queues_to_device( + input_payload['base_queues_cpu'], device) + action_seq = input_payload['action_seq_cpu'].to( + device, non_blocking=True) + source_iter = input_payload['source_iter'] + pipeline_print( + f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: building pipeline input ' + f'from step {source_iter} handoff ...') + cond_obs_queues = build_pipeline_iteration_input( + runtime=runtime, + args=args, + base_queues=base_queues, + action_seq=action_seq, + noise_shape=noise_shape, + model_input_fs=model_input_fs, + ori_action_dim=ori_action_dim, + ori_state_dim=ori_state_dim) + else: + raise ValueError(f"Unsupported input mode: {input_payload['mode']}") + + iter_input_queues = clone_observation_queues(cond_obs_queues) + iter_input_queues_cpu = clone_observation_queues_to_cpu(iter_input_queues) + handoff_sent = False + + def _on_policy_handoff(handoff: dict[str, torch.Tensor]) -> None: + nonlocal handoff_sent + if handoff_sent or (iter_idx + 1) >= args.n_iter: + return + handoff_queue.put({ + 'source_iter': + iter_idx, + 'next_iter': + iter_idx + 1, + 'base_queues_cpu': + iter_input_queues_cpu, + 'action_seq_cpu': + handoff['actions'].detach().cpu().clone(), + 'handoff_step': + handoff.get('step', args.pipeline_split_step), + }) + handoff_sent = True + handoff_step = handoff.get('step', args.pipeline_split_step) + pipeline_print( + f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: emitted policy handoff ' + f'for step {iter_idx + 1} at ddim step ' + f'{handoff_step} ...') + + policy_observation = build_observation_from_queues(cond_obs_queues, device) + pipeline_print( + f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: generating actions ...') + _, pred_actions, _, _, _ = image_guided_synthesis_sim_mode( + model, + instruction, + policy_observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + precision=args.precision, + fs=model_input_fs, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False, + decode_video=False, + pipeline_split_step=args.pipeline_split_step, + pipeline_compare_full=False, + handoff_callback=_on_policy_handoff if args.pipeline_split_step > 0 else + None) + + cond_obs_queues = append_action_sequence(cond_obs_queues, pred_actions, + ori_action_dim) + wm_observation = { + 'observation.images.top': policy_observation['observation.images.top'], + 'observation.state': policy_observation['observation.state'], + 'action': torch.stack(list(cond_obs_queues['action']), dim=1).to( + device, non_blocking=True), + } + pipeline_print( + f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: interacting with world model ...') + _, _, pred_states, wm_samples, _ = image_guided_synthesis_sim_mode( + model, + "", + wm_observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + precision=args.precision, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + decode_video=False, + pipeline_split_step=0, + pipeline_compare_full=False) + + wm_video = model.decode_first_stage(wm_samples) + if args.debug_wm_stats: + _debug_world_model_stats(wm_samples, + wm_video, + prefix=f"step={iter_idx}@gpu{runtime.gpu_no}") + seg_video = wm_video[:, :, :args.exe_steps] + pipeline_print('>' * 24) + return { + 'iter_idx': iter_idx, + 'gpu_no': runtime.gpu_no, + 'seg_video_cpu': seg_video.detach().cpu(), + } + + +def run_inference_multi_gpu_pipeline( + args: argparse.Namespace, + runtimes: list[SimpleNamespace]) -> None: + os.makedirs(args.savedir + '/inference', exist_ok=True) + config = runtimes[0].config + data = load_inference_data(config, args) + df = pd.read_csv(os.path.join(args.prompt_dir, f"{args.dataset}.csv")) + + model0 = runtimes[0].model + assert (args.height % 16 == 0) and ( + args.width % 16 + == 0), "Error: image size [h,w] should be multiples of 16!" + assert args.bs == 1, "Current implementation only support [batch size = 1]!" + + h, w = args.height // 8, args.width // 8 + channels = model0.model.diffusion_model.out_channels + n_frames = args.video_length + pipeline_print(f'>>> Generate {n_frames} frames under each generation ...') + noise_shape = [args.bs, channels, n_frames, h, w] + gpu_ids = [runtime.gpu_no for runtime in runtimes] + pipeline_print(f'>>> Multi-GPU pipeline enabled on GPUs {gpu_ids}.') + + for idx in range(0, len(df)): + sample = df.iloc[idx] + init_frame_path = get_init_frame_path(args.prompt_dir, sample) + ori_fps = float(sample['fps']) + video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" + os.makedirs(video_save_dir, exist_ok=True) + + transition_path = get_transition_path(args.prompt_dir, sample) + with h5py.File(transition_path, 'r') as h5f: + transition_dict = {} + for key in h5f.keys(): + transition_dict[key] = torch.tensor(h5f[key][()]) + for key in h5f.attrs.keys(): + transition_dict[key] = h5f.attrs[key] + + for fs in args.frame_stride: + sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" + write_q = mp.Queue(maxsize=4) + writer_proc = mp.Process(target=_video_writer_process, + args=(write_q, sample_full_video_file, + args.save_fps)) + writer_proc.daemon = True + writer_proc.start() + _register_child_process(writer_proc) + executors = { + runtime.gpu_no: ThreadPoolExecutor(max_workers=1) + for runtime in runtimes } + handoff_events = Queue() + future_meta = {} + submitted_iters = set() + completed_results = {} + next_write_idx = 0 + model_input_fs = ori_fps // fs + progress = tqdm(total=args.n_iter, + ascii=False, + desc=f'fs={fs} pipeline', + leave=True) + try: + batch, ori_state_dim, ori_action_dim = prepare_init_input( + 0, + init_frame_path, + transition_dict, + fs, + data.test_datasets[args.dataset], + n_obs_steps=model0.n_obs_steps_imagen) + initial_observation = { + 'observation.images.top': + batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), + 'observation.state': + batch['observation.state'][-1].unsqueeze(0), + 'action': + torch.zeros_like(batch['action'][-1]).unsqueeze(0), + } + initial_queues = { + "observation.images.top": + deque(maxlen=model0.n_obs_steps_imagen), + "observation.state": deque(maxlen=model0.n_obs_steps_imagen), + "action": deque(maxlen=args.video_length), + } + initial_queues = populate_queues(initial_queues, + initial_observation) + initial_payload = { + 'mode': 'initial', + 'queues_cpu': clone_observation_queues_to_cpu(initial_queues), + } + first_gpu = runtimes[0].gpu_no + future = executors[first_gpu].submit( + run_pipeline_iteration_task, runtimes[0], args, 0, + sample['instruction'], initial_payload, noise_shape, + model_input_fs, ori_action_dim, ori_state_dim, + handoff_events) + future_meta[future] = {'iter_idx': 0, 'gpu_no': first_gpu} + submitted_iters.add(0) - return batch_variants, actions, states, debug_info + while next_write_idx < args.n_iter: + done_futures = [fut for fut in list(future_meta.keys()) + if fut.done()] + for fut in done_futures: + meta = future_meta.pop(fut) + result = fut.result() + completed_results[result['iter_idx']] = result + pipeline_print( + f'>>> Step {result["iter_idx"]}@gpu{meta["gpu_no"]}: ' + f'final segment ready.') + + while True: + try: + event = handoff_events.get_nowait() + except Empty: + break + next_iter = event['next_iter'] + if next_iter >= args.n_iter or next_iter in submitted_iters: + continue + target_runtime = runtimes[next_iter % len(runtimes)] + payload = { + 'mode': 'handoff', + 'base_queues_cpu': event['base_queues_cpu'], + 'action_seq_cpu': event['action_seq_cpu'], + 'source_iter': event['source_iter'], + 'handoff_step': event['handoff_step'], + } + pipeline_print( + f'>>> Step {event["source_iter"]}: queueing step ' + f'{next_iter} on gpu{target_runtime.gpu_no} from ' + f'ddim step {event["handoff_step"]} ...') + future = executors[target_runtime.gpu_no].submit( + run_pipeline_iteration_task, target_runtime, args, + next_iter, sample['instruction'], payload, + noise_shape, model_input_fs, ori_action_dim, + ori_state_dim, handoff_events) + future_meta[future] = { + 'iter_idx': next_iter, + 'gpu_no': target_runtime.gpu_no + } + submitted_iters.add(next_iter) + + while next_write_idx in completed_results: + result = completed_results.pop(next_write_idx) + write_q.put(result['seg_video_cpu']) + next_write_idx += 1 + progress.update(1) + progress.set_postfix_str( + f'written={next_write_idx}/{args.n_iter}', + refresh=False) + + if next_write_idx < args.n_iter: + if not future_meta and len(submitted_iters) >= args.n_iter: + raise RuntimeError( + "Pipeline scheduler is idle before all iterations finished.") + time.sleep(0.05) + finally: + progress.close() + for executor in executors.values(): + executor.shutdown(wait=True, cancel_futures=False) + _stop_writer_process(writer_proc, write_q) + + _flush_io() -def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: +def run_inference(args: argparse.Namespace, + gpu_num: int, + gpu_no: int, + runtime: Optional[SimpleNamespace] = None) -> None: """ Run inference pipeline on prompts and image inputs. @@ -1046,74 +1241,40 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: Returns: None """ - # Create inference and tensorboard dirs - inference_dir = args.savedir + '/inference' - os.makedirs(inference_dir, exist_ok=True) - log_dir = args.savedir + f"/tensorboard" - os.makedirs(log_dir, exist_ok=True) - writer = SummaryWriter(log_dir=log_dir) - analysis_logger = None - backbone_profiler = None - head_schedule = args.head_schedule_steps if args.head_schedule_steps else None - head_log_steps = args.head_log_steps if args.head_log_steps else None - head_skip_mode = args.head_skip_mode - backbone_reuse_blocks = (args.backbone_reuse_blocks - if args.backbone_reuse_blocks else None) - backbone_reuse_schedule_steps = ( - args.backbone_reuse_schedule_steps - if args.backbone_reuse_schedule_steps else None) - backbone_reuse_force_compute_steps = ( - args.backbone_reuse_force_compute_steps - if args.backbone_reuse_force_compute_steps else None) - backbone_reuse_enabled = ( - args.backbone_reuse_mode == "reuse_output" - and backbone_reuse_blocks is not None) - if args.analysis_log_metrics: - analysis_logger = InteractionAnalysisLogger( - output_dir=inference_dir, - psnr_lookup=load_psnr_lookup(args.analysis_psnr_path), - ) - if args.analysis_profile_backbone_blocks: - if head_schedule is not None or backbone_reuse_enabled: - raise ValueError( - "Backbone block profiling expects dense DDIM runs. " - "Do not pass sparse head or backbone reuse flags.") - backbone_profiler = BackboneBlockProfiler(output_dir=inference_dir) - case_id = get_case_id(args.prompt_dir) + # Create inference dir + os.makedirs(args.savedir + '/inference', exist_ok=True) # Load prompt csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) - # Load config - config = OmegaConf.load(args.config) - config['model']['params']['wma_config']['params'][ - 'use_checkpoint'] = False - model = instantiate_from_config(config.model) - model.perframe_ae = args.perframe_ae - assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" - model = load_model_checkpoint(model, args.ckpt_path) - model.eval() - print(f'>>> Load pre-trained model ...') + # Load config (always needed for data setup) + config = runtime.config if runtime is not None else OmegaConf.load(args.config) + + # Parallel loading: model and data are independent + def _load_model(): + return load_inference_runtime(args, gpu_no).model - # Build unnomalizer - logging.info("***** Configing Data *****") - data = instantiate_from_config(config.data) - data.setup() - print(">>> Dataset is successfully loaded ...") + def _load_data(): + return load_inference_data(config, args) - model = model.cuda(gpu_no) - device = get_device_from_parameters(model) + if runtime is None: + with ThreadPoolExecutor(max_workers=2) as executor: + model_future = executor.submit(_load_model) + data_future = executor.submit(_load_data) + model = model_future.result() + data = data_future.result() + device = get_device_from_parameters(model) + else: + model = runtime.model + device = runtime.device + data = _load_data() # Run over data - assert (args.height % 16 == 0) and ( - args.width % 16 - == 0), "Error: image size [h,w] should be multiples of 16!" - assert args.bs == 1, "Current implementation only support [batch size = 1]!" - if args.analysis_log_metrics: - assert args.ddim_steps > 0, "analysis_log_metrics requires positive --ddim_steps." - assert args.analysis_reference_steps > 0, ( - "analysis_log_metrics requires positive --analysis_reference_steps.") + assert (args.height % 16 == 0) and ( + args.width % 16 + == 0), "Error: image size [h,w] should be multiples of 16!" + assert args.bs == 1, "Current implementation only support [batch size = 1]!" # Get latent noise shape h, w = args.height // 8, args.width // 8 @@ -1130,10 +1291,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) - video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" - os.makedirs(video_save_dir, exist_ok=True) - os.makedirs(video_save_dir + '/dm', exist_ok=True) - os.makedirs(video_save_dir + '/wm', exist_ok=True) + video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" + os.makedirs(video_save_dir, exist_ok=True) # Load transitions to get the initial state later transition_path = get_transition_path(args.prompt_dir, sample) @@ -1146,395 +1305,228 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # If many, test various frequence control and world-model generation for fs in args.frame_stride: - sample_id = build_sample_id(args.dataset, sample, fs) - scene = get_scene_name(sample, args.dataset) - - # For saving imagens in policy - sample_save_dir = f'{video_save_dir}/dm/{fs}' - os.makedirs(sample_save_dir, exist_ok=True) - # For saving environmental changes in world-model - sample_save_dir = f'{video_save_dir}/wm/{fs}' - os.makedirs(sample_save_dir, exist_ok=True) - # For collecting interaction videos - wm_video = [] - # Initialize observation queues - cond_obs_queues = { - "observation.images.top": - deque(maxlen=model.n_obs_steps_imagen), - "observation.state": deque(maxlen=model.n_obs_steps_imagen), - "action": deque(maxlen=args.video_length), - } - # Obtain initial frame and state - start_idx = 0 - model_input_fs = ori_fps // fs - batch, ori_state_dim, ori_action_dim = prepare_init_input( - start_idx, - init_frame_path, - transition_dict, - fs, - data.test_datasets[args.dataset], - n_obs_steps=model.n_obs_steps_imagen) - observation = { - 'observation.images.top': - batch['observation.image'].permute(1, 0, 2, - 3)[-1].unsqueeze(0), - 'observation.state': - batch['observation.state'][-1].unsqueeze(0), - 'action': - torch.zeros_like(batch['action'][-1]).unsqueeze(0) - } - observation = { - key: observation[key].to(device, non_blocking=True) - for key in observation - } - # Update observation queues - cond_obs_queues = populate_queues(cond_obs_queues, observation) - - # Multi-round interaction with the world-model - for itr in tqdm(range(args.n_iter)): - round_start_time = time.time() - - # Get observation + # Writer process for incremental video saving + sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" + write_q = mp.Queue(maxsize=4) + writer_proc = mp.Process( + target=_video_writer_process, + args=(write_q, sample_full_video_file, args.save_fps)) + writer_proc.daemon = True + writer_proc.start() + _register_child_process(writer_proc) + try: + # Initialize observation queues + cond_obs_queues = { + "observation.images.top": + deque(maxlen=model.n_obs_steps_imagen), + "observation.state": deque(maxlen=model.n_obs_steps_imagen), + "action": deque(maxlen=args.video_length), + } + # Obtain initial frame and state + start_idx = 0 + model_input_fs = ori_fps // fs + batch, ori_state_dim, ori_action_dim = prepare_init_input( + start_idx, + init_frame_path, + transition_dict, + fs, + data.test_datasets[args.dataset], + n_obs_steps=model.n_obs_steps_imagen) observation = { - 'observation.images.top': - torch.stack(list( - cond_obs_queues['observation.images.top']), - dim=1).permute(0, 2, 1, 3, 4), - 'observation.state': - torch.stack(list(cond_obs_queues['observation.state']), - dim=1), - 'action': - torch.stack(list(cond_obs_queues['action']), dim=1), + 'observation.images.top': + batch['observation.image'].permute(1, 0, 2, + 3)[-1].unsqueeze(0), + 'observation.state': + batch['observation.state'][-1].unsqueeze(0), + 'action': + torch.zeros_like(batch['action'][-1]).unsqueeze(0) } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } + # Update observation queues + cond_obs_queues = populate_queues(cond_obs_queues, observation) - # Use world-model in policy to generate action - print(f'>>> Step {itr}: generating actions ...') - policy_noise_bundle = ( - make_sampling_noise_bundle(model, noise_shape) - if (args.analysis_log_metrics - or backbone_profiler is not None) else None) - policy_reference_debug = None - policy_sampling_seed = int(args.seed + itr * 1000 + 11) - if backbone_profiler is not None: - reset_sampling_seed(policy_sampling_seed) - backbone_profiler.start_reference_pass( - sample_id=sample_id, - case_id=case_id, - scene=scene, - pass_type='policy', - round_id=itr, - ) - image_guided_synthesis_sim_mode( + # Multi-round interaction with the world-model + pending_handoff = None + pending_handoff_source_itr = None + pending_handoff_ddim_step = None + for itr in tqdm(range(args.n_iter), ascii=False): + + # Build observation for policy pass. + if pending_handoff is not None: + print( + f'>>> Step {itr}: consuming pipeline handoff ' + f'from step {pending_handoff_source_itr} ' + f'at ddim step {pending_handoff_ddim_step} ...') + cond_obs_queues = pending_handoff + pending_handoff = None + pending_handoff_source_itr = None + pending_handoff_ddim_step = None + + iter_input_queues = clone_observation_queues(cond_obs_queues) + policy_observation = build_observation_from_queues( + cond_obs_queues, device) + + # Use world-model in policy to generate action + print(f'>>> Step {itr}: generating actions ...') + _, pred_actions, _, _, policy_intermedia = image_guided_synthesis_sim_mode( model, sample['instruction'], - observation, + policy_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, + precision=args.precision, fs=model_input_fs, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False, - init_noise_bundle=policy_noise_bundle, decode_video=False, - return_debug_info=False, - backbone_block_profiler=backbone_profiler, - ) - backbone_profiler.finish_pass() - need_policy_reference = args.analysis_log_metrics and ( - args.analysis_reference_steps != args.ddim_steps - or head_schedule is not None or backbone_reuse_enabled) - if need_policy_reference: - _, _, _, policy_reference_debug = image_guided_synthesis_sim_mode( - model, - sample['instruction'], - observation, - noise_shape, - action_cond_step=args.exe_steps, - ddim_steps=args.analysis_reference_steps, - ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - sim_mode=False, - init_noise_bundle=policy_noise_bundle, - decode_video=False, - return_debug_info=True, - head_log_steps=head_log_steps) - policy_pass_start = time.time() - if backbone_profiler is not None: - reset_sampling_seed(policy_sampling_seed) - backbone_profiler.start_target_pass( - sample_id=sample_id, - case_id=case_id, - scene=scene, - pass_type='policy', - round_id=itr, - ) - pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode( - model, - sample['instruction'], - observation, - noise_shape, - action_cond_step=args.exe_steps, - ddim_steps=args.ddim_steps, - ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - sim_mode=False, - init_noise_bundle=policy_noise_bundle, - return_debug_info=args.analysis_log_metrics, - head_schedule=head_schedule, - head_log_steps=head_log_steps, - head_skip_mode=head_skip_mode, - backbone_reuse_blocks=backbone_reuse_blocks, - backbone_reuse_start_step=args.backbone_reuse_start_step, - backbone_reuse_schedule_steps= - backbone_reuse_schedule_steps, - backbone_reuse_force_compute_steps= - backbone_reuse_force_compute_steps, - backbone_reuse_mode=args.backbone_reuse_mode, - backbone_block_profiler=backbone_profiler) - if backbone_profiler is not None: - backbone_profiler.finish_pass() - policy_pass_total_time_s = time.time() - policy_pass_start - policy_summary_row = None - if analysis_logger is not None: - if policy_reference_debug is None: - policy_reference_debug = policy_debug - policy_summary_row = analysis_logger.log_pass( - sample_id=sample_id, - videoid=int(sample['videoid']), - scene=scene, - pass_type='policy', - round_id=itr, - pass_total_time_s=policy_pass_total_time_s, - target_debug=policy_debug, - reference_debug=policy_reference_debug, - ) + pipeline_split_step=args.pipeline_split_step, + pipeline_compare_full=False) - # Update future actions in the observation queues - for idx in range(len(pred_actions[0])): - observation = {'action': pred_actions[0][idx:idx + 1]} - observation['action'][:, ori_action_dim:] = 0.0 - cond_obs_queues = populate_queues(cond_obs_queues, - observation) - - # Collect data for interacting the world-model using the predicted actions - observation = { - 'observation.images.top': - torch.stack(list( - cond_obs_queues['observation.images.top']), - dim=1).permute(0, 2, 1, 3, 4), - 'observation.state': - torch.stack(list(cond_obs_queues['observation.state']), - dim=1), - 'action': - torch.stack(list(cond_obs_queues['action']), dim=1), - } - observation = { - key: observation[key].to(device, non_blocking=True) - for key in observation - } + pipeline_action_seq = None + pipeline_pred_states = None + pipeline_wm_samples = None + pipeline_wm_intermedia = {} + if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter: + pipeline_action_seq = policy_intermedia.get( + 'handoff', {}).get('actions', pred_actions) + pipeline_action_queues = clone_observation_queues( + iter_input_queues) + pipeline_action_queues = append_action_sequence( + pipeline_action_queues, pipeline_action_seq, + ori_action_dim) + wm_handoff_observation = { + 'observation.images.top': + policy_observation['observation.images.top'], + 'observation.state': + policy_observation['observation.state'], + 'action': + torch.stack(list(pipeline_action_queues['action']), + dim=1).to(device, non_blocking=True), + } + print( + f'>>> Step {itr}: preparing pipeline handoff branch ...' + ) + _, _, pipeline_pred_states, pipeline_wm_samples, pipeline_wm_intermedia = image_guided_synthesis_sim_mode( + model, + "", + wm_handoff_observation, + noise_shape, + action_cond_step=args.exe_steps, + ddim_steps=args.ddim_steps, + ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args. + unconditional_guidance_scale, + precision=args.precision, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + decode_video=False, + pipeline_split_step=args.pipeline_split_step, + pipeline_compare_full=False) - # Interaction with the world-model - print(f'>>> Step {itr}: interacting with world model ...') - world_noise_bundle = ( - make_sampling_noise_bundle(model, noise_shape) - if (args.analysis_log_metrics - or backbone_profiler is not None) else None) - world_reference_debug = None - world_sampling_seed = int(args.seed + itr * 1000 + 29) - if backbone_profiler is not None: - reset_sampling_seed(world_sampling_seed) - backbone_profiler.start_reference_pass( - sample_id=sample_id, - case_id=case_id, - scene=scene, - pass_type='world_model', - round_id=itr, - ) - image_guided_synthesis_sim_mode( + # Update future actions in the observation queues + cond_obs_queues = append_action_sequence( + cond_obs_queues, pred_actions, ori_action_dim) + + # Reuse images/state and only rebuild action for WM pass. + wm_observation = { + 'observation.images.top': policy_observation['observation.images.top'], + 'observation.state': policy_observation['observation.state'], + 'action': torch.stack(list(cond_obs_queues['action']), dim=1).to( + device, non_blocking=True), + } + + # Interaction with the world-model + print(f'>>> Step {itr}: interacting with world model ...') + _, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode( model, "", - observation, + wm_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, + precision=args.precision, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, - init_noise_bundle=world_noise_bundle, decode_video=False, - return_debug_info=False, - backbone_block_profiler=backbone_profiler, - ) - backbone_profiler.finish_pass() - need_world_reference = args.analysis_log_metrics and ( - args.analysis_reference_steps != args.ddim_steps - or head_schedule is not None or backbone_reuse_enabled) - if need_world_reference: - _, _, _, world_reference_debug = image_guided_synthesis_sim_mode( - model, - "", - observation, - noise_shape, - action_cond_step=args.exe_steps, - ddim_steps=args.analysis_reference_steps, - ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - text_input=False, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - init_noise_bundle=world_noise_bundle, - decode_video=False, - return_debug_info=True, - head_log_steps=head_log_steps) - world_pass_start = time.time() - if backbone_profiler is not None: - reset_sampling_seed(world_sampling_seed) - backbone_profiler.start_target_pass( - sample_id=sample_id, - case_id=case_id, - scene=scene, - pass_type='world_model', - round_id=itr, - ) - pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode( - model, - "", - observation, - noise_shape, - action_cond_step=args.exe_steps, - ddim_steps=args.ddim_steps, - ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - text_input=False, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - init_noise_bundle=world_noise_bundle, - return_debug_info=args.analysis_log_metrics, - head_schedule=head_schedule, - head_log_steps=head_log_steps, - head_skip_mode=head_skip_mode, - backbone_reuse_blocks=backbone_reuse_blocks, - backbone_reuse_start_step=args.backbone_reuse_start_step, - backbone_reuse_schedule_steps= - backbone_reuse_schedule_steps, - backbone_reuse_force_compute_steps= - backbone_reuse_force_compute_steps, - backbone_reuse_mode=args.backbone_reuse_mode, - backbone_block_profiler=backbone_profiler) - if backbone_profiler is not None: - backbone_profiler.finish_pass() - world_pass_total_time_s = time.time() - world_pass_start - world_summary_row = None - if analysis_logger is not None: - if world_reference_debug is None: - world_reference_debug = world_debug - world_summary_row = analysis_logger.log_pass( - sample_id=sample_id, - videoid=int(sample['videoid']), - scene=scene, - pass_type='world_model', - round_id=itr, - pass_total_time_s=world_pass_total_time_s, - target_debug=world_debug, - reference_debug=world_reference_debug, - ) - analysis_logger.log_round( - sample_id=sample_id, - videoid=int(sample['videoid']), - scene=scene, - round_id=itr, - policy_pass_total_time_s=policy_pass_total_time_s, - world_model_pass_total_time_s= - world_pass_total_time_s, - round_total_time_s=time.time() - round_start_time, - latent_init_dist_to_prev_round=np.nan - if world_summary_row is None else - world_summary_row['latent_init_dist_to_prev_round'], - action_drift_vs_prev_round=np.nan - if policy_summary_row is None else - policy_summary_row['action_drift_vs_prev_round'], - ) + pipeline_split_step=args.pipeline_split_step, + pipeline_compare_full=False) - for idx in range(args.exe_steps): - observation = { - 'observation.images.top': - pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), - 'observation.state': - torch.zeros_like(pred_states[0][idx:idx + 1]) if - args.zero_pred_state else pred_states[0][idx:idx + 1], - 'action': - torch.zeros_like(pred_actions[0][-1:]) - } - observation['observation.state'][:, ori_state_dim:] = 0.0 - cond_obs_queues = populate_queues(cond_obs_queues, - observation) - - # Save the imagen videos for decision-making - sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_0, - sample_tag, - fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_1, - sample_tag, - fps=args.save_fps) - - # Save the imagen videos for decision-making - sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_0.cpu(), - sample_video_file, - fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_1.cpu(), - sample_video_file, - fps=args.save_fps) - - print('>' * 24) - # Collect the result of world-model interactions - wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) - - full_video = torch.cat(wm_video, dim=2) - sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" - log_to_tensorboard(writer, - full_video, - sample_tag, - fps=args.save_fps) - sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" - save_results(full_video, sample_full_video_file, fps=args.save_fps) + # Decode full WM clip, then take executable segment to keep behavior closer to previous path. + wm_video = model.decode_first_stage(wm_samples) + if args.debug_wm_stats: + _debug_world_model_stats( + wm_samples, + wm_video, + prefix=f"step={itr}") + seg_video = wm_video[:, :, :args.exe_steps] - if analysis_logger is not None: - analysis_logger.flush() - if backbone_profiler is not None: - backbone_profiler.flush() - writer.close() + cond_obs_queues = rollout_execution_segment( + queues=cond_obs_queues, + seg_video=seg_video, + pred_states=pred_states, + zero_action_template=pred_actions[0][-1:], + exe_steps=args.exe_steps, + ori_state_dim=ori_state_dim, + zero_pred_state=args.zero_pred_state) + + if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter: + policy_handoff = policy_intermedia.get('handoff', {}) + wm_handoff = pipeline_wm_intermedia.get('handoff', {}) + handoff_ddim_step = wm_handoff.get( + 'step', + policy_handoff.get('step', args.pipeline_split_step)) + next_action_seq = pipeline_action_seq + handoff_video_latent = wm_handoff.get( + 'pred_x0', wm_handoff.get('samples', None)) + if handoff_video_latent is None: + handoff_video_latent = pipeline_wm_samples + handoff_video = model.decode_first_stage( + handoff_video_latent) + handoff_state_seq = wm_handoff.get('states', + pipeline_pred_states) + + pending_handoff = clone_observation_queues( + iter_input_queues) + pending_handoff = append_action_sequence( + pending_handoff, next_action_seq, ori_action_dim) + pending_handoff = rollout_execution_segment( + queues=pending_handoff, + seg_video=handoff_video[:, :, :args.exe_steps], + pred_states=handoff_state_seq, + zero_action_template=next_action_seq[0][-1:], + exe_steps=args.exe_steps, + ori_state_dim=ori_state_dim, + zero_pred_state=args.zero_pred_state) + pending_handoff_source_itr = itr + pending_handoff_ddim_step = handoff_ddim_step + print( + f'>>> Step {itr}: prepared pipeline handoff ' + f'for step {itr + 1} at ddim step ' + f'{handoff_ddim_step} ...') + + print('>' * 24) + # Send decoded segment to writer process + write_q.put(seg_video.detach().cpu()) + finally: + _stop_writer_process(writer_proc, write_q) + + # Wait for all async I/O to complete + _flush_io() def get_parser(): @@ -1620,13 +1612,19 @@ def get_parser(): help= "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) - parser.add_argument( - "--perframe_ae", - action='store_true', - default=False, - help= - "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." - ) + parser.add_argument( + "--perframe_ae", + action='store_true', + default=False, + help= + "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." + ) + parser.add_argument( + "--precision", + type=int, + default=16, + choices=[16, 32], + help="Sampling precision for latent/action/state noise initialization. Default is 16.") parser.add_argument( "--n_action_steps", type=int, @@ -1649,85 +1647,76 @@ def get_parser(): action='store_true', default=False, help="not using the predicted states as comparison") + parser.add_argument( + "--fast_policy_no_decode", + action='store_true', + default=False, + help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.") + parser.add_argument( + "--debug_wm_stats", + action='store_true', + default=False, + help="Print latent/decode statistics for world-model samples before writing video.") + parser.add_argument( + "--disable_trt", + action='store_true', + default=False, + help="Disable TensorRT backbone loading and force the PyTorch video backbone path.") + parser.add_argument( + "--trt_engine_path", + type=str, + default=None, + help="Optional explicit TensorRT engine path. Defaults to trt_engines/video_backbone.engine.") parser.add_argument("--save_fps", type=int, default=8, help="fps for the saving video") - parser.add_argument("--analysis_log_metrics", - action='store_true', - default=False, - help="Enable DDIM convergence logging and export analysis CSVs.") parser.add_argument( - "--analysis_reference_steps", + "--pipeline_split_step", type=int, - default=50, - help="Reference DDIM steps used to build the full-step baseline for *_vs_full50 metrics." - ) - parser.add_argument("--analysis_psnr_path", - type=str, - default=None, - help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.") + default=0, + help="Run DDIM sampling in two segments on a single GPU for pipeline experiments. " + "Set to a value like 25 to test 25+25 splitting; 0 disables it.") parser.add_argument( - "--analysis_profile_backbone_blocks", + "--pipeline_compare_full", action='store_true', default=False, - help="Run dense two-pass backbone block profiling and export backbone_block_log.csv.") + help="When pipeline_split_step is enabled, also run the original full DDIM pass " + "with the same initial noise and print max-abs diffs for validation.") parser.add_argument( - "--head_schedule_steps", + "--pipeline_multi_gpu", + action='store_true', + default=False, + help="Enable true asynchronous pipeline execution across multiple GPUs. " + "When disabled, keep the original single-GPU/serial inference path.") + parser.add_argument( + "--pipeline_gpu_ids", type=int, - nargs='*', - default=None, - help="Zero-based DDIM loop indices where action/state heads execute. Omit for dense execution.") - parser.add_argument( - "--head_log_steps", - type=int, - nargs='*', - default=None, - help="Zero-based DDIM loop indices to snapshot sparse action/state/latent outputs for dense-vs-sparse comparison.") - parser.add_argument( - "--head_skip_mode", - type=str, - default="reuse_prediction", - choices=["reuse_prediction", "freeze_state"], - help="Behavior on non-checkpoint steps: reuse cached head predictions while still running scheduler.step, or freeze action/state entirely.") - parser.add_argument( - "--backbone_reuse_blocks", - type=str, - nargs='*', - default=None, - help="Decoder block names to reuse on non-checkpoint DDIM steps, e.g. output_5 output_4 output_6.") - parser.add_argument( - "--backbone_reuse_start_step", - type=int, - default=None, - help="1-based DDIM step index after which decoder block reuse becomes eligible.") - parser.add_argument( - "--backbone_reuse_schedule_steps", - type=int, - nargs='*', - default=None, - help="1-based DDIM step indices where selected decoder blocks must be recomputed.") - parser.add_argument( - "--backbone_reuse_force_compute_steps", - type=int, - nargs='*', - default=None, - help="1-based DDIM step indices that always recompute selected decoder blocks, even if omitted from the reuse schedule.") - parser.add_argument( - "--backbone_reuse_mode", - type=str, - default="disabled", - choices=["disabled", "reuse_output"], - help="Decoder block reuse mode. 'reuse_output' reuses cached output block tensors on non-checkpoint steps.") + nargs='+', + default=[0, 1], + help="Logical CUDA device ids used by the multi-GPU pipeline scheduler.") return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() - seed = args.seed - if seed < 0: - seed = random.randint(0, 2**31) - seed_everything(seed) - rank, gpu_num = 0, 1 - run_inference(args, gpu_num, rank) + seed = args.seed + if seed < 0: + seed = random.randint(0, 2**31) + seed_everything(seed) + if args.pipeline_multi_gpu: + if args.pipeline_split_step <= 0: + raise ValueError( + "--pipeline_multi_gpu requires --pipeline_split_step > 0.") + if len(args.pipeline_gpu_ids) < 2: + raise ValueError( + "--pipeline_multi_gpu requires at least two gpu ids.") + runtimes = [ + load_inference_runtime(args, gpu_id) + for gpu_id in args.pipeline_gpu_ids + ] + run_inference_multi_gpu_pipeline(args, runtimes[:2]) + else: + rank, gpu_num = 0, 1 + run_inference(args, gpu_num, rank) diff --git a/scripts/export_trt.py b/scripts/export_trt.py new file mode 100644 index 0000000..ab74ec9 --- /dev/null +++ b/scripts/export_trt.py @@ -0,0 +1,209 @@ +"""Export video UNet backbone to ONNX, then convert to TensorRT engine. + +Usage: + python scripts/export_trt.py \ + --ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \ + --config configs/inference/world_model_interaction.yaml \ + --out_dir trt_engines + + python scripts/export_trt.py \ + --ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \ + --config configs/inference/world_model_interaction.yaml \ + --engine_path trt_engines/video_backbone_multigpu.engine \ + --onnx_path trt_engines/video_backbone_multigpu.onnx +""" + +import os +import sys +import argparse +import json + +import torch +import tensorrt as trt +from omegaconf import OmegaConf + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) +from unifolm_wma.utils.utils import instantiate_from_config +from unifolm_wma.trt_utils import export_backbone_onnx + + +class TacticRecorder(trt.IAlgorithmSelector): + """Pass 1: record all candidate tactics and the auto-selected winner.""" + + def __init__(self): + super().__init__() + self.records = {} # layer_name -> {candidates: [...], selected: ...} + + def select_algorithms(self, ctx, choices): + name = ctx.name + # Collect input/output shapes + inputs = [] + for j in range(ctx.num_inputs): + try: + inputs.append([int(d) for d in ctx.get_shape(j)]) + except Exception: + inputs.append(None) + outputs = [] + for j in range(ctx.num_outputs): + try: + outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)]) + except Exception: + outputs.append(None) + self.records[name] = { + "input_shapes": inputs, + "output_shapes": outputs, + "candidates": [], + "selected": None, + } + for i, c in enumerate(choices): + v = c.algorithm_variant + self.records[name]["candidates"].append({ + "index": i, + "implementation": v.implementation, + "tactic": v.tactic, + "timing_msec": c.timing_msec, + "workspace_size": c.workspace_size, + }) + # return all indices -> let TRT auto-pick the fastest + return list(range(len(choices))) + + def report_algorithms(self, ctx, choices): + # Both ctx and choices are lists in report_algorithms + for c, alg in zip(ctx, choices): + name = c.name + if name in self.records: + v = alg.algorithm_variant + self.records[name]["selected"] = { + "implementation": v.implementation, + "tactic": v.tactic, + "timing_msec": alg.timing_msec, + "workspace_size": alg.workspace_size, + } + + def save(self, path): + with open(path, "w") as f: + json.dump(self.records, f, indent=2) + print(f">>> Tactic info saved to {path} ({len(self.records)} layers)") + + +class TacticForcer(trt.IAlgorithmSelector): + """Pass 2: force user-specified tactics from a JSON file.""" + + def __init__(self, path): + super().__init__() + with open(path) as f: + self.overrides = json.load(f) + n = sum(1 for v in self.overrides.values() if v.get("force")) + print(f">>> Loaded tactic overrides: {n} layers with 'force' set") + + def select_algorithms(self, ctx, choices): + name = ctx.name + override = self.overrides.get(name) + if override and override.get("force"): + target_impl = override["force"]["implementation"] + target_tactic = override["force"]["tactic"] + for i, c in enumerate(choices): + v = c.algorithm_variant + if v.implementation == target_impl and v.tactic == target_tactic: + return [i] + print(f" WARN: forced tactic not found for {name}, using auto") + return list(range(len(choices))) + + def report_algorithms(self, ctx, choices): + pass + + +def load_model(config_path, ckpt_path, device): + if ckpt_path.endswith('.prepared.pt'): + model = torch.load(ckpt_path, map_location='cpu') + else: + config = OmegaConf.load(config_path) + model = instantiate_from_config(config.model) + state_dict = torch.load(ckpt_path, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + model.load_state_dict(state_dict, strict=False) + model.eval().to(device) + return model + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt', required=True) + parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml') + parser.add_argument('--out_dir', default='trt_engines') + parser.add_argument('--gpu_id', + type=int, + default=0, + help='CUDA device id used for ONNX export and TRT build.') + parser.add_argument('--onnx_path', + default=None, + help='Optional explicit ONNX output path. Overrides --out_dir default name.') + parser.add_argument('--engine_path', + default=None, + help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.') + parser.add_argument('--context_len', type=int, default=95) + parser.add_argument('--fp16', action='store_true', default=True) + parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON') + parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON') + args = parser.parse_args() + device = torch.device('cuda', args.gpu_id) + torch.cuda.set_device(device) + + onnx_path = args.onnx_path or os.path.join(args.out_dir, + 'video_backbone.onnx') + engine_path = args.engine_path or os.path.join(args.out_dir, + 'video_backbone.engine') + os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True) + os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True) + + if os.path.exists(onnx_path): + print(f">>> ONNX already exists at {onnx_path}, skipping export.") + n_outputs = 10 + else: + print(">>> Loading model ...") + model = load_model(args.config, args.ckpt, device) + print(">>> Exporting ONNX ...") + with torch.no_grad(): + n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len) + del model + torch.cuda.empty_cache() + + print(">>> Converting ONNX -> TensorRT engine ...") + logger = trt.Logger(trt.Logger.WARNING) + builder = trt.Builder(logger) + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + + if not parser.parse_from_file(os.path.abspath(onnx_path)): + for i in range(parser.num_errors): + print(f" ONNX parse error: {parser.get_error(i)}") + raise RuntimeError("ONNX parsing failed") + + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30) + if args.fp16: + config.set_flag(trt.BuilderFlag.FP16) + + # Tactic selection + recorder = None + if args.dump_tactics: + recorder = TacticRecorder() + config.algorithm_selector = recorder + elif args.load_tactics: + config.algorithm_selector = TacticForcer(args.load_tactics) + + engine_bytes = builder.build_serialized_network(network, config) + + if recorder and args.dump_tactics: + recorder.save(args.dump_tactics) + + with open(engine_path, 'wb') as f: + f.write(engine_bytes) + + print(f"\n>>> Done! Engine saved to {engine_path}") + print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors") + + +if __name__ == '__main__': + main() diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index 23ab20e..7e08a52 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -1,12 +1,13 @@ import numpy as np import torch import copy -import time from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import extract_into_tensor from tqdm import tqdm +from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache +from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache class DDIMSampler(object): @@ -20,8 +21,9 @@ class DDIMSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + target_device = self.model.device + if attr.device != target_device: + attr = attr.to(target_device) setattr(self, name, attr) def make_schedule(self, @@ -68,11 +70,12 @@ class DDIMSampler(object): ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + # Ensure tensors are on correct device for efficient indexing + self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas))) + self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas))) + self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev))) self.register_buffer('ddim_sqrt_one_minus_alphas', - np.sqrt(1. - ddim_alphas)) + to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas)))) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) @@ -107,17 +110,9 @@ class DDIMSampler(object): fs=None, timestep_spacing='uniform', #uniform_trailing for starting from last timestep guidance_rescale=0.0, - action_T=None, - state_T=None, - record_step_outputs=False, - head_schedule=None, - head_log_steps=None, - head_skip_mode="reuse_prediction", - backbone_reuse_blocks=None, - backbone_reuse_start_step=None, - backbone_reuse_schedule_steps=None, - backbone_reuse_force_compute_steps=None, - backbone_reuse_mode="disabled", + handoff_step: int = 0, + handoff_callback=None, + stop_at_handoff: bool = False, **kwargs): # Check condition bs @@ -173,17 +168,9 @@ class DDIMSampler(object): precision=precision, fs=fs, guidance_rescale=guidance_rescale, - action_T=action_T, - state_T=state_T, - record_step_outputs=record_step_outputs, - head_schedule=head_schedule, - head_log_steps=head_log_steps, - head_skip_mode=head_skip_mode, - backbone_reuse_blocks=backbone_reuse_blocks, - backbone_reuse_start_step=backbone_reuse_start_step, - backbone_reuse_schedule_steps=backbone_reuse_schedule_steps, - backbone_reuse_force_compute_steps=backbone_reuse_force_compute_steps, - backbone_reuse_mode=backbone_reuse_mode, + handoff_step=handoff_step, + handoff_callback=handoff_callback, + stop_at_handoff=stop_at_handoff, **kwargs) return samples, actions, states, intermediates @@ -210,44 +197,23 @@ class DDIMSampler(object): precision=None, fs=None, guidance_rescale=0.0, - action_T=None, - state_T=None, - record_step_outputs=False, - head_schedule=None, - head_log_steps=None, - head_skip_mode="reuse_prediction", - backbone_reuse_blocks=None, - backbone_reuse_start_step=None, - backbone_reuse_schedule_steps=None, - backbone_reuse_force_compute_steps=None, - backbone_reuse_mode="disabled", + handoff_step: int = 0, + handoff_callback=None, + stop_at_handoff: bool = False, **kwargs): device = self.model.betas.device dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state b = shape[0] - horizon = shape[2] if len(shape) >= 3 else 16 if x_T is None: img = torch.randn(shape, device=device) + action = torch.randn((b, 16, self.model.agent_action_dim), device=device) + state = torch.randn((b, 16, self.model.agent_state_dim), device=device) else: img = x_T - if action_T is None: - action = torch.randn((b, horizon, self.model.agent_action_dim), - device=device) - else: - action = action_T - if state_T is None: - state = torch.randn((b, horizon, self.model.agent_state_dim), - device=device) - else: - state = state_T - - if precision is not None: - if precision == 16: - img = img.to(dtype=torch.float16) - action = action.to(dtype=torch.float16) - state = state.to(dtype=torch.float16) + action = torch.randn((b, 16, self.model.agent_action_dim), device=device) + state = torch.randn((b, 16, self.model.agent_state_dim), device=device) if timesteps is None: timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps @@ -265,40 +231,6 @@ class DDIMSampler(object): 'x_inter_state': [state], 'pred_x0_state': [state], } - if record_step_outputs: - intermediates['analysis_init'] = { - 'img': img.detach().cpu(), - 'action': action.detach().cpu(), - 'state': state.detach().cpu(), - } - intermediates['step_records'] = [] - head_schedule_set = None if head_schedule is None else { - int(step_index) for step_index in head_schedule - } - head_log_steps_set = None if head_log_steps is None else { - int(step_index) for step_index in head_log_steps - } - if head_log_steps_set is not None: - intermediates['head_sparse_logs'] = {} - backbone_reuse_blocks_set = None if backbone_reuse_blocks is None else { - str(block_name) for block_name in backbone_reuse_blocks - } - backbone_reuse_schedule_steps_set = ( - None if backbone_reuse_schedule_steps is None else { - int(step_index) for step_index in backbone_reuse_schedule_steps - }) - backbone_reuse_force_compute_steps_set = ( - None if backbone_reuse_force_compute_steps is None else { - int(step_index) - for step_index in backbone_reuse_force_compute_steps - }) - backbone_reuse_active = (backbone_reuse_mode == "reuse_output" - and backbone_reuse_blocks_set) - backbone_reuse_cache = { - 'single': {}, - 'cond': {}, - 'uncond': {}, - } time_range = reversed(range( 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ @@ -309,89 +241,50 @@ class DDIMSampler(object): iterator = time_range clean_cond = kwargs.pop("clean_cond", False) - sync_device = device if isinstance(device, torch.device) else torch.device( - device) - should_sync = record_step_outputs and sync_device.type == "cuda" - if head_skip_mode not in {"reuse_prediction", "freeze_state"}: - raise ValueError( - f"Unsupported head_skip_mode={head_skip_mode!r}. " - "Expected 'reuse_prediction' or 'freeze_state'.") - if backbone_reuse_mode not in {"disabled", "reuse_output"}: - raise ValueError( - f"Unsupported backbone_reuse_mode={backbone_reuse_mode!r}. " - "Expected 'disabled' or 'reuse_output'.") - x_action_frozen = action.detach().clone() - x_state_frozen = state.detach().clone() - action_pred_cache = None - state_pred_cache = None - dp_ddim_scheduler_action.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps)) - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b, ), step, device=device, dtype=torch.long) + ts = torch.empty((b, ), device=device, dtype=torch.long) + enable_cross_attn_kv_cache(self.model) + enable_ctx_cache(self.model) + handoff = {} + try: + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts.fill_(step) - # Use mask to blend noised original latent (img_orig) & new sampled latent (img) - if mask is not None: - assert x0 is not None - if clean_cond: - img_orig = x0 - else: - img_orig = self.model.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + # Use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img - if should_sync: - torch.cuda.synchronize(sync_device) - step_start_time = time.time() - scheduled_head = head_schedule_set is None or i in head_schedule_set - if head_skip_mode == "reuse_prediction": - run_head = scheduled_head or action_pred_cache is None or state_pred_cache is None - else: - run_head = scheduled_head - backbone_reuse_step_stats: dict[str, set[str]] | None = None - if backbone_reuse_active: - backbone_reuse_step_stats = { - 'single': set(), - 'cond': set(), - 'uncond': set(), - } + outs = self.p_sample_ddim( + img, + action, + state, + cond, + ts, + index=index, + precision=precision, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + x0=x0, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) - outs = self.p_sample_ddim( - img, - action, - state, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - mask=mask, - x0=x0, - fs=fs, - guidance_rescale=guidance_rescale, - backbone_step_index=i + 1, - run_head=run_head, - backbone_reuse_blocks=backbone_reuse_blocks_set, - backbone_reuse_start_step=backbone_reuse_start_step, - backbone_reuse_schedule_steps=backbone_reuse_schedule_steps_set, - backbone_reuse_force_compute_steps= - backbone_reuse_force_compute_steps_set, - backbone_reuse_mode=backbone_reuse_mode, - backbone_reuse_cache=backbone_reuse_cache, - backbone_reuse_step_stats=backbone_reuse_step_stats, - **kwargs) + img, pred_x0, model_output_action, model_output_state = outs - img, pred_x0, model_output_action, model_output_state = outs - - if run_head: - action_pred_cache = model_output_action.detach().clone() - state_pred_cache = model_output_state.detach().clone() action = dp_ddim_scheduler_action.step( model_output_action, step, @@ -404,73 +297,35 @@ class DDIMSampler(object): state, generator=None, ).prev_sample - x_action_frozen = action.detach().clone() - x_state_frozen = state.detach().clone() - else: - if head_skip_mode == "reuse_prediction": - action = dp_ddim_scheduler_action.step( - action_pred_cache, - step, - action, - generator=None, - ).prev_sample - state = dp_ddim_scheduler_state.step( - state_pred_cache, - step, - state, - generator=None, - ).prev_sample - x_action_frozen = action.detach().clone() - x_state_frozen = state.detach().clone() - else: - action = x_action_frozen - state = x_state_frozen - if should_sync: - torch.cuda.synchronize(sync_device) - step_time_s = time.time() - step_start_time + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if handoff_step > 0 and (i + 1) == handoff_step: + handoff = { + 'samples': img.clone(), + 'actions': action.clone(), + 'states': state.clone(), + 'pred_x0': pred_x0.clone(), + 'step': i + 1, + } + if handoff_callback is not None: + handoff_callback(handoff) + if stop_at_handoff: + intermediates['handoff'] = handoff + break - reused_blocks = [] - if backbone_reuse_step_stats is not None: - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - reused_blocks = sorted(backbone_reuse_step_stats['single']) - else: - reused_blocks = sorted( - backbone_reuse_step_stats['cond'] - | backbone_reuse_step_stats['uncond']) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - intermediates['x_inter_action'].append(action) - intermediates['x_inter_state'].append(state) - if head_log_steps_set is not None and i in head_log_steps_set: - intermediates['head_sparse_logs'][i] = { - 'step_index': i, - 'ddim_timestep': int(step), - 'head_executed': run_head, - 'img': img.detach().cpu(), - 'pred_x0': pred_x0.detach().cpu(), - 'action': action.detach().cpu(), - 'state': state.detach().cpu(), - } - if record_step_outputs: - intermediates['step_records'].append({ - 'step_index': i + 1, - 'ddim_timestep': int(step), - 'head_executed': run_head, - 'img': img.detach().cpu(), - 'pred_x0': pred_x0.detach().cpu(), - 'action': action.detach().cpu(), - 'state': state.detach().cpu(), - 'step_time_s': step_time_s, - 'backbone_reused_blocks_count': len(reused_blocks), - 'backbone_reuse_hit_blocks': ",".join(reused_blocks), - }) + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + intermediates['x_inter_action'].append(action) + intermediates['x_inter_state'].append(state) + finally: + disable_cross_attn_kv_cache(self.model) + disable_ctx_cache(self.model) + if handoff_step > 0: + intermediates['handoff'] = handoff return img, action, state, intermediates @torch.no_grad() @@ -481,6 +336,7 @@ class DDIMSampler(object): c, t, index, + precision=None, repeat_noise=False, use_original_steps=False, quantize_denoised=False, @@ -495,62 +351,35 @@ class DDIMSampler(object): mask=None, x0=None, guidance_rescale=0.0, - run_head=True, **kwargs): b, *_, device = *x.shape, x.device - if x.dim() == 5: - is_video = True - else: - is_video = False - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - model_output, model_output_action, model_output_state = self.model.apply_model( - x, - x_action, - x_state, - t, - c, - run_head=run_head, - backbone_reuse_branch="single", - **kwargs) # unet denoiser - else: - # do_classifier_free_guidance - if isinstance(c, torch.Tensor) or isinstance(c, dict): - e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model( - x, - x_action, - x_state, - t, - c, - run_head=run_head, - backbone_reuse_branch="cond", - **kwargs) - e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model( - x, - x_action, - x_state, - t, - unconditional_conditioning, - run_head=run_head, - backbone_reuse_branch="uncond", - **kwargs) + use_autocast = precision == 16 and device.type == 'cuda' + with torch.cuda.amp.autocast(enabled=use_autocast, + dtype=torch.float16): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output, model_output_action, model_output_state = self.model.apply_model( + x, x_action, x_state, t, c, **kwargs) # unet denoiser else: - raise NotImplementedError - model_output = e_t_uncond + unconditional_guidance_scale * ( - e_t_cond - e_t_uncond) - if run_head: + # do_classifier_free_guidance + if isinstance(c, torch.Tensor) or isinstance(c, dict): + e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model( + x, x_action, x_state, t, c, **kwargs) + e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model( + x, x_action, x_state, t, unconditional_conditioning, + **kwargs) + else: + raise NotImplementedError + model_output = e_t_uncond + unconditional_guidance_scale * ( + e_t_cond - e_t_uncond) model_output_action = e_t_uncond_action + unconditional_guidance_scale * ( e_t_cond_action - e_t_uncond_action) model_output_state = e_t_uncond_state + unconditional_guidance_scale * ( e_t_cond_state - e_t_uncond_state) - else: - model_output_action = None - model_output_state = None - if guidance_rescale > 0.0: - model_output = rescale_noise_cfg( - model_output, e_t_cond, guidance_rescale=guidance_rescale) - if run_head: + if guidance_rescale > 0.0: + model_output = rescale_noise_cfg( + model_output, e_t_cond, guidance_rescale=guidance_rescale) model_output_action = rescale_noise_cfg( model_output_action, e_t_cond_action, @@ -575,17 +404,11 @@ class DDIMSampler(object): sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - if is_video: - size = (b, 1, 1, 1, 1) - else: - size = (b, 1, 1, 1) - - a_t = torch.full(size, alphas[index], device=device) - a_prev = torch.full(size, alphas_prev[index], device=device) - sigma_t = torch.full(size, sigmas[index], device=device) - sqrt_one_minus_at = torch.full(size, - sqrt_one_minus_alphas[index], - device=device) + # Use 0-d tensors directly (already on device); broadcasting handles shape + a_t = alphas[index] + a_prev = alphas_prev[index] + sigma_t = sigmas[index] + sqrt_one_minus_at = sqrt_one_minus_alphas[index] if self.model.parameterization != "v": pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -593,12 +416,8 @@ class DDIMSampler(object): pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) if self.model.use_dynamic_rescale: - scale_t = torch.full(size, - self.ddim_scale_arr[index], - device=device) - prev_scale_t = torch.full(size, - self.ddim_scale_arr_prev[index], - device=device) + scale_t = self.ddim_scale_arr[index] + prev_scale_t = self.ddim_scale_arr_prev[index] rescale = (prev_scale_t / scale_t) pred_x0 *= rescale diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py index 79be322..ab91932 100644 --- a/src/unifolm_wma/modules/networks/wma_model.py +++ b/src/unifolm_wma/modules/networks/wma_model.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import time from torch import Tensor from functools import partial @@ -686,6 +685,37 @@ class WMAModel(nn.Module): self.action_token_projector = instantiate_from_config( stem_process_config) + # Context precomputation cache + self._ctx_cache_enabled = False + self._ctx_cache = {} + self._trt_backbone = None # TRT engine for video UNet backbone + # Reusable CUDA stream for parallel state_unet / action_unet + self._state_stream = torch.cuda.Stream() + + def __getstate__(self): + state = self.__dict__.copy() + state.pop('_state_stream', None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, '_ctx_cache_enabled'): + self._ctx_cache_enabled = False + if not hasattr(self, '_ctx_cache'): + self._ctx_cache = {} + if not hasattr(self, '_trt_backbone'): + self._trt_backbone = None + self._state_stream = torch.cuda.Stream() + + def load_trt_backbone(self, engine_path, n_hs_a=9): + """Load a TensorRT engine for the video UNet backbone.""" + from unifolm_wma.trt_utils import TRTBackbone + device = next(self.parameters()).device + self._trt_backbone = TRTBackbone(engine_path, + n_hs_a=n_hs_a, + device=device) + print(f">>> TRT backbone loaded from {engine_path} on {device}") + def forward(self, x: Tensor, x_action: Tensor, @@ -714,80 +744,70 @@ class WMAModel(nn.Module): Tuple of Tensors for predictions: """ - b, _, t, _, _ = x.shape - run_head = kwargs.pop("run_head", True) - backbone_block_profiler = kwargs.pop("backbone_block_profiler", None) - backbone_step_index = kwargs.pop("backbone_step_index", None) - backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None) - backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step", - None) - backbone_reuse_schedule_steps = kwargs.pop( - "backbone_reuse_schedule_steps", None) - backbone_reuse_force_compute_steps = kwargs.pop( - "backbone_reuse_force_compute_steps", None) - backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled") - backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None) - backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats", - None) - backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single") t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).type(x.dtype) emb = self.time_embed(t_emb) - bt, l_context, _ = context.shape - if self.base_model_gen_only: - assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + _ctx_key = context.data_ptr() + if self._ctx_cache_enabled and _ctx_key in self._ctx_cache: + context = self._ctx_cache[_ctx_key] else: - if l_context == self.n_obs_steps + 77 + t * 16: - context_agent_state = context[:, :self.n_obs_steps] - context_text = context[:, self.n_obs_steps:self.n_obs_steps + - 77, :] - context_img = context[:, self.n_obs_steps + 77:, :] - context_agent_state = context_agent_state.repeat_interleave( - repeats=t, dim=0) - context_text = context_text.repeat_interleave(repeats=t, dim=0) - context_img = rearrange(context_img, - 'b (t l) c -> (b t) l c', - t=t) - context = torch.cat( - [context_agent_state, context_text, context_img], dim=1) - elif l_context == self.n_obs_steps + 16 + 77 + t * 16: - context_agent_state = context[:, :self.n_obs_steps] - context_agent_action = context[:, self. - n_obs_steps:self.n_obs_steps + - 16, :] - context_agent_action = rearrange( - context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') - context_agent_action = self.action_token_projector( - context_agent_action) - context_agent_action = rearrange(context_agent_action, - '(b o) l d -> b o l d', - o=t) - context_agent_action = rearrange(context_agent_action, - 'b o (t l) d -> b o t l d', - t=t) - context_agent_action = context_agent_action.permute( - 0, 2, 1, 3, 4) - context_agent_action = rearrange(context_agent_action, - 'b t o l d -> (b t) (o l) d') + bt, l_context, _ = context.shape + if self.base_model_gen_only: + assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE + else: + if l_context == self.n_obs_steps + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_text = context[:, self.n_obs_steps:self.n_obs_steps + + 77, :] + context_img = context[:, self.n_obs_steps + 77:, :] + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context = torch.cat( + [context_agent_state, context_text, context_img], dim=1) + elif l_context == self.n_obs_steps + 16 + 77 + t * 16: + context_agent_state = context[:, :self.n_obs_steps] + context_agent_action = context[:, self. + n_obs_steps:self.n_obs_steps + + 16, :] + context_agent_action = rearrange( + context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') + context_agent_action = self.action_token_projector( + context_agent_action) + context_agent_action = rearrange(context_agent_action, + '(b o) l d -> b o l d', + o=t) + context_agent_action = rearrange(context_agent_action, + 'b o (t l) d -> b o t l d', + t=t) + context_agent_action = context_agent_action.permute( + 0, 2, 1, 3, 4) + context_agent_action = rearrange(context_agent_action, + 'b t o l d -> (b t) (o l) d') - context_text = context[:, self.n_obs_steps + - 16:self.n_obs_steps + 16 + 77, :] - context_text = context_text.repeat_interleave(repeats=t, dim=0) + context_text = context[:, self.n_obs_steps + + 16:self.n_obs_steps + 16 + 77, :] + context_text = context_text.repeat_interleave(repeats=t, dim=0) - context_img = context[:, self.n_obs_steps + 16 + 77:, :] - context_img = rearrange(context_img, - 'b (t l) c -> (b t) l c', - t=t) - context_agent_state = context_agent_state.repeat_interleave( - repeats=t, dim=0) - context = torch.cat([ - context_agent_state, context_agent_action, context_text, - context_img - ], - dim=1) + context_img = context[:, self.n_obs_steps + 16 + 77:, :] + context_img = rearrange(context_img, + 'b (t l) c -> (b t) l c', + t=t) + context_agent_state = context_agent_state.repeat_interleave( + repeats=t, dim=0) + context = torch.cat([ + context_agent_state, context_agent_action, context_text, + context_img + ], + dim=1) + if self._ctx_cache_enabled: + self._ctx_cache[_ctx_key] = context emb = emb.repeat_interleave(repeats=t, dim=0) @@ -807,150 +827,95 @@ class WMAModel(nn.Module): fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) emb = emb + fs_embed - def run_block_with_profile(block_name: str, block_stage: str, - block_index: int | None, - fn: Callable[[], Tensor]) -> Tensor: - if backbone_block_profiler is None or backbone_step_index is None: - return fn() - if x.device.type == "cuda": - torch.cuda.synchronize(x.device) - start_time = time.perf_counter() - out = fn() - if x.device.type == "cuda": - torch.cuda.synchronize(x.device) - backbone_block_profiler.record_block( - step=int(backbone_step_index), - block_name=block_name, - block_stage=block_stage, - block_index=block_index, - output=out, - forward_time_ms=(time.perf_counter() - start_time) * 1000.0, - ) - return out - - reuse_cache_branch: Dict[str, Tensor] | None = None - if backbone_reuse_cache is not None: - reuse_cache_branch = backbone_reuse_cache.setdefault( - backbone_reuse_branch, {}) - - def should_reuse_output_block(block_name: str) -> bool: - if backbone_reuse_mode != "reuse_output": - return False - if backbone_step_index is None or backbone_reuse_start_step is None: - return False - if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks: - return False - if int(backbone_step_index) < int(backbone_reuse_start_step): - return False - if (backbone_reuse_force_compute_steps is not None - and int(backbone_step_index) - in backbone_reuse_force_compute_steps): - return False - if (backbone_reuse_schedule_steps is not None - and int(backbone_step_index) - in backbone_reuse_schedule_steps): - return False - if reuse_cache_branch is None: - return False - return block_name in reuse_cache_branch - - h = x.type(self.dtype) - adapter_idx = 0 - hs = [] - hs_a = [] - for id, module in enumerate(self.input_blocks): - def run_input_block() -> Tensor: - block_out = module(h, emb, context=context, batch_size=b) + if self._trt_backbone is not None: + # TRT path: run backbone via TensorRT engine + h_in = x.type(self.dtype).contiguous() + y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous()) + else: + # PyTorch path: original backbone + h = x.type(self.dtype) + adapter_idx = 0 + hs = [] + hs_a = [] + for id, module in enumerate(self.input_blocks): + h = module(h, emb, context=context, batch_size=b) if id == 0 and self.addition_attention: - block_out = self.init_attn(block_out, - emb, - context=context, - batch_size=b) - return block_out + h = self.init_attn(h, emb, context=context, batch_size=b) + # plug-in adapter features + if ((id + 1) % 3 == 0) and features_adapter is not None: + h = h + features_adapter[adapter_idx] + adapter_idx += 1 + if id != 0: + if isinstance(module[0], Downsample): + hs_a.append( + rearrange(hs[-1], '(b t) c h w -> b t c h w', b=b)) + hs.append(h) + hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b)) - h = run_block_with_profile( - block_name=f"input_{id}", - block_stage="input_blocks", - block_index=id, - fn=run_input_block, - ) - # plug-in adapter features - if ((id + 1) % 3 == 0) and features_adapter is not None: - h = h + features_adapter[adapter_idx] - adapter_idx += 1 - if id != 0: - if isinstance(module[0], Downsample): + if features_adapter is not None: + assert len( + features_adapter) == adapter_idx, 'Wrong features_adapter' + h = self.middle_block(h, emb, context=context, batch_size=b) + hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b)) + + hs_out = [] + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context=context, batch_size=b) + if isinstance(module[-1], Upsample): hs_a.append( - rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t)) - hs.append(h) - hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) + rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b)) + hs_out.append(h) + h = h.type(x.dtype) + hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b)) - if features_adapter is not None: - assert len( - features_adapter) == adapter_idx, 'Wrong features_adapter' - h = run_block_with_profile( - block_name="middle", - block_stage="middle_block", - block_index=0, - fn=lambda: self.middle_block(h, emb, context=context, batch_size=b), - ) - hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) + y = self.out(h) + y = rearrange(y, '(b t) c h w -> b c t h w', b=b) - hs_out = [] - for id, module in enumerate(self.output_blocks): - skip_h = hs.pop() - block_name = f"output_{id}" - - def run_output_block() -> Tensor: - return module(torch.cat([h, skip_h], dim=1), - emb, - context=context, - batch_size=b) - - if should_reuse_output_block(block_name): - h = reuse_cache_branch[block_name].to(device=h.device, - dtype=h.dtype) - if backbone_reuse_step_stats is not None: - backbone_reuse_step_stats.setdefault( - backbone_reuse_branch, set()).add(block_name) - else: - h = run_block_with_profile( - block_name=block_name, - block_stage="output_blocks", - block_index=id, - fn=run_output_block, - ) - if (reuse_cache_branch is not None and backbone_reuse_mode == - "reuse_output" - and backbone_reuse_blocks is not None - and block_name in backbone_reuse_blocks): - reuse_cache_branch[block_name] = h.detach().clone() - if isinstance(module[-1], Upsample): - hs_a.append( - rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) - hs_out.append(h) - h = h.type(x.dtype) - hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) - - y = self.out(h) - y = rearrange(y, '(b t) c h w -> b c t h w', b=b) - - if not self.base_model_gen_only and run_head: + if not self.base_model_gen_only: ba, _, _ = x_action.shape - a_y = self.action_unet(x_action, timesteps[:ba], hs_a, - context_action[:2], **kwargs) - # Predict state - if b > 1: - s_y = self.state_unet(x_state, timesteps[:ba], hs_a, + ts_state = timesteps[:ba] if b > 1 else timesteps + is_sim_mode = context_action[2] if len(context_action) > 2 else False + + if is_sim_mode: + # WM mode: only need state_unet, skip action_unet + s_y = self.state_unet(x_state, ts_state, hs_a, context_action[:2], **kwargs) + a_y = torch.zeros_like(x_action) else: - s_y = self.state_unet(x_state, timesteps, hs_a, - context_action[:2], **kwargs) - elif not self.base_model_gen_only: - a_y = None - s_y = None + # DM mode: only need action_unet, skip state_unet + a_y = self.action_unet(x_action, timesteps[:ba], hs_a, + context_action[:2], **kwargs) + s_y = torch.zeros_like(x_state) else: a_y = torch.zeros_like(x_action) s_y = torch.zeros_like(x_state) return y, a_y, s_y + + +def enable_ctx_cache(model): + """Enable context precomputation cache on WMAModel and its action/state UNets.""" + for m in model.modules(): + if isinstance(m, WMAModel): + m._ctx_cache_enabled = True + m._ctx_cache = {} + # conditional_unet1d cache + from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + for m in model.modules(): + if isinstance(m, ConditionalUnet1D): + m._global_cond_cache_enabled = True + m._global_cond_cache = {} + + +def disable_ctx_cache(model): + """Disable and clear context precomputation cache.""" + for m in model.modules(): + if isinstance(m, WMAModel): + m._ctx_cache_enabled = False + m._ctx_cache = {} + from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D + for m in model.modules(): + if isinstance(m, ConditionalUnet1D): + m._global_cond_cache_enabled = False + m._global_cond_cache = {} diff --git a/src/unifolm_wma/trt_utils.py b/src/unifolm_wma/trt_utils.py new file mode 100644 index 0000000..6ce6939 --- /dev/null +++ b/src/unifolm_wma/trt_utils.py @@ -0,0 +1,196 @@ +"""TensorRT acceleration utilities for the video UNet backbone.""" + +import torch +import torch.nn as nn +from einops import rearrange +from unifolm_wma.modules.networks.wma_model import Downsample, Upsample + + +def _normalize_cuda_device(device) -> torch.device: + if device is None: + return torch.device('cuda', torch.cuda.current_device()) + if isinstance(device, torch.device): + if device.type != 'cuda': + raise ValueError(f"TensorRT requires a CUDA device, got {device}.") + return device + if isinstance(device, int): + return torch.device('cuda', device) + normalized = torch.device(device) + if normalized.type != 'cuda': + raise ValueError(f"TensorRT requires a CUDA device, got {normalized}.") + return normalized + + +class VideoBackboneForExport(nn.Module): + """Wrapper that isolates the video UNet backbone for ONNX export. + + Takes already-preprocessed inputs (after context/time embedding prep) + and returns y + hs_a as a flat tuple. + """ + + def __init__(self, wma_model): + super().__init__() + self.input_blocks = wma_model.input_blocks + self.middle_block = wma_model.middle_block + self.output_blocks = wma_model.output_blocks + self.out = wma_model.out + self.addition_attention = wma_model.addition_attention + if self.addition_attention: + self.init_attn = wma_model.init_attn + self.dtype = wma_model.dtype + + def forward(self, h, emb, context): + t = 16 + b = 1 + + hs = [] + hs_a = [] + h = h.type(self.dtype) + for id, module in enumerate(self.input_blocks): + h = module(h, emb, context=context, batch_size=b) + if id == 0 and self.addition_attention: + h = self.init_attn(h, emb, context=context, batch_size=b) + if id != 0: + if isinstance(module[0], Downsample): + hs_a.append(rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t)) + hs.append(h) + hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) + + h = self.middle_block(h, emb, context=context, batch_size=b) + hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) + + hs_out = [] + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context=context, batch_size=b) + if isinstance(module[-1], Upsample): + hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) + hs_out.append(h) + hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) + + y = self.out(h.type(h.dtype)) + y = rearrange(y, '(b t) c h w -> b c t h w', b=b) + return (y, *hs_a) + + +def export_backbone_onnx(model, save_path, context_len=95): + wma = model.model.diffusion_model + wrapper = VideoBackboneForExport(wma) + device = next(wma.parameters()).device + wrapper.eval().to(device) + + for m in wrapper.modules(): + if hasattr(m, 'checkpoint'): + m.checkpoint = False + if hasattr(m, 'use_checkpoint'): + m.use_checkpoint = False + + import xformers.ops + _orig_mea = xformers.ops.memory_efficient_attention + def _sdpa_replacement(q, k, v, attn_bias=None, op=None, **kw): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + xformers.ops.memory_efficient_attention = _sdpa_replacement + + BT = 16 + emb_dim = wma.model_channels * 4 + ctx_dim = 1024 + in_ch = wma.in_channels + + dummy_h = torch.randn(BT, + in_ch, + 40, + 64, + device=device, + dtype=torch.float32) + dummy_emb = torch.randn(BT, + emb_dim, + device=device, + dtype=torch.float32) + dummy_ctx = torch.randn(BT, + context_len, + ctx_dim, + device=device, + dtype=torch.float32) + + with torch.no_grad(): + outputs = wrapper(dummy_h, dummy_emb, dummy_ctx) + n_outputs = len(outputs) + print(f">>> Backbone has {n_outputs} outputs (1 y + {n_outputs-1} hs_a)") + for i, o in enumerate(outputs): + print(f" output[{i}]: {o.shape} {o.dtype}") + + output_names = ['y'] + [f'hs_a_{i}' for i in range(n_outputs - 1)] + + torch.onnx.export( + wrapper, + (dummy_h, dummy_emb, dummy_ctx), + save_path, + input_names=['h', 'emb', 'context'], + output_names=output_names, + opset_version=17, + do_constant_folding=True, + ) + print(f">>> ONNX exported to {save_path}") + xformers.ops.memory_efficient_attention = _orig_mea + return n_outputs + + +class TRTBackbone: + """TensorRT runtime wrapper for the video UNet backbone.""" + + def __init__(self, engine_path, n_hs_a=9, device=None): + import tensorrt as trt + + self.device = _normalize_cuda_device(device) + self.logger = trt.Logger(trt.Logger.WARNING) + with torch.cuda.device(self.device): + with open(engine_path, 'rb') as f: + runtime = trt.Runtime(self.logger) + self.engine = runtime.deserialize_cuda_engine(f.read()) + self.context = self.engine.create_execution_context() + self.n_hs_a = n_hs_a + + import numpy as np + self.output_buffers = {} + with torch.cuda.device(self.device): + for i in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(i) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: + shape = self.engine.get_tensor_shape(name) + np_dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + buf = torch.empty( + list(shape), + dtype=torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype, + device=self.device) + self.output_buffers[name] = buf + print( + f" TRT output '{name}': {list(shape)} {buf.dtype} on {self.device}" + ) + + def __call__(self, h, emb, context): + import numpy as np + import tensorrt as trt + + bound_inputs = {} + with torch.cuda.device(self.device): + for name, tensor in [('h', h), ('emb', emb), ('context', context)]: + expected_dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + torch_expected = torch.from_numpy( + np.empty(0, dtype=expected_dtype)).dtype + if tensor.device != self.device or tensor.dtype != torch_expected: + tensor = tensor.to(device=self.device, + dtype=torch_expected, + non_blocking=True) + tensor = tensor.contiguous() + bound_inputs[name] = tensor + self.context.set_tensor_address(name, tensor.data_ptr()) + + for name, buf in self.output_buffers.items(): + self.context.set_tensor_address(name, buf.data_ptr()) + + stream = torch.cuda.current_stream(device=self.device) + self.context.execute_async_v3(stream.cuda_stream) + + y = self.output_buffers['y'] + hs_a = [self.output_buffers[f'hs_a_{i}'] for i in range(self.n_hs_a)] + return y, hs_a diff --git a/unitree_z1_dual_arm_stackbox_v2/case2/run_world_model_interaction.sh b/unitree_z1_dual_arm_stackbox_v2/case2/run_world_model_interaction.sh index 44d2449..725958a 100644 --- a/unitree_z1_dual_arm_stackbox_v2/case2/run_world_model_interaction.sh +++ b/unitree_z1_dual_arm_stackbox_v2/case2/run_world_model_interaction.sh @@ -2,11 +2,11 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case2" dataset="unitree_z1_dual_arm_stackbox_v2" { - time CUDA_VISIBLE_DEVICES=0 "${PYTHON_BIN:-python}" scripts/evaluation/world_model_interaction.py \ + time CUDA_VISIBLE_DEVICES=0,1 python3 scripts/evaluation/world_model_interaction.py \ --seed 123 \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \ --config configs/inference/world_model_interaction.yaml \ - --savedir "${res_dir}/output/sparse_8" \ + --savedir "${res_dir}/output" \ --bs 1 --height 320 --width 512 \ --unconditional_guidance_scale 1.0 \ --ddim_steps 50 \ @@ -21,9 +21,7 @@ dataset="unitree_z1_dual_arm_stackbox_v2" --timestep_spacing 'uniform_trailing' \ --guidance_rescale 0.7 \ --perframe_ae \ - --analysis_log_metrics \ - --analysis_reference_steps 50 \ - --head_schedule_steps 0 7 14 21 28 35 42 49 \ - --head_skip_mode reuse_prediction \ - --head_log_steps 40 43 46 47 48 49 + --pipeline_split_step 30 \ + --pipeline_multi_gpu \ + --pipeline_gpu_ids 0 1 } 2>&1 | tee "${res_dir}/output.log"