添加三层迭代级性能分析工具 profile_iteration.py

Layer1: CUDA Events 精确测量每个itr内10个阶段耗时
Layer2: torch.profiler GPU timeline trace
Layer3: CSV输出支持A/B对比

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-02-10 05:42:11 +00:00
parent 125b85ce68
commit b0ebb7006e
2 changed files with 980 additions and 0 deletions

View File

@@ -0,0 +1,975 @@
"""
Profile the full iteration loop of world model interaction.
Three layers of profiling:
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
Layer 3: A/B comparison (standardized CSV output)
Usage:
# Layer 1 only (fast, default):
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
python scripts/evaluation/profile_iteration.py \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
--dataset unitree_z1_dual_arm_cleanup_pencils \
--frame_stride 4 --n_iter 5
# Layer 1 + Layer 2 (GPU trace):
... --trace --trace_dir ./profile_traces
# Layer 3 (A/B comparison): run twice, diff the CSVs
... --csv baseline.csv
... --csv optimized.csv
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
"""
import argparse
import csv
import os
import sys
import time
from collections import defaultdict, deque
from contextlib import nullcontext
import h5py
import numpy as np
import pandas as pd
import torch
import torchvision
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from torch import Tensor
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
# ──────────────────────────────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────────────────────────────
STAGE_NAMES = [
"stack_to_device_1",
"synth_policy",
"update_action_queue",
"stack_to_device_2",
"synth_world_model",
"update_obs_queue",
"tensorboard_log",
"save_results",
"cpu_transfer",
"itr_total",
]
# Sub-stages inside image_guided_synthesis_sim_mode
SYNTH_SUB_STAGES = [
"ddim_sampler_init",
"image_embedding",
"vae_encode",
"text_conditioning",
"projectors",
"cond_assembly",
"ddim_sampling",
"vae_decode",
]
# ──────────────────────────────────────────────────────────────────────
# CudaTimer — GPU-precise timing via CUDA events
# ──────────────────────────────────────────────────────────────────────
class CudaTimer:
"""Context manager that records GPU time between enter/exit using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
def __enter__(self):
torch.cuda.synchronize()
self._start = torch.cuda.Event(enable_timing=True)
self._end = torch.cuda.Event(enable_timing=True)
self._start.record()
return self
def __exit__(self, *args):
self._end.record()
torch.cuda.synchronize()
elapsed_ms = self._start.elapsed_time(self._end)
self.records[self.name].append(elapsed_ms)
class WallTimer:
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
def __init__(self, name, records):
self.name = name
self.records = records
def __enter__(self):
torch.cuda.synchronize()
self._t0 = time.perf_counter()
return self
def __exit__(self, *args):
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
self.records[self.name].append(elapsed_ms)
# ──────────────────────────────────────────────────────────────────────
# Model loading (reused from world_model_interaction.py)
# ──────────────────────────────────────────────────────────────────────
def patch_norm_bypass_autocast():
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
from collections import OrderedDict
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except Exception:
new_sd = OrderedDict()
for k, v in state_dict.items():
new_sd[k] = v
for k in list(new_sd.keys()):
if "framestride_embed" in k:
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
model.load_state_dict(new_sd, strict=True)
model.eval()
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
model.embedder.to(torch.bfloat16)
model.image_proj_model.to(torch.bfloat16)
model.encoder_autocast_dtype = None
model.state_projector.to(torch.bfloat16)
model.action_projector.to(torch.bfloat16)
model.projector_autocast_dtype = None
if args.vae_dtype == "bf16":
model.first_stage_model.to(torch.bfloat16)
# Compile hot ResBlocks
apply_torch_compile(model)
model = model.cuda()
print(">>> Model loaded and ready.")
return model, config
# ──────────────────────────────────────────────────────────────────────
# Data preparation (reused from world_model_interaction.py)
# ──────────────────────────────────────────────────────────────────────
def get_init_frame_path(data_dir, sample):
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
return os.path.join(data_dir, 'images', rel)
def get_transition_path(data_dir, sample):
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
return os.path.join(data_dir, 'transitions', rel)
def prepare_init_input(start_idx, init_frame_path, transition_dict,
frame_stride, wma_data, video_length=16, n_obs_steps=2):
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
padding = states[0:1, :].repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {'observation.image': init_frame}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def populate_queues(queues, batch):
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
# ──────────────────────────────────────────────────────────────────────
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
# ──────────────────────────────────────────────────────────────────────
def get_latent_z(model, videos):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_dtype = next(model.first_stage_model.parameters()).dtype
x = x.to(dtype=vae_dtype)
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def save_results(video, filename, fps=8):
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename, grid, fps=fps,
video_codec='h264', options={'crf': '10'})
def profiled_synthesis(model, prompts, observation, noise_shape,
ddim_steps, ddim_eta, unconditional_guidance_scale,
fs, text_input, timestep_spacing, guidance_rescale,
sim_mode, decode_video, records, prefix):
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
Args:
prefix: "policy" or "wm" — prepended to sub-stage names in records.
"""
b, _, t, _, _ = noise_shape
batch_size = noise_shape[0]
device = next(model.parameters()).device
# --- sub-stage: ddim_sampler_init ---
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
ddim_sampler = DDIMSampler(model)
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
# --- sub-stage: image_embedding ---
with CudaTimer(f"{prefix}/image_embedding", records):
model_dtype = next(model.embedder.parameters()).dtype
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- sub-stage: vae_encode ---
with CudaTimer(f"{prefix}/vae_encode", records):
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
# --- sub-stage: text_conditioning ---
with CudaTimer(f"{prefix}/text_conditioning", records):
if not text_input:
prompts_use = [""] * batch_size
else:
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts_use)
# --- sub-stage: projectors ---
with CudaTimer(f"{prefix}/projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
# --- sub-stage: cond_assembly ---
with CudaTimer(f"{prefix}/cond_assembly", records):
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
# --- sub-stage: ddim_sampling ---
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
if autocast_dtype is not None and device.type == 'cuda':
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
else:
autocast_ctx = nullcontext()
with CudaTimer(f"{prefix}/ddim_sampling", records):
with autocast_ctx:
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=None,
eta=ddim_eta,
cfg_img=None,
mask=None,
x0=None,
fs=fs_t,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
unconditional_conditioning_img_nonetext=None,
)
# --- sub-stage: vae_decode ---
batch_variants = None
if decode_video:
with CudaTimer(f"{prefix}/vae_decode", records):
batch_variants = model.decode_first_stage(samples)
else:
records[f"{prefix}/vae_decode"].append(0.0)
return batch_variants, actions, states
# ──────────────────────────────────────────────────────────────────────
# Instrumented iteration loop
# ──────────────────────────────────────────────────────────────────────
def run_profiled_iterations(model, args, config, noise_shape, device):
"""Run the full iteration loop with per-stage timing.
Returns:
all_records: list of dicts, one per itr, {stage_name: ms}
"""
# Load data
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
sample = df.iloc[0]
data_module = instantiate_from_config(config.data)
data_module.setup()
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
fs = args.frame_stride
model_input_fs = ori_fps // fs
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# Prepare initial observation
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0, init_frame_path, transition_dict, fs,
data_module.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
cond_obs_queues = {
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Temp dir for save_results profiling
tmp_dir = os.path.join(args.savedir, "profile_tmp")
os.makedirs(tmp_dir, exist_ok=True)
prompt_text = sample['instruction']
all_records = []
print(f">>> Running {args.n_iter} profiled iterations ...")
for itr in range(args.n_iter):
rec = defaultdict(list)
# ── itr_total start ──
torch.cuda.synchronize()
itr_start = torch.cuda.Event(enable_timing=True)
itr_end = torch.cuda.Event(enable_timing=True)
itr_start.record()
# ① stack_to_device_1
with CudaTimer("stack_to_device_1", rec):
observation = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
# ② synth_policy
with CudaTimer("synth_policy", rec):
pred_videos_0, pred_actions, _ = profiled_synthesis(
model, prompt_text, observation, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=True,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode,
records=rec, prefix="policy")
# ③ update_action_queue
with WallTimer("update_action_queue", rec):
for idx in range(len(pred_actions[0])):
obs_a = {'action': pred_actions[0][idx:idx + 1]}
obs_a['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
# ④ stack_to_device_2
with CudaTimer("stack_to_device_2", rec):
observation = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
# ⑤ synth_world_model
with CudaTimer("synth_world_model", rec):
pred_videos_1, _, pred_states = profiled_synthesis(
model, "", observation, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=True, decode_video=True,
records=rec, prefix="wm")
# ⑥ update_obs_queue
with WallTimer("update_obs_queue", rec):
for idx in range(args.exe_steps):
obs_u = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state':
pred_states[0][idx:idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:]),
}
obs_u['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
with WallTimer("tensorboard_log", rec):
for vid in [pred_videos_0, pred_videos_1]:
if vid is not None and vid.dim() == 5:
v = vid.permute(2, 0, 1, 3, 4)
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
_ = torch.stack(grids, dim=0)
# ⑧ save_results
with WallTimer("save_results", rec):
if pred_videos_0 is not None:
save_results(pred_videos_0.cpu(),
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
fps=args.save_fps)
save_results(pred_videos_1.cpu(),
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
fps=args.save_fps)
# ⑨ cpu_transfer
with CudaTimer("cpu_transfer", rec):
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
# ── itr_total end ──
itr_end.record()
torch.cuda.synchronize()
itr_total_ms = itr_start.elapsed_time(itr_end)
rec["itr_total"].append(itr_total_ms)
# Flatten: each stage has exactly one entry per itr
itr_rec = {k: v[0] for k, v in rec.items()}
all_records.append(itr_rec)
# Print live progress
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
f"save={itr_rec.get('save_results', 0):.0f} | "
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
return all_records
# ──────────────────────────────────────────────────────────────────────
# Layer 1: Console report
# ──────────────────────────────────────────────────────────────────────
def print_iteration_report(all_records, warmup=1):
"""Print a structured table of per-stage timing across iterations."""
if len(all_records) <= warmup:
records = all_records
else:
records = all_records[warmup:]
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
# Collect all stage keys in a stable order
all_keys = []
seen = set()
for rec in records:
for k in rec:
if k not in seen:
all_keys.append(k)
seen.add(k)
# Separate top-level stages from sub-stages
top_keys = [k for k in all_keys if '/' not in k]
sub_keys = [k for k in all_keys if '/' in k]
def _print_table(keys, title):
if not keys:
return
print("=" * 82)
print(title)
print("=" * 82)
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
print("-" * 82)
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
for k in keys:
vals = [rec.get(k, 0) for rec in records]
mean = np.mean(vals)
std = np.std(vals)
mn = np.min(vals)
mx = np.max(vals)
pct = mean / total_mean * 100 if total_mean > 0 else 0
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
print("-" * 82)
print()
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
# ──────────────────────────────────────────────────────────────────────
# Layer 3: CSV output for A/B comparison
# ──────────────────────────────────────────────────────────────────────
def write_csv(all_records, csv_path, warmup=1):
"""Write per-iteration timing to CSV for later comparison."""
records = all_records[warmup:] if len(all_records) > warmup else all_records
# Collect all keys
all_keys = []
seen = set()
for rec in records:
for k in rec:
if k not in seen:
all_keys.append(k)
seen.add(k)
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
writer.writeheader()
for i, rec in enumerate(records):
row = {'itr': i}
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
writer.writerow(row)
# Also write a summary row
summary_path = csv_path.replace('.csv', '_summary.csv')
with open(summary_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
writer.writeheader()
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
('min', np.min), ('max', np.max)]:
row = {'stat': stat_name}
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
for k in all_keys})
writer.writerow(row)
print(f">>> CSV written to: {csv_path}")
print(f">>> Summary written to: {summary_path}")
def compare_csvs(path_a, path_b):
"""Compare two summary CSVs and print a diff table."""
df_a = pd.read_csv(path_a, index_col='stat')
df_b = pd.read_csv(path_b, index_col='stat')
# Use mean row for comparison
mean_a = df_a.loc['mean'].astype(float)
mean_b = df_b.loc['mean'].astype(float)
print("=" * 90)
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
print("=" * 90)
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
print("-" * 90)
for col in mean_a.index:
if col not in mean_b.index:
continue
a_val = mean_a[col]
b_val = mean_b[col]
diff = b_val - a_val
speedup = a_val / b_val if b_val > 0 else float('inf')
marker = " <<<" if abs(diff) > 50 else ""
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
print("-" * 90)
total_a = mean_a.get('itr_total', 0)
total_b = mean_b.get('itr_total', 0)
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
print()
# ──────────────────────────────────────────────────────────────────────
# Layer 2: GPU timeline trace wrapper
# ──────────────────────────────────────────────────────────────────────
def run_with_trace(model, args, config, noise_shape, device):
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
trace_dir = args.trace_dir
os.makedirs(trace_dir, exist_ok=True)
# We need the same data setup as run_profiled_iterations
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
sample = df.iloc[0]
data_module = instantiate_from_config(config.data)
data_module.setup()
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
fs = args.frame_stride
model_input_fs = ori_fps // fs
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0, init_frame_path, transition_dict, fs,
data_module.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
cond_obs_queues = {
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
cond_obs_queues = populate_queues(cond_obs_queues, observation)
tmp_dir = os.path.join(args.savedir, "profile_tmp")
os.makedirs(tmp_dir, exist_ok=True)
prompt_text = sample['instruction']
# Total iterations: warmup + active
n_warmup = 1
n_active = min(args.n_iter, 2) # trace 2 active iterations max
n_total = n_warmup + n_active
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
print(f">>> Trace output: {trace_dir}")
with torch.no_grad(), torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=0, warmup=n_warmup, active=n_active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=True,
with_stack=True,
) as prof:
for itr_idx in range(n_total):
phase = "warmup" if itr_idx < n_warmup else "active"
print(f" trace itr {itr_idx} ({phase})...")
# ── One full iteration (same logic as run_inference) ──
obs_loc = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
# Policy pass
dummy_rec = defaultdict(list)
pv0, pa, _ = profiled_synthesis(
model, prompt_text, obs_loc, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=True,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode,
records=dummy_rec, prefix="policy")
for idx in range(len(pa[0])):
oa = {'action': pa[0][idx:idx + 1]}
oa['action'][:, ori_action_dim:] = 0.0
populate_queues(cond_obs_queues, oa)
# Re-stack for world model
obs_loc2 = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
# World model pass
pv1, _, ps = profiled_synthesis(
model, "", obs_loc2, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=True, decode_video=True,
records=dummy_rec, prefix="wm")
# Update obs queue
for idx in range(args.exe_steps):
ou = {
'observation.images.top':
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state': ps[0][idx:idx + 1],
'action': torch.zeros_like(pa[0][-1:]),
}
ou['observation.state'][:, ori_state_dim:] = 0.0
populate_queues(cond_obs_queues, ou)
# Save results (captures CPU stall in trace)
if pv0 is not None:
save_results(pv0.cpu(),
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
fps=args.save_fps)
save_results(pv1.cpu(),
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
fps=args.save_fps)
prof.step()
print(f">>> Trace saved to {trace_dir}")
print(" View with: tensorboard --logdir", trace_dir)
print(" Or open the .json file in chrome://tracing")
# ──────────────────────────────────────────────────────────────────────
# Argument parser
# ──────────────────────────────────────────────────────────────────────
def get_parser():
p = argparse.ArgumentParser(description="Profile full iteration loop")
# Compare mode (no model needed)
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
help="Compare two summary CSVs and exit")
# Model / data
p.add_argument("--ckpt_path", type=str, default=None)
p.add_argument("--config", type=str, default=None)
p.add_argument("--prompt_dir", type=str, default=None)
p.add_argument("--dataset", type=str, default=None)
p.add_argument("--savedir", type=str, default="profile_output")
# Inference params (match world_model_interaction.py)
p.add_argument("--ddim_steps", type=int, default=50)
p.add_argument("--ddim_eta", type=float, default=1.0)
p.add_argument("--bs", type=int, default=1)
p.add_argument("--height", type=int, default=320)
p.add_argument("--width", type=int, default=512)
p.add_argument("--frame_stride", type=int, default=4)
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
p.add_argument("--video_length", type=int, default=16)
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
p.add_argument("--guidance_rescale", type=float, default=0.7)
p.add_argument("--exe_steps", type=int, default=16)
p.add_argument("--n_iter", type=int, default=5)
p.add_argument("--save_fps", type=int, default=8)
p.add_argument("--seed", type=int, default=123)
p.add_argument("--perframe_ae", action='store_true', default=False)
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
# Profiling control
p.add_argument("--warmup", type=int, default=1,
help="Number of warmup iterations to skip in statistics")
p.add_argument("--csv", type=str, default=None,
help="Write per-iteration timing to this CSV file")
p.add_argument("--trace", action='store_true', default=False,
help="Enable Layer 2: GPU timeline trace")
p.add_argument("--trace_dir", type=str, default="./profile_traces",
help="Directory for trace output")
return p
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main():
patch_norm_bypass_autocast()
parser = get_parser()
args = parser.parse_args()
# ── Compare mode: no model needed ──
if args.compare:
compare_csvs(args.compare[0], args.compare[1])
return
# ── Validate required args ──
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
if getattr(args, required) is None:
parser.error(f"--{required} is required for profiling mode")
seed_everything(args.seed)
os.makedirs(args.savedir, exist_ok=True)
# ── Load model ──
print("=" * 60)
print("PROFILE ITERATION — Loading model...")
print("=" * 60)
model, config = load_model(args)
device = next(model.parameters()).device
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
noise_shape = [args.bs, channels, args.video_length, h, w]
print(f">>> Noise shape: {noise_shape}")
print(f">>> DDIM steps: {args.ddim_steps}")
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
# ── Layer 2: GPU trace (optional) ──
if args.trace:
with torch.no_grad():
run_with_trace(model, args, config, noise_shape, device)
print()
# ── Layer 1: Iteration-level breakdown ──
print("=" * 60)
print("LAYER 1: ITERATION-LEVEL PROFILING")
print("=" * 60)
with torch.no_grad():
all_records = run_profiled_iterations(
model, args, config, noise_shape, device)
# Print report
print_iteration_report(all_records, warmup=args.warmup)
# ── Layer 3: CSV output for A/B comparison ──
if args.csv:
write_csv(all_records, args.csv, warmup=args.warmup)
print("Done.")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,5 @@
#\!/bin/bash
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
dataset="unitree_z1_dual_arm_cleanup_pencils"
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"