1756 lines
71 KiB
Python
1756 lines
71 KiB
Python
import argparse, os, glob
|
|
import pandas as pd
|
|
import random
|
|
import torch
|
|
import torchvision
|
|
import h5py
|
|
import numpy as np
|
|
import logging
|
|
import einops
|
|
import warnings
|
|
import imageio
|
|
import time
|
|
import json
|
|
import atexit
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import contextmanager, nullcontext
|
|
from dataclasses import dataclass, field, asdict
|
|
from typing import Optional, Dict, List, Any, Mapping
|
|
|
|
from pytorch_lightning import seed_everything
|
|
from omegaconf import OmegaConf
|
|
from tqdm import tqdm
|
|
from einops import rearrange, repeat
|
|
from collections import OrderedDict
|
|
from torch import nn
|
|
from eval_utils import populate_queues
|
|
from collections import deque
|
|
from torch import Tensor
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from PIL import Image
|
|
|
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
|
|
|
|
# ========== Profiling Infrastructure ==========
|
|
@dataclass
|
|
class TimingRecord:
|
|
"""Record for a single timing measurement."""
|
|
name: str
|
|
start_time: float = 0.0
|
|
end_time: float = 0.0
|
|
cuda_time_ms: float = 0.0
|
|
count: int = 0
|
|
children: List['TimingRecord'] = field(default_factory=list)
|
|
|
|
@property
|
|
def cpu_time_ms(self) -> float:
|
|
return (self.end_time - self.start_time) * 1000
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
'name': self.name,
|
|
'cpu_time_ms': self.cpu_time_ms,
|
|
'cuda_time_ms': self.cuda_time_ms,
|
|
'count': self.count,
|
|
'children': [c.to_dict() for c in self.children]
|
|
}
|
|
|
|
|
|
class ProfilerManager:
|
|
"""Manages macro and micro-level profiling."""
|
|
|
|
def __init__(
|
|
self,
|
|
enabled: bool = False,
|
|
output_dir: str = "./profile_output",
|
|
profile_detail: str = "light",
|
|
):
|
|
self.enabled = enabled
|
|
self.output_dir = output_dir
|
|
self.profile_detail = profile_detail
|
|
self.macro_timings: Dict[str, List[float]] = {}
|
|
self.cuda_events: Dict[str, List[tuple]] = {}
|
|
self.memory_snapshots: List[Dict] = []
|
|
self.pytorch_profiler = None
|
|
self.current_iteration = 0
|
|
self.operator_stats: Dict[str, Dict] = {}
|
|
self.profiler_config = self._build_profiler_config(profile_detail)
|
|
|
|
if enabled:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
|
|
"""Return profiler settings based on the requested detail level."""
|
|
if profile_detail not in ("light", "full"):
|
|
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
|
|
if profile_detail == "full":
|
|
return {
|
|
"record_shapes": True,
|
|
"profile_memory": True,
|
|
"with_stack": True,
|
|
"with_flops": True,
|
|
"with_modules": True,
|
|
"group_by_input_shape": True,
|
|
}
|
|
return {
|
|
"record_shapes": False,
|
|
"profile_memory": False,
|
|
"with_stack": False,
|
|
"with_flops": False,
|
|
"with_modules": False,
|
|
"group_by_input_shape": False,
|
|
}
|
|
|
|
@contextmanager
|
|
def profile_section(self, name: str, sync_cuda: bool = True):
|
|
"""Context manager for profiling a code section."""
|
|
if not self.enabled:
|
|
yield
|
|
return
|
|
|
|
if sync_cuda and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
start_event = None
|
|
end_event = None
|
|
if torch.cuda.is_available():
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
start_event.record()
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
if sync_cuda and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
end_time = time.perf_counter()
|
|
cpu_time_ms = (end_time - start_time) * 1000
|
|
|
|
cuda_time_ms = 0.0
|
|
if start_event is not None and end_event is not None:
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
cuda_time_ms = start_event.elapsed_time(end_event)
|
|
|
|
if name not in self.macro_timings:
|
|
self.macro_timings[name] = []
|
|
self.macro_timings[name].append(cpu_time_ms)
|
|
|
|
if name not in self.cuda_events:
|
|
self.cuda_events[name] = []
|
|
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
|
|
|
|
def record_memory(self, tag: str = ""):
|
|
"""Record current GPU memory state."""
|
|
if not self.enabled or not torch.cuda.is_available():
|
|
return
|
|
|
|
snapshot = {
|
|
'tag': tag,
|
|
'iteration': self.current_iteration,
|
|
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
|
|
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
|
|
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
|
|
}
|
|
self.memory_snapshots.append(snapshot)
|
|
|
|
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
|
|
"""Start PyTorch profiler for operator-level analysis."""
|
|
if not self.enabled:
|
|
return nullcontext()
|
|
|
|
self.pytorch_profiler = torch.profiler.profile(
|
|
activities=[
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
],
|
|
schedule=torch.profiler.schedule(
|
|
wait=wait, warmup=warmup, active=active, repeat=1
|
|
),
|
|
on_trace_ready=self._trace_handler,
|
|
record_shapes=self.profiler_config["record_shapes"],
|
|
profile_memory=self.profiler_config["profile_memory"],
|
|
with_stack=self.profiler_config["with_stack"],
|
|
with_flops=self.profiler_config["with_flops"],
|
|
with_modules=self.profiler_config["with_modules"],
|
|
)
|
|
return self.pytorch_profiler
|
|
|
|
def _trace_handler(self, prof):
|
|
"""Handle profiler trace output."""
|
|
trace_path = os.path.join(
|
|
self.output_dir,
|
|
f"trace_iter_{self.current_iteration}.json"
|
|
)
|
|
prof.export_chrome_trace(trace_path)
|
|
|
|
# Extract operator statistics
|
|
key_averages = prof.key_averages(
|
|
group_by_input_shape=self.profiler_config["group_by_input_shape"]
|
|
)
|
|
for evt in key_averages:
|
|
op_name = evt.key
|
|
if op_name not in self.operator_stats:
|
|
self.operator_stats[op_name] = {
|
|
'count': 0,
|
|
'cpu_time_total_us': 0,
|
|
'cuda_time_total_us': 0,
|
|
'self_cpu_time_total_us': 0,
|
|
'self_cuda_time_total_us': 0,
|
|
'cpu_memory_usage': 0,
|
|
'cuda_memory_usage': 0,
|
|
'flops': 0,
|
|
}
|
|
stats = self.operator_stats[op_name]
|
|
stats['count'] += evt.count
|
|
stats['cpu_time_total_us'] += evt.cpu_time_total
|
|
stats['cuda_time_total_us'] += evt.cuda_time_total
|
|
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
|
|
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
|
|
if hasattr(evt, 'cpu_memory_usage'):
|
|
stats['cpu_memory_usage'] += evt.cpu_memory_usage
|
|
if hasattr(evt, 'cuda_memory_usage'):
|
|
stats['cuda_memory_usage'] += evt.cuda_memory_usage
|
|
if hasattr(evt, 'flops') and evt.flops:
|
|
stats['flops'] += evt.flops
|
|
|
|
def step_profiler(self):
|
|
"""Step the PyTorch profiler."""
|
|
if self.pytorch_profiler is not None:
|
|
self.pytorch_profiler.step()
|
|
|
|
def generate_report(self) -> str:
|
|
"""Generate comprehensive profiling report."""
|
|
if not self.enabled:
|
|
return "Profiling disabled."
|
|
|
|
report_lines = []
|
|
report_lines.append("=" * 80)
|
|
report_lines.append("PERFORMANCE PROFILING REPORT")
|
|
report_lines.append("=" * 80)
|
|
report_lines.append("")
|
|
|
|
# Macro-level timing summary
|
|
report_lines.append("-" * 40)
|
|
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
|
|
report_lines.append("-" * 40)
|
|
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
|
|
report_lines.append("-" * 86)
|
|
|
|
total_time = 0
|
|
timing_data = []
|
|
for name, times in sorted(self.macro_timings.items()):
|
|
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
|
|
avg_time = np.mean(times)
|
|
avg_cuda = np.mean(cuda_times) if cuda_times else 0
|
|
total = sum(times)
|
|
total_time += total
|
|
timing_data.append({
|
|
'name': name,
|
|
'count': len(times),
|
|
'total_ms': total,
|
|
'avg_ms': avg_time,
|
|
'cuda_avg_ms': avg_cuda,
|
|
'times': times,
|
|
'cuda_times': cuda_times,
|
|
})
|
|
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
|
|
|
|
report_lines.append("-" * 86)
|
|
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
|
|
report_lines.append("")
|
|
|
|
# Memory summary
|
|
if self.memory_snapshots:
|
|
report_lines.append("-" * 40)
|
|
report_lines.append("GPU MEMORY SUMMARY")
|
|
report_lines.append("-" * 40)
|
|
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
|
|
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
|
|
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
|
|
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
|
|
report_lines.append("")
|
|
|
|
# Top operators by CUDA time
|
|
if self.operator_stats:
|
|
report_lines.append("-" * 40)
|
|
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
|
|
report_lines.append("-" * 40)
|
|
sorted_ops = sorted(
|
|
self.operator_stats.items(),
|
|
key=lambda x: x[1]['cuda_time_total_us'],
|
|
reverse=True
|
|
)[:30]
|
|
|
|
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
|
|
report_lines.append("-" * 96)
|
|
|
|
for op_name, stats in sorted_ops:
|
|
# Truncate long operator names
|
|
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
|
|
report_lines.append(
|
|
f"{display_name:<50} {stats['count']:>8} "
|
|
f"{stats['cuda_time_total_us']/1000:>12.2f} "
|
|
f"{stats['cpu_time_total_us']/1000:>12.2f} "
|
|
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
|
|
)
|
|
report_lines.append("")
|
|
|
|
# Compute category breakdown
|
|
report_lines.append("-" * 40)
|
|
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
|
|
report_lines.append("-" * 40)
|
|
|
|
categories = {
|
|
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
|
|
'Convolution': ['conv', 'cudnn'],
|
|
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
|
|
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
|
|
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
|
|
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
|
|
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
|
|
}
|
|
|
|
category_times = {cat: 0.0 for cat in categories}
|
|
category_times['Other'] = 0.0
|
|
|
|
for op_name, stats in self.operator_stats.items():
|
|
op_lower = op_name.lower()
|
|
categorized = False
|
|
for cat, keywords in categories.items():
|
|
if any(kw in op_lower for kw in keywords):
|
|
category_times[cat] += stats['cuda_time_total_us']
|
|
categorized = True
|
|
break
|
|
if not categorized:
|
|
category_times['Other'] += stats['cuda_time_total_us']
|
|
|
|
total_op_time = sum(category_times.values())
|
|
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
|
|
report_lines.append("-" * 57)
|
|
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
|
|
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
|
|
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
|
|
report_lines.append("")
|
|
|
|
report = "\n".join(report_lines)
|
|
return report
|
|
|
|
def save_results(self):
|
|
"""Save all profiling results to files."""
|
|
if not self.enabled:
|
|
return
|
|
|
|
# Save report
|
|
report = self.generate_report()
|
|
report_path = os.path.join(self.output_dir, "profiling_report.txt")
|
|
with open(report_path, 'w') as f:
|
|
f.write(report)
|
|
print(f">>> Profiling report saved to: {report_path}")
|
|
|
|
# Save detailed JSON data
|
|
data = {
|
|
'macro_timings': {
|
|
name: {
|
|
'times': times,
|
|
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
|
|
}
|
|
for name, times in self.macro_timings.items()
|
|
},
|
|
'memory_snapshots': self.memory_snapshots,
|
|
'operator_stats': self.operator_stats,
|
|
}
|
|
json_path = os.path.join(self.output_dir, "profiling_data.json")
|
|
with open(json_path, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
print(f">>> Detailed profiling data saved to: {json_path}")
|
|
|
|
# Print summary to console
|
|
print("\n" + report)
|
|
|
|
|
|
# Global profiler instance
|
|
_profiler: Optional[ProfilerManager] = None
|
|
|
|
def get_profiler() -> ProfilerManager:
|
|
"""Get the global profiler instance."""
|
|
global _profiler
|
|
if _profiler is None:
|
|
_profiler = ProfilerManager(enabled=False)
|
|
return _profiler
|
|
|
|
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
|
|
"""Initialize the global profiler."""
|
|
global _profiler
|
|
_profiler = ProfilerManager(
|
|
enabled=enabled,
|
|
output_dir=output_dir,
|
|
profile_detail=profile_detail,
|
|
)
|
|
return _profiler
|
|
|
|
|
|
# ========== Async I/O ==========
|
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
|
_io_futures: List[Any] = []
|
|
|
|
|
|
def _get_io_executor() -> ThreadPoolExecutor:
|
|
global _io_executor
|
|
if _io_executor is None:
|
|
_io_executor = ThreadPoolExecutor(max_workers=2)
|
|
return _io_executor
|
|
|
|
|
|
def _flush_io():
|
|
"""Wait for all pending async I/O to finish."""
|
|
global _io_futures
|
|
for fut in _io_futures:
|
|
try:
|
|
fut.result()
|
|
except Exception as e:
|
|
print(f">>> [async I/O] error: {e}")
|
|
_io_futures.clear()
|
|
|
|
|
|
atexit.register(_flush_io)
|
|
|
|
|
|
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
|
"""Synchronous save on CPU tensor (runs in background thread)."""
|
|
video = torch.clamp(video_cpu.float(), -1., 1.)
|
|
n = video.shape[0]
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
torchvision.io.write_video(filename,
|
|
grid,
|
|
fps=fps,
|
|
video_codec='h264',
|
|
options={'crf': '10'})
|
|
|
|
|
|
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
|
"""Submit video saving to background thread pool."""
|
|
video_cpu = video.detach().cpu()
|
|
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
|
_io_futures.append(fut)
|
|
|
|
|
|
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
|
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
|
if video_cpu.dim() == 5:
|
|
n = video_cpu.shape[0]
|
|
video = video_cpu.permute(2, 0, 1, 3, 4)
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = grid.unsqueeze(dim=0)
|
|
writer.add_video(tag, grid, fps=fps)
|
|
|
|
|
|
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
|
"""Submit TensorBoard logging to background thread pool."""
|
|
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
|
data_cpu = data.detach().cpu()
|
|
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
|
_io_futures.append(fut)
|
|
|
|
|
|
# ========== Original Functions ==========
|
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|
"""Get a module's device by checking one of its parameters.
|
|
|
|
Args:
|
|
module (nn.Module): The model whose device is to be inferred.
|
|
|
|
Returns:
|
|
torch.device: The device of the model's parameters.
|
|
"""
|
|
return next(iter(module.parameters())).device
|
|
|
|
|
|
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
|
"""Save a list of frames to a video file.
|
|
|
|
Args:
|
|
video_path (str): Output path for the video.
|
|
stacked_frames (list): List of image frames.
|
|
fps (int): Frames per second for the video.
|
|
"""
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore",
|
|
"pkg_resources is deprecated as an API",
|
|
category=DeprecationWarning)
|
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
|
|
|
|
|
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
|
"""Return sorted list of files in a directory matching specified postfixes.
|
|
|
|
Args:
|
|
data_dir (str): Directory path to search in.
|
|
postfixes (list[str]): List of file extensions to match.
|
|
|
|
Returns:
|
|
list[str]: Sorted list of file paths.
|
|
"""
|
|
patterns = [
|
|
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
|
|
]
|
|
file_list = []
|
|
for pattern in patterns:
|
|
file_list.extend(glob.glob(pattern))
|
|
file_list.sort()
|
|
return file_list
|
|
|
|
|
|
def _load_state_dict(model: nn.Module,
|
|
state_dict: Mapping[str, torch.Tensor],
|
|
strict: bool = True,
|
|
assign: bool = False) -> None:
|
|
if assign:
|
|
try:
|
|
model.load_state_dict(state_dict, strict=strict, assign=True)
|
|
return
|
|
except TypeError:
|
|
warnings.warn(
|
|
"load_state_dict(assign=True) not supported; "
|
|
"falling back to copy load.")
|
|
model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
def load_model_checkpoint(model: nn.Module,
|
|
ckpt: str,
|
|
assign: bool | None = None,
|
|
device: str | torch.device = "cpu") -> nn.Module:
|
|
"""Load model weights from checkpoint file.
|
|
|
|
Args:
|
|
model (nn.Module): Model instance.
|
|
ckpt (str): Path to the checkpoint file.
|
|
assign (bool | None): Whether to preserve checkpoint tensor dtypes
|
|
via load_state_dict(assign=True). If None, auto-enable when a
|
|
casted checkpoint metadata is detected.
|
|
device (str | torch.device): Target device for loaded tensors.
|
|
|
|
Returns:
|
|
nn.Module: Model with loaded weights.
|
|
"""
|
|
ckpt_data = torch.load(ckpt, map_location=device, mmap=True)
|
|
use_assign = False
|
|
if assign is not None:
|
|
use_assign = assign
|
|
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
|
|
use_assign = True
|
|
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
|
|
state_dict = ckpt_data["state_dict"]
|
|
try:
|
|
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
|
|
except Exception:
|
|
new_pl_sd = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
new_pl_sd[k] = v
|
|
|
|
for k in list(new_pl_sd.keys()):
|
|
if "framestride_embed" in k:
|
|
new_key = k.replace("framestride_embed", "fps_embedding")
|
|
new_pl_sd[new_key] = new_pl_sd[k]
|
|
del new_pl_sd[k]
|
|
_load_state_dict(model,
|
|
new_pl_sd,
|
|
strict=True,
|
|
assign=use_assign)
|
|
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
|
|
new_pl_sd = OrderedDict()
|
|
for key in ckpt_data['module'].keys():
|
|
new_pl_sd[key[16:]] = ckpt_data['module'][key]
|
|
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
|
|
else:
|
|
_load_state_dict(model,
|
|
ckpt_data,
|
|
strict=True,
|
|
assign=use_assign)
|
|
print('>>> model checkpoint loaded.')
|
|
return model
|
|
|
|
|
|
def maybe_cast_module(module: nn.Module | None,
|
|
dtype: torch.dtype,
|
|
label: str,
|
|
profiler: Optional[ProfilerManager] = None,
|
|
profile_name: Optional[str] = None) -> None:
|
|
if module is None:
|
|
return
|
|
try:
|
|
param = next(module.parameters())
|
|
except StopIteration:
|
|
print(f">>> {label} has no parameters; skip cast")
|
|
return
|
|
if param.dtype == dtype:
|
|
print(f">>> {label} already {dtype}; skip cast")
|
|
return
|
|
ctx = nullcontext()
|
|
if profiler is not None and profile_name:
|
|
ctx = profiler.profile_section(profile_name)
|
|
with ctx:
|
|
module.to(dtype=dtype)
|
|
print(f">>> {label} cast to {dtype}")
|
|
|
|
|
|
def save_casted_checkpoint(model: nn.Module,
|
|
save_path: str,
|
|
metadata: Optional[Dict[str, Any]] = None) -> None:
|
|
if not save_path:
|
|
return
|
|
save_dir = os.path.dirname(save_path)
|
|
if save_dir:
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
cpu_state = {}
|
|
for key, value in model.state_dict().items():
|
|
if isinstance(value, torch.Tensor):
|
|
cpu_state[key] = value.detach().to("cpu")
|
|
else:
|
|
cpu_state[key] = value
|
|
payload: Dict[str, Any] = {"state_dict": cpu_state}
|
|
if metadata:
|
|
payload["precision_metadata"] = metadata
|
|
torch.save(payload, save_path)
|
|
print(f">>> Saved casted checkpoint to {save_path}")
|
|
|
|
|
|
def _module_param_dtype(module: nn.Module | None) -> str:
|
|
if module is None:
|
|
return "None"
|
|
dtype_counts: Dict[str, int] = {}
|
|
for param in module.parameters():
|
|
dtype_key = str(param.dtype)
|
|
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
|
|
if not dtype_counts:
|
|
return "no_params"
|
|
if len(dtype_counts) == 1:
|
|
return next(iter(dtype_counts))
|
|
total = sum(dtype_counts.values())
|
|
parts = []
|
|
for dtype_key in sorted(dtype_counts.keys()):
|
|
ratio = dtype_counts[dtype_key] / total
|
|
parts.append(f"{dtype_key}={ratio:.1%}")
|
|
return f"mixed({', '.join(parts)})"
|
|
|
|
|
|
def log_inference_precision(model: nn.Module) -> None:
|
|
device = "unknown"
|
|
for param in model.parameters():
|
|
device = str(param.device)
|
|
break
|
|
model_dtype = _module_param_dtype(model)
|
|
|
|
print(f">>> inference precision: model={model_dtype}, device={device}")
|
|
for attr in [
|
|
"model", "first_stage_model", "cond_stage_model", "embedder",
|
|
"image_proj_model"
|
|
]:
|
|
if hasattr(model, attr):
|
|
submodule = getattr(model, attr)
|
|
print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}")
|
|
|
|
print(
|
|
">>> autocast gpu dtype default: "
|
|
f"{torch.get_autocast_gpu_dtype()} "
|
|
f"(enabled={torch.is_autocast_enabled()})")
|
|
|
|
|
|
def is_inferenced(save_dir: str, filename: str) -> bool:
|
|
"""Check if a given filename has already been processed and saved.
|
|
|
|
Args:
|
|
save_dir (str): Directory where results are saved.
|
|
filename (str): Name of the file to check.
|
|
|
|
Returns:
|
|
bool: True if processed file exists, False otherwise.
|
|
"""
|
|
video_file = os.path.join(save_dir, "samples_separate",
|
|
f"{filename[:-4]}_sample0.mp4")
|
|
return os.path.exists(video_file)
|
|
|
|
|
|
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
|
"""Save video tensor to file using torchvision.
|
|
|
|
Args:
|
|
video (Tensor): Tensor of shape (B, C, T, H, W).
|
|
filename (str): Output file path.
|
|
fps (int, optional): Frames per second. Defaults to 8.
|
|
"""
|
|
video = video.detach().cpu()
|
|
video = torch.clamp(video.float(), -1., 1.)
|
|
n = video.shape[0]
|
|
video = video.permute(2, 0, 1, 3, 4)
|
|
|
|
frame_grids = [
|
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
for framesheet in video
|
|
]
|
|
grid = torch.stack(frame_grids, dim=0)
|
|
grid = (grid + 1.0) / 2.0
|
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
torchvision.io.write_video(filename,
|
|
grid,
|
|
fps=fps,
|
|
video_codec='h264',
|
|
options={'crf': '10'})
|
|
|
|
|
|
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
|
"""Construct the init_frame path from directory and sample metadata.
|
|
|
|
Args:
|
|
data_dir (str): Base directory containing videos.
|
|
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
|
|
|
|
Returns:
|
|
str: Full path to the video file.
|
|
"""
|
|
rel_video_fp = os.path.join(sample['data_dir'],
|
|
str(sample['videoid']) + '.png')
|
|
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
|
|
return full_image_fp
|
|
|
|
|
|
def get_transition_path(data_dir: str, sample: dict) -> str:
|
|
"""Construct the full transition file path from directory and sample metadata.
|
|
|
|
Args:
|
|
data_dir (str): Base directory containing transition files.
|
|
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
|
|
|
|
Returns:
|
|
str: Full path to the HDF5 transition file.
|
|
"""
|
|
rel_transition_fp = os.path.join(sample['data_dir'],
|
|
str(sample['videoid']) + '.h5')
|
|
full_transition_fp = os.path.join(data_dir, 'transitions',
|
|
rel_transition_fp)
|
|
return full_transition_fp
|
|
|
|
|
|
def prepare_init_input(start_idx: int,
|
|
init_frame_path: str,
|
|
transition_dict: dict[str, torch.Tensor],
|
|
frame_stride: int,
|
|
wma_data,
|
|
video_length: int = 16,
|
|
n_obs_steps: int = 2) -> dict[str, Tensor]:
|
|
"""
|
|
Extracts a structured sample from a video sequence including frames, states, and actions,
|
|
along with properly padded observations and pre-processed tensors for model input.
|
|
|
|
Args:
|
|
start_idx (int): Starting frame index for the current clip.
|
|
video: decord video instance.
|
|
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
|
|
'observation.state', 'action_type', 'state_type'.
|
|
frame_stride (int): Temporal stride between sampled frames.
|
|
wma_data: Object that holds configuration and utility functions like normalization,
|
|
transformation, and resolution info.
|
|
video_length (int, optional): Number of frames to sample from the video. Default is 16.
|
|
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
|
|
"""
|
|
|
|
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
|
init_frame = Image.open(init_frame_path).convert('RGB')
|
|
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
|
|
3, 0, 1, 2).float()
|
|
|
|
if start_idx < n_obs_steps - 1:
|
|
state_indices = list(range(0, start_idx + 1))
|
|
states = transition_dict['observation.state'][state_indices, :]
|
|
num_padding = n_obs_steps - 1 - start_idx
|
|
first_slice = states[0:1, :] # (t, d)
|
|
padding = first_slice.repeat(num_padding, 1)
|
|
states = torch.cat((padding, states), dim=0)
|
|
else:
|
|
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
|
states = transition_dict['observation.state'][state_indices, :]
|
|
|
|
actions = transition_dict['action'][indices, :]
|
|
|
|
ori_state_dim = states.shape[-1]
|
|
ori_action_dim = actions.shape[-1]
|
|
|
|
frames_action_state_dict = {
|
|
'action': actions,
|
|
'observation.state': states,
|
|
}
|
|
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
|
frames_action_state_dict = wma_data.get_uni_vec(
|
|
frames_action_state_dict,
|
|
transition_dict['action_type'],
|
|
transition_dict['state_type'],
|
|
)
|
|
|
|
if wma_data.spatial_transform is not None:
|
|
init_frame = wma_data.spatial_transform(init_frame)
|
|
init_frame = (init_frame / 255 - 0.5) * 2
|
|
|
|
data = {
|
|
'observation.image': init_frame,
|
|
}
|
|
data.update(frames_action_state_dict)
|
|
return data, ori_state_dim, ori_action_dim
|
|
|
|
|
|
def get_latent_z(model, videos: Tensor) -> Tensor:
|
|
"""
|
|
Extracts latent features from a video batch using the model's first-stage encoder.
|
|
|
|
Args:
|
|
model: the world model.
|
|
videos (Tensor): Input videos of shape [B, C, T, H, W].
|
|
|
|
Returns:
|
|
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
|
"""
|
|
profiler = get_profiler()
|
|
with profiler.profile_section("get_latent_z/encode"):
|
|
b, c, t, h, w = videos.shape
|
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
|
vae_ctx = nullcontext()
|
|
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
|
|
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
|
with vae_ctx:
|
|
z = model.encode_first_stage(x)
|
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
|
return z
|
|
|
|
|
|
def preprocess_observation(
|
|
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
|
"""Convert environment observation to LeRobot format observation.
|
|
Args:
|
|
observation: Dictionary of observation batches from a Gym vector environment.
|
|
Returns:
|
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
|
"""
|
|
# Map to expected inputs for the policy
|
|
return_observations = {}
|
|
|
|
if isinstance(observations["pixels"], dict):
|
|
imgs = {
|
|
f"observation.images.{key}": img
|
|
for key, img in observations["pixels"].items()
|
|
}
|
|
else:
|
|
imgs = {"observation.images.top": observations["pixels"]}
|
|
|
|
for imgkey, img in imgs.items():
|
|
img = torch.from_numpy(img)
|
|
|
|
# Sanity check that images are channel last
|
|
_, h, w, c = img.shape
|
|
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
|
|
|
# Sanity check that images are uint8
|
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
|
|
|
# Convert to channel first of type float32 in range [0,1]
|
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
img = img.type(torch.float32)
|
|
|
|
return_observations[imgkey] = img
|
|
|
|
return_observations["observation.state"] = torch.from_numpy(
|
|
observations["agent_pos"]).float()
|
|
return_observations['observation.state'] = model.normalize_inputs({
|
|
'observation.state':
|
|
return_observations['observation.state'].to(model.device)
|
|
})['observation.state']
|
|
|
|
return return_observations
|
|
|
|
|
|
def _move_to_device(batch: Mapping[str, Any],
|
|
device: torch.device) -> dict[str, Any]:
|
|
moved = {}
|
|
for key, value in batch.items():
|
|
if isinstance(value, torch.Tensor) and value.device != device:
|
|
moved[key] = value.to(device, non_blocking=True)
|
|
else:
|
|
moved[key] = value
|
|
return moved
|
|
|
|
|
|
def image_guided_synthesis_sim_mode(
|
|
model: torch.nn.Module,
|
|
prompts: list[str],
|
|
observation: dict,
|
|
noise_shape: tuple[int, int, int, int, int],
|
|
action_cond_step: int = 16,
|
|
n_samples: int = 1,
|
|
ddim_steps: int = 50,
|
|
ddim_eta: float = 1.0,
|
|
unconditional_guidance_scale: float = 1.0,
|
|
fs: int | None = None,
|
|
text_input: bool = True,
|
|
timestep_spacing: str = 'uniform',
|
|
guidance_rescale: float = 0.0,
|
|
sim_mode: bool = True,
|
|
diffusion_autocast_dtype: Optional[torch.dtype] = None,
|
|
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
|
|
|
Args:
|
|
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
|
|
prompts (list[str]): A list of textual prompts to guide the synthesis process.
|
|
observation (dict): A dictionary containing observed inputs including:
|
|
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
|
|
- 'observation.state': Tensor of shape [B, O, D] (state vector)
|
|
- 'action': Tensor of shape [B, T, D] (action sequence)
|
|
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
|
|
typically (B, C, T, H, W).
|
|
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
|
|
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
|
|
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
|
|
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
|
|
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
|
|
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
|
|
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
|
|
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
|
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
|
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
|
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
|
|
**kwargs: Additional arguments passed to the DDIM sampler.
|
|
|
|
Returns:
|
|
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
|
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
|
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
|
"""
|
|
profiler = get_profiler()
|
|
|
|
b, _, t, _, _ = noise_shape
|
|
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
|
if ddim_sampler is None:
|
|
ddim_sampler = DDIMSampler(model)
|
|
model._ddim_sampler = ddim_sampler
|
|
batch_size = noise_shape[0]
|
|
|
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
|
|
|
with profiler.profile_section("synthesis/conditioning_prep"):
|
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
|
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
|
if getattr(model, "encoder_mode", "autocast") == "autocast":
|
|
preprocess_ctx = torch.autocast("cuda", enabled=False)
|
|
with preprocess_ctx:
|
|
cond_img_fp32 = cond_img.float()
|
|
if hasattr(model.embedder, "preprocess"):
|
|
preprocessed = model.embedder.preprocess(cond_img_fp32)
|
|
else:
|
|
preprocessed = cond_img_fp32
|
|
|
|
if hasattr(model.embedder,
|
|
"encode_with_vision_transformer") and hasattr(
|
|
model.embedder, "preprocess"):
|
|
original_preprocess = model.embedder.preprocess
|
|
try:
|
|
model.embedder.preprocess = lambda x: x
|
|
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
cond_img_emb = model.embedder.encode_with_vision_transformer(
|
|
preprocessed)
|
|
finally:
|
|
model.embedder.preprocess = original_preprocess
|
|
else:
|
|
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
cond_img_emb = model.embedder(preprocessed)
|
|
else:
|
|
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
cond_img_emb = model.embedder(cond_img)
|
|
else:
|
|
cond_img_emb = model.embedder(cond_img)
|
|
|
|
if model.model.conditioning_key == 'hybrid':
|
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
|
img_cat_cond = z[:, :, -1:, :, :]
|
|
img_cat_cond = repeat(img_cat_cond,
|
|
'b c t h w -> b c (repeat t) h w',
|
|
repeat=noise_shape[2])
|
|
cond = {"c_concat": [img_cat_cond]}
|
|
|
|
if not text_input:
|
|
prompts = [""] * batch_size
|
|
encoder_ctx = nullcontext()
|
|
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
|
encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
|
with encoder_ctx:
|
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
|
target_dtype = cond_ins_emb.dtype
|
|
|
|
cond_img_emb = model._projector_forward(model.image_proj_model,
|
|
cond_img_emb, target_dtype)
|
|
|
|
cond_state_emb = model._projector_forward(
|
|
model.state_projector, observation['observation.state'],
|
|
target_dtype)
|
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to(
|
|
dtype=target_dtype)
|
|
|
|
cond_action_emb = model._projector_forward(
|
|
model.action_projector, observation['action'], target_dtype)
|
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to(
|
|
dtype=target_dtype)
|
|
|
|
if not sim_mode:
|
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
|
|
|
cond["c_crossattn"] = [
|
|
torch.cat(
|
|
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
|
dim=1)
|
|
]
|
|
cond["c_crossattn_action"] = [
|
|
observation['observation.images.top'][:, :,
|
|
-model.n_obs_steps_acting:],
|
|
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
|
sim_mode,
|
|
False,
|
|
]
|
|
|
|
uc = None
|
|
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
|
cond_mask = None
|
|
cond_z0 = None
|
|
|
|
if ddim_sampler is not None:
|
|
with profiler.profile_section("synthesis/ddim_sampling"):
|
|
autocast_ctx = nullcontext()
|
|
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
|
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
|
with autocast_ctx:
|
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
|
S=ddim_steps,
|
|
conditioning=cond,
|
|
batch_size=batch_size,
|
|
shape=noise_shape[1:],
|
|
verbose=False,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=ddim_eta,
|
|
cfg_img=None,
|
|
mask=cond_mask,
|
|
x0=cond_z0,
|
|
fs=fs,
|
|
timestep_spacing=timestep_spacing,
|
|
guidance_rescale=guidance_rescale,
|
|
**kwargs)
|
|
|
|
# Reconstruct from latent to pixel space
|
|
with profiler.profile_section("synthesis/decode_first_stage"):
|
|
if getattr(model, "vae_bf16", False):
|
|
if samples.dtype != torch.bfloat16:
|
|
samples = samples.to(dtype=torch.bfloat16)
|
|
vae_ctx = nullcontext()
|
|
if model.device.type == "cuda":
|
|
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
|
with vae_ctx:
|
|
batch_images = model.decode_first_stage(samples)
|
|
else:
|
|
if samples.dtype != torch.float32:
|
|
samples = samples.float()
|
|
batch_images = model.decode_first_stage(samples)
|
|
batch_variants = batch_images
|
|
|
|
return batch_variants, actions, states
|
|
|
|
|
|
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|
"""
|
|
Run inference pipeline on prompts and image inputs.
|
|
|
|
Args:
|
|
args (argparse.Namespace): Parsed command-line arguments.
|
|
gpu_num (int): Number of GPUs.
|
|
gpu_no (int): Index of the current GPU.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
profiler = get_profiler()
|
|
|
|
# Create inference and tensorboard dirs
|
|
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
|
log_dir = args.savedir + f"/tensorboard"
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
writer = SummaryWriter(log_dir=log_dir)
|
|
|
|
# Load prompt
|
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
|
df = pd.read_csv(csv_path)
|
|
|
|
# Load config (always needed for data setup)
|
|
config = OmegaConf.load(args.config)
|
|
|
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
|
if os.path.exists(prepared_path):
|
|
# ---- Fast path: load the fully-prepared model ----
|
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
|
with profiler.profile_section("model_loading/prepared"):
|
|
model = torch.load(prepared_path,
|
|
map_location=f"cuda:{gpu_no}",
|
|
weights_only=False,
|
|
mmap=True)
|
|
model.eval()
|
|
diffusion_autocast_dtype = (torch.bfloat16
|
|
if args.diffusion_dtype == "bf16"
|
|
else None)
|
|
print(f">>> Prepared model loaded.")
|
|
else:
|
|
# ---- Normal path: construct + checkpoint + casting ----
|
|
with profiler.profile_section("model_loading/config"):
|
|
config['model']['params']['wma_config']['params'][
|
|
'use_checkpoint'] = False
|
|
model = instantiate_from_config(config.model)
|
|
model.perframe_ae = args.perframe_ae
|
|
|
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
|
|
|
with profiler.profile_section("model_loading/checkpoint"):
|
|
model = load_model_checkpoint(model, args.ckpt_path,
|
|
device=f"cuda:{gpu_no}")
|
|
model.eval()
|
|
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
|
print(f'>>> Load pre-trained model ...')
|
|
|
|
diffusion_autocast_dtype = None
|
|
if args.diffusion_dtype == "bf16":
|
|
maybe_cast_module(
|
|
model.model,
|
|
torch.bfloat16,
|
|
"diffusion backbone",
|
|
profiler=profiler,
|
|
profile_name="model_loading/diffusion_bf16",
|
|
)
|
|
diffusion_autocast_dtype = torch.bfloat16
|
|
print(">>> diffusion backbone set to bfloat16")
|
|
|
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
|
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
|
maybe_cast_module(
|
|
model.first_stage_model,
|
|
vae_weight_dtype,
|
|
"VAE",
|
|
profiler=profiler,
|
|
profile_name="model_loading/vae_cast",
|
|
)
|
|
model.vae_bf16 = args.vae_dtype == "bf16"
|
|
print(f">>> VAE dtype set to {args.vae_dtype}")
|
|
|
|
# --- VAE performance optimizations ---
|
|
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
|
vae = model.first_stage_model
|
|
|
|
# Channels-last memory format: cuDNN uses faster NHWC kernels
|
|
if args.vae_channels_last:
|
|
vae = vae.to(memory_format=torch.channels_last)
|
|
vae._channels_last = True
|
|
model.first_stage_model = vae
|
|
print(">>> VAE converted to channels_last (NHWC) memory format")
|
|
|
|
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
|
|
if args.vae_compile:
|
|
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
|
|
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
|
|
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
|
|
|
|
# Batch decode size
|
|
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
|
|
model.vae_decode_bs = vae_decode_bs
|
|
model.vae_encode_bs = vae_decode_bs
|
|
if args.vae_decode_bs > 0:
|
|
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
|
|
else:
|
|
print(">>> VAE encode/decode batch size: all frames at once")
|
|
|
|
encoder_mode = args.encoder_mode
|
|
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
|
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
|
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
|
maybe_cast_module(
|
|
model.cond_stage_model,
|
|
encoder_weight_dtype,
|
|
"cond_stage_model",
|
|
profiler=profiler,
|
|
profile_name="model_loading/encoder_cond_cast",
|
|
)
|
|
if hasattr(model, "embedder") and model.embedder is not None:
|
|
maybe_cast_module(
|
|
model.embedder,
|
|
encoder_weight_dtype,
|
|
"embedder",
|
|
profiler=profiler,
|
|
profile_name="model_loading/encoder_embedder_cast",
|
|
)
|
|
model.encoder_bf16 = encoder_bf16
|
|
model.encoder_mode = encoder_mode
|
|
print(
|
|
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
|
)
|
|
|
|
projector_mode = args.projector_mode
|
|
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
|
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
|
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
|
maybe_cast_module(
|
|
model.image_proj_model,
|
|
projector_weight_dtype,
|
|
"image_proj_model",
|
|
profiler=profiler,
|
|
profile_name="model_loading/projector_image_cast",
|
|
)
|
|
if hasattr(model, "state_projector") and model.state_projector is not None:
|
|
maybe_cast_module(
|
|
model.state_projector,
|
|
projector_weight_dtype,
|
|
"state_projector",
|
|
profiler=profiler,
|
|
profile_name="model_loading/projector_state_cast",
|
|
)
|
|
if hasattr(model, "action_projector") and model.action_projector is not None:
|
|
maybe_cast_module(
|
|
model.action_projector,
|
|
projector_weight_dtype,
|
|
"action_projector",
|
|
profiler=profiler,
|
|
profile_name="model_loading/projector_action_cast",
|
|
)
|
|
if hasattr(model, "projector_bf16"):
|
|
model.projector_bf16 = projector_bf16
|
|
model.projector_mode = projector_mode
|
|
print(
|
|
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
|
|
)
|
|
|
|
log_inference_precision(model)
|
|
|
|
if args.export_casted_ckpt:
|
|
metadata = {
|
|
"diffusion_dtype": args.diffusion_dtype,
|
|
"vae_dtype": args.vae_dtype,
|
|
"encoder_mode": args.encoder_mode,
|
|
"projector_mode": args.projector_mode,
|
|
"perframe_ae": args.perframe_ae,
|
|
}
|
|
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
|
|
if args.export_only:
|
|
print(">>> export_only set; skipping inference.")
|
|
return
|
|
|
|
# Save prepared model for fast loading next time
|
|
if prepared_path:
|
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
|
torch.save(model, prepared_path)
|
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
|
|
|
# Build normalizer (always needed, independent of model loading path)
|
|
logging.info("***** Configing Data *****")
|
|
with profiler.profile_section("data_loading"):
|
|
data = instantiate_from_config(config.data)
|
|
data.setup()
|
|
print(">>> Dataset is successfully loaded ...")
|
|
device = get_device_from_parameters(model)
|
|
|
|
profiler.record_memory("after_model_load")
|
|
|
|
# Run over data
|
|
assert (args.height % 16 == 0) and (
|
|
args.width % 16
|
|
== 0), "Error: image size [h,w] should be multiples of 16!"
|
|
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
|
|
|
# Get latent noise shape
|
|
h, w = args.height // 8, args.width // 8
|
|
channels = model.model.diffusion_model.out_channels
|
|
n_frames = args.video_length
|
|
print(f'>>> Generate {n_frames} frames under each generation ...')
|
|
noise_shape = [args.bs, channels, n_frames, h, w]
|
|
|
|
# Determine profiler iterations
|
|
profile_active_iters = getattr(args, 'profile_iterations', 3)
|
|
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
|
|
|
|
# Start inference
|
|
for idx in range(0, len(df)):
|
|
sample = df.iloc[idx]
|
|
|
|
# Got initial frame path
|
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
|
ori_fps = float(sample['fps'])
|
|
|
|
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
|
|
os.makedirs(video_save_dir, exist_ok=True)
|
|
os.makedirs(video_save_dir + '/dm', exist_ok=True)
|
|
os.makedirs(video_save_dir + '/wm', exist_ok=True)
|
|
|
|
# Load transitions to get the initial state later
|
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
|
with profiler.profile_section("load_transitions"):
|
|
with h5py.File(transition_path, 'r') as h5f:
|
|
transition_dict = {}
|
|
for key in h5f.keys():
|
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
|
for key in h5f.attrs.keys():
|
|
transition_dict[key] = h5f.attrs[key]
|
|
|
|
# If many, test various frequence control and world-model generation
|
|
for fs in args.frame_stride:
|
|
|
|
# For saving imagens in policy
|
|
sample_save_dir = f'{video_save_dir}/dm/{fs}'
|
|
os.makedirs(sample_save_dir, exist_ok=True)
|
|
# For saving environmental changes in world-model
|
|
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
|
os.makedirs(sample_save_dir, exist_ok=True)
|
|
# For collecting interaction videos
|
|
wm_video = []
|
|
# Initialize observation queues
|
|
cond_obs_queues = {
|
|
"observation.images.top":
|
|
deque(maxlen=model.n_obs_steps_imagen),
|
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
|
"action": deque(maxlen=args.video_length),
|
|
}
|
|
|
|
# Obtain initial frame and state
|
|
with profiler.profile_section("prepare_init_input"):
|
|
start_idx = 0
|
|
model_input_fs = ori_fps // fs
|
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
|
start_idx,
|
|
init_frame_path,
|
|
transition_dict,
|
|
fs,
|
|
data.test_datasets[args.dataset],
|
|
n_obs_steps=model.n_obs_steps_imagen)
|
|
observation = {
|
|
'observation.images.top':
|
|
batch['observation.image'].permute(1, 0, 2,
|
|
3)[-1].unsqueeze(0),
|
|
'observation.state':
|
|
batch['observation.state'][-1].unsqueeze(0),
|
|
'action':
|
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
|
|
}
|
|
observation = _move_to_device(observation, device)
|
|
# Update observation queues
|
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
|
|
|
# Setup PyTorch profiler context if enabled
|
|
pytorch_prof_ctx = nullcontext()
|
|
if use_pytorch_profiler:
|
|
pytorch_prof_ctx = profiler.start_pytorch_profiler(
|
|
wait=1, warmup=1, active=profile_active_iters
|
|
)
|
|
|
|
# Multi-round interaction with the world-model
|
|
with pytorch_prof_ctx:
|
|
for itr in tqdm(range(args.n_iter)):
|
|
log_every = max(1, args.step_log_every)
|
|
log_step = (itr % log_every == 0)
|
|
profiler.current_iteration = itr
|
|
profiler.record_memory(f"iter_{itr}_start")
|
|
|
|
with profiler.profile_section("iteration_total"):
|
|
# Get observation
|
|
with profiler.profile_section("prepare_observation"):
|
|
observation = {
|
|
'observation.images.top':
|
|
torch.stack(list(
|
|
cond_obs_queues['observation.images.top']),
|
|
dim=1).permute(0, 2, 1, 3, 4),
|
|
'observation.state':
|
|
torch.stack(list(cond_obs_queues['observation.state']),
|
|
dim=1),
|
|
'action':
|
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
}
|
|
observation = _move_to_device(observation, device)
|
|
|
|
# Use world-model in policy to generate action
|
|
if log_step:
|
|
print(f'>>> Step {itr}: generating actions ...')
|
|
with profiler.profile_section("action_generation"):
|
|
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
|
model,
|
|
sample['instruction'],
|
|
observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
fs=model_input_fs,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
sim_mode=False,
|
|
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
|
|
|
# Update future actions in the observation queues
|
|
with profiler.profile_section("update_action_queues"):
|
|
for act_idx in range(len(pred_actions[0])):
|
|
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
|
|
obs_update['action'][:, ori_action_dim:] = 0.0
|
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
|
obs_update)
|
|
|
|
# Collect data for interacting the world-model using the predicted actions
|
|
with profiler.profile_section("prepare_wm_observation"):
|
|
observation = {
|
|
'observation.images.top':
|
|
torch.stack(list(
|
|
cond_obs_queues['observation.images.top']),
|
|
dim=1).permute(0, 2, 1, 3, 4),
|
|
'observation.state':
|
|
torch.stack(list(cond_obs_queues['observation.state']),
|
|
dim=1),
|
|
'action':
|
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
|
}
|
|
observation = _move_to_device(observation, device)
|
|
|
|
# Interaction with the world-model
|
|
if log_step:
|
|
print(f'>>> Step {itr}: interacting with world model ...')
|
|
with profiler.profile_section("world_model_interaction"):
|
|
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
|
model,
|
|
"",
|
|
observation,
|
|
noise_shape,
|
|
action_cond_step=args.exe_steps,
|
|
ddim_steps=args.ddim_steps,
|
|
ddim_eta=args.ddim_eta,
|
|
unconditional_guidance_scale=args.
|
|
unconditional_guidance_scale,
|
|
fs=model_input_fs,
|
|
text_input=False,
|
|
timestep_spacing=args.timestep_spacing,
|
|
guidance_rescale=args.guidance_rescale,
|
|
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
|
|
|
with profiler.profile_section("update_state_queues"):
|
|
for step_idx in range(args.exe_steps):
|
|
obs_update = {
|
|
'observation.images.top':
|
|
pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3),
|
|
'observation.state':
|
|
torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if
|
|
args.zero_pred_state else pred_states[0][step_idx:step_idx + 1],
|
|
'action':
|
|
torch.zeros_like(pred_actions[0][-1:])
|
|
}
|
|
obs_update['observation.state'][:, ori_state_dim:] = 0.0
|
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
|
obs_update)
|
|
|
|
# Save the imagen videos for decision-making (async)
|
|
with profiler.profile_section("save_results"):
|
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
|
log_to_tensorboard_async(writer,
|
|
pred_videos_0,
|
|
sample_tag,
|
|
fps=args.save_fps)
|
|
# Save videos environment changes via world-model interaction
|
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
|
log_to_tensorboard_async(writer,
|
|
pred_videos_1,
|
|
sample_tag,
|
|
fps=args.save_fps)
|
|
|
|
# Save the imagen videos for decision-making
|
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
|
save_results_async(pred_videos_0,
|
|
sample_video_file,
|
|
fps=args.save_fps)
|
|
# Save videos environment changes via world-model interaction
|
|
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
|
save_results_async(pred_videos_1,
|
|
sample_video_file,
|
|
fps=args.save_fps)
|
|
|
|
print('>' * 24)
|
|
# Collect the result of world-model interactions
|
|
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
|
|
|
profiler.record_memory(f"iter_{itr}_end")
|
|
profiler.step_profiler()
|
|
|
|
full_video = torch.cat(wm_video, dim=2)
|
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
|
log_to_tensorboard_async(writer,
|
|
full_video,
|
|
sample_tag,
|
|
fps=args.save_fps)
|
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
|
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
|
|
|
# Wait for all async I/O to complete before profiling report
|
|
_flush_io()
|
|
|
|
# Save profiling results
|
|
profiler.save_results()
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--savedir",
|
|
type=str,
|
|
default=None,
|
|
help="Path to save the results.")
|
|
parser.add_argument("--ckpt_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to the model checkpoint.")
|
|
parser.add_argument("--config",
|
|
type=str,
|
|
help="Path to the model checkpoint.")
|
|
parser.add_argument(
|
|
"--prompt_dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory containing videos and corresponding prompts.")
|
|
parser.add_argument("--dataset",
|
|
type=str,
|
|
default=None,
|
|
help="the name of dataset to test")
|
|
parser.add_argument(
|
|
"--ddim_steps",
|
|
type=int,
|
|
default=50,
|
|
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
|
parser.add_argument(
|
|
"--ddim_eta",
|
|
type=float,
|
|
default=1.0,
|
|
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
|
parser.add_argument("--bs",
|
|
type=int,
|
|
default=1,
|
|
help="Batch size for inference. Must be 1.")
|
|
parser.add_argument("--height",
|
|
type=int,
|
|
default=320,
|
|
help="Height of the generated images in pixels.")
|
|
parser.add_argument("--width",
|
|
type=int,
|
|
default=512,
|
|
help="Width of the generated images in pixels.")
|
|
parser.add_argument(
|
|
"--frame_stride",
|
|
type=int,
|
|
nargs='+',
|
|
required=True,
|
|
help=
|
|
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
|
|
)
|
|
parser.add_argument(
|
|
"--unconditional_guidance_scale",
|
|
type=float,
|
|
default=1.0,
|
|
help="Scale for classifier-free guidance during sampling.")
|
|
parser.add_argument("--seed",
|
|
type=int,
|
|
default=123,
|
|
help="Random seed for reproducibility.")
|
|
parser.add_argument("--video_length",
|
|
type=int,
|
|
default=16,
|
|
help="Number of frames in the generated video.")
|
|
parser.add_argument("--num_generation",
|
|
type=int,
|
|
default=1,
|
|
help="seed for seed_everything")
|
|
parser.add_argument(
|
|
"--timestep_spacing",
|
|
type=str,
|
|
default="uniform",
|
|
help=
|
|
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_rescale",
|
|
type=float,
|
|
default=0.0,
|
|
help=
|
|
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
|
)
|
|
parser.add_argument(
|
|
"--perframe_ae",
|
|
action='store_true',
|
|
default=False,
|
|
help=
|
|
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
|
)
|
|
parser.add_argument(
|
|
"--diffusion_dtype",
|
|
type=str,
|
|
choices=["fp32", "bf16"],
|
|
default="fp32",
|
|
help="Dtype for diffusion backbone weights and sampling autocast."
|
|
)
|
|
parser.add_argument(
|
|
"--projector_mode",
|
|
type=str,
|
|
choices=["fp32", "autocast", "bf16_full"],
|
|
default="fp32",
|
|
help=
|
|
"Projector precision mode for image/state/action projectors: "
|
|
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
|
"bf16_full=bf16 weights + bf16 forward."
|
|
)
|
|
parser.add_argument(
|
|
"--encoder_mode",
|
|
type=str,
|
|
choices=["fp32", "autocast", "bf16_full"],
|
|
default="fp32",
|
|
help=
|
|
"Encoder precision mode for cond_stage_model/embedder: "
|
|
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
|
"bf16_full=bf16 weights + bf16 forward."
|
|
)
|
|
parser.add_argument(
|
|
"--vae_dtype",
|
|
type=str,
|
|
choices=["fp32", "bf16"],
|
|
default="fp32",
|
|
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
|
)
|
|
parser.add_argument(
|
|
"--vae_compile",
|
|
action='store_true',
|
|
default=False,
|
|
help="Apply torch.compile to VAE decoder for kernel fusion."
|
|
)
|
|
parser.add_argument(
|
|
"--vae_channels_last",
|
|
action='store_true',
|
|
default=False,
|
|
help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions."
|
|
)
|
|
parser.add_argument(
|
|
"--vae_decode_bs",
|
|
type=int,
|
|
default=0,
|
|
help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead."
|
|
)
|
|
parser.add_argument(
|
|
"--export_casted_ckpt",
|
|
type=str,
|
|
default=None,
|
|
help=
|
|
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
|
|
)
|
|
parser.add_argument(
|
|
"--export_only",
|
|
action='store_true',
|
|
default=False,
|
|
help="Exit after exporting the casted checkpoint."
|
|
)
|
|
parser.add_argument(
|
|
"--step_log_every",
|
|
type=int,
|
|
default=1,
|
|
help="Print per-iteration step logs every N iterations."
|
|
)
|
|
parser.add_argument(
|
|
"--n_action_steps",
|
|
type=int,
|
|
default=16,
|
|
help="num of samples per prompt",
|
|
)
|
|
parser.add_argument(
|
|
"--exe_steps",
|
|
type=int,
|
|
default=16,
|
|
help="num of samples to execute",
|
|
)
|
|
parser.add_argument(
|
|
"--n_iter",
|
|
type=int,
|
|
default=40,
|
|
help="num of iteration to interact with the world model",
|
|
)
|
|
parser.add_argument("--zero_pred_state",
|
|
action='store_true',
|
|
default=False,
|
|
help="not using the predicted states as comparison")
|
|
parser.add_argument("--save_fps",
|
|
type=int,
|
|
default=8,
|
|
help="fps for the saving video")
|
|
# Profiling arguments
|
|
parser.add_argument(
|
|
"--profile",
|
|
action='store_true',
|
|
default=False,
|
|
help="Enable performance profiling (macro and operator-level analysis)."
|
|
)
|
|
parser.add_argument(
|
|
"--profile_output_dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
|
|
)
|
|
parser.add_argument(
|
|
"--profile_iterations",
|
|
type=int,
|
|
default=3,
|
|
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
|
|
)
|
|
parser.add_argument(
|
|
"--profile_detail",
|
|
type=str,
|
|
choices=["light", "full"],
|
|
default="light",
|
|
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
|
|
)
|
|
return parser
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
seed = args.seed
|
|
if seed < 0:
|
|
seed = random.randint(0, 2**31)
|
|
seed_everything(seed)
|
|
|
|
# Initialize profiler
|
|
profile_output_dir = args.profile_output_dir
|
|
if profile_output_dir is None:
|
|
profile_output_dir = os.path.join(args.savedir, "profile_output")
|
|
init_profiler(
|
|
enabled=args.profile,
|
|
output_dir=profile_output_dir,
|
|
profile_detail=args.profile_detail,
|
|
)
|
|
|
|
rank, gpu_num = 0, 1
|
|
run_inference(args, gpu_num, rank)
|