添加三层迭代级性能分析工具 profile_iteration.py
Layer1: CUDA Events 精确测量每个itr内10个阶段耗时 Layer2: torch.profiler GPU timeline trace Layer3: CSV输出支持A/B对比 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
975
scripts/evaluation/profile_iteration.py
Normal file
975
scripts/evaluation/profile_iteration.py
Normal file
@@ -0,0 +1,975 @@
|
||||
"""
|
||||
Profile the full iteration loop of world model interaction.
|
||||
|
||||
Three layers of profiling:
|
||||
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
|
||||
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
|
||||
Layer 3: A/B comparison (standardized CSV output)
|
||||
|
||||
Usage:
|
||||
# Layer 1 only (fast, default):
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python scripts/evaluation/profile_iteration.py \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
|
||||
--dataset unitree_z1_dual_arm_cleanup_pencils \
|
||||
--frame_stride 4 --n_iter 5
|
||||
|
||||
# Layer 1 + Layer 2 (GPU trace):
|
||||
... --trace --trace_dir ./profile_traces
|
||||
|
||||
# Layer 3 (A/B comparison): run twice, diff the CSVs
|
||||
... --csv baseline.csv
|
||||
... --csv optimized.csv
|
||||
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from contextlib import nullcontext
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import Tensor
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Constants
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
STAGE_NAMES = [
|
||||
"stack_to_device_1",
|
||||
"synth_policy",
|
||||
"update_action_queue",
|
||||
"stack_to_device_2",
|
||||
"synth_world_model",
|
||||
"update_obs_queue",
|
||||
"tensorboard_log",
|
||||
"save_results",
|
||||
"cpu_transfer",
|
||||
"itr_total",
|
||||
]
|
||||
|
||||
# Sub-stages inside image_guided_synthesis_sim_mode
|
||||
SYNTH_SUB_STAGES = [
|
||||
"ddim_sampler_init",
|
||||
"image_embedding",
|
||||
"vae_encode",
|
||||
"text_conditioning",
|
||||
"projectors",
|
||||
"cond_assembly",
|
||||
"ddim_sampling",
|
||||
"vae_decode",
|
||||
]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# CudaTimer — GPU-precise timing via CUDA events
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
class CudaTimer:
|
||||
"""Context manager that records GPU time between enter/exit using CUDA events."""
|
||||
|
||||
def __init__(self, name, records):
|
||||
self.name = name
|
||||
self.records = records
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self._start = torch.cuda.Event(enable_timing=True)
|
||||
self._end = torch.cuda.Event(enable_timing=True)
|
||||
self._start.record()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self._end.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = self._start.elapsed_time(self._end)
|
||||
self.records[self.name].append(elapsed_ms)
|
||||
|
||||
|
||||
class WallTimer:
|
||||
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
|
||||
|
||||
def __init__(self, name, records):
|
||||
self.name = name
|
||||
self.records = records
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self._t0 = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
|
||||
self.records[self.name].append(elapsed_ms)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Model loading (reused from world_model_interaction.py)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def patch_norm_bypass_autocast():
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
|
||||
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
|
||||
from collections import OrderedDict
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
except Exception:
|
||||
new_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_sd[k] = v
|
||||
for k in list(new_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
|
||||
model.load_state_dict(new_sd, strict=True)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
|
||||
model.model.to(torch.bfloat16)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
model.embedder.to(torch.bfloat16)
|
||||
model.image_proj_model.to(torch.bfloat16)
|
||||
model.encoder_autocast_dtype = None
|
||||
model.state_projector.to(torch.bfloat16)
|
||||
model.action_projector.to(torch.bfloat16)
|
||||
model.projector_autocast_dtype = None
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(torch.bfloat16)
|
||||
|
||||
# Compile hot ResBlocks
|
||||
apply_torch_compile(model)
|
||||
model = model.cuda()
|
||||
print(">>> Model loaded and ready.")
|
||||
return model, config
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Data preparation (reused from world_model_interaction.py)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def get_init_frame_path(data_dir, sample):
|
||||
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
|
||||
return os.path.join(data_dir, 'images', rel)
|
||||
|
||||
|
||||
def get_transition_path(data_dir, sample):
|
||||
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
|
||||
return os.path.join(data_dir, 'transitions', rel)
|
||||
|
||||
|
||||
def prepare_init_input(start_idx, init_frame_path, transition_dict,
|
||||
frame_stride, wma_data, video_length=16, n_obs_steps=2):
|
||||
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
||||
init_frame = Image.open(init_frame_path).convert('RGB')
|
||||
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
|
||||
|
||||
if start_idx < n_obs_steps - 1:
|
||||
state_indices = list(range(0, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
num_padding = n_obs_steps - 1 - start_idx
|
||||
padding = states[0:1, :].repeat(num_padding, 1)
|
||||
states = torch.cat((padding, states), dim=0)
|
||||
else:
|
||||
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
|
||||
actions = transition_dict['action'][indices, :]
|
||||
ori_state_dim = states.shape[-1]
|
||||
ori_action_dim = actions.shape[-1]
|
||||
|
||||
frames_action_state_dict = {
|
||||
'action': actions,
|
||||
'observation.state': states,
|
||||
}
|
||||
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
||||
frames_action_state_dict = wma_data.get_uni_vec(
|
||||
frames_action_state_dict,
|
||||
transition_dict['action_type'],
|
||||
transition_dict['state_type'],
|
||||
)
|
||||
|
||||
if wma_data.spatial_transform is not None:
|
||||
init_frame = wma_data.spatial_transform(init_frame)
|
||||
init_frame = (init_frame / 255 - 0.5) * 2
|
||||
|
||||
data = {'observation.image': init_frame}
|
||||
data.update(frames_action_state_dict)
|
||||
return data, ori_state_dim, ori_action_dim
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def get_latent_z(model, videos):
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||
x = x.to(dtype=vae_dtype)
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
def save_results(video, filename, fps=8):
|
||||
video = video.detach().cpu()
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename, grid, fps=fps,
|
||||
video_codec='h264', options={'crf': '10'})
|
||||
|
||||
|
||||
def profiled_synthesis(model, prompts, observation, noise_shape,
|
||||
ddim_steps, ddim_eta, unconditional_guidance_scale,
|
||||
fs, text_input, timestep_spacing, guidance_rescale,
|
||||
sim_mode, decode_video, records, prefix):
|
||||
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
|
||||
|
||||
Args:
|
||||
prefix: "policy" or "wm" — prepended to sub-stage names in records.
|
||||
"""
|
||||
b, _, t, _, _ = noise_shape
|
||||
batch_size = noise_shape[0]
|
||||
device = next(model.parameters()).device
|
||||
|
||||
# --- sub-stage: ddim_sampler_init ---
|
||||
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
|
||||
|
||||
# --- sub-stage: image_embedding ---
|
||||
with CudaTimer(f"{prefix}/image_embedding", records):
|
||||
model_dtype = next(model.embedder.parameters()).dtype
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
# --- sub-stage: vae_encode ---
|
||||
with CudaTimer(f"{prefix}/vae_encode", records):
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=noise_shape[2])
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
# --- sub-stage: text_conditioning ---
|
||||
with CudaTimer(f"{prefix}/text_conditioning", records):
|
||||
if not text_input:
|
||||
prompts_use = [""] * batch_size
|
||||
else:
|
||||
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts_use)
|
||||
|
||||
# --- sub-stage: projectors ---
|
||||
with CudaTimer(f"{prefix}/projectors", records):
|
||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||
cond_state_emb = model.state_projector(
|
||||
observation['observation.state'].to(dtype=projector_dtype))
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
cond_action_emb = model.action_projector(
|
||||
observation['action'].to(dtype=projector_dtype))
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
if not sim_mode:
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
|
||||
# --- sub-stage: cond_assembly ---
|
||||
with CudaTimer(f"{prefix}/cond_assembly", records):
|
||||
cond["c_crossattn"] = [
|
||||
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||
]
|
||||
cond["c_crossattn_action"] = [
|
||||
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
|
||||
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
||||
sim_mode,
|
||||
False,
|
||||
]
|
||||
|
||||
# --- sub-stage: ddim_sampling ---
|
||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||
if autocast_dtype is not None and device.type == 'cuda':
|
||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
with CudaTimer(f"{prefix}/ddim_sampling", records):
|
||||
with autocast_ctx:
|
||||
samples, actions, states, _ = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=None,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=None,
|
||||
x0=None,
|
||||
fs=fs_t,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
unconditional_conditioning_img_nonetext=None,
|
||||
)
|
||||
|
||||
# --- sub-stage: vae_decode ---
|
||||
batch_variants = None
|
||||
if decode_video:
|
||||
with CudaTimer(f"{prefix}/vae_decode", records):
|
||||
batch_variants = model.decode_first_stage(samples)
|
||||
else:
|
||||
records[f"{prefix}/vae_decode"].append(0.0)
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Instrumented iteration loop
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def run_profiled_iterations(model, args, config, noise_shape, device):
|
||||
"""Run the full iteration loop with per-stage timing.
|
||||
|
||||
Returns:
|
||||
all_records: list of dicts, one per itr, {stage_name: ms}
|
||||
"""
|
||||
# Load data
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
sample = df.iloc[0]
|
||||
|
||||
data_module = instantiate_from_config(config.data)
|
||||
data_module.setup()
|
||||
|
||||
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||
ori_fps = float(sample['fps'])
|
||||
fs = args.frame_stride
|
||||
model_input_fs = ori_fps // fs
|
||||
|
||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
# Prepare initial observation
|
||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||
0, init_frame_path, transition_dict, fs,
|
||||
data_module.test_datasets[args.dataset],
|
||||
n_obs_steps=model.n_obs_steps_imagen)
|
||||
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||
'observation.state':
|
||||
batch['observation.state'][-1].unsqueeze(0),
|
||||
'action':
|
||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
cond_obs_queues = {
|
||||
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"action": deque(maxlen=args.video_length),
|
||||
}
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
# Temp dir for save_results profiling
|
||||
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
prompt_text = sample['instruction']
|
||||
all_records = []
|
||||
|
||||
print(f">>> Running {args.n_iter} profiled iterations ...")
|
||||
for itr in range(args.n_iter):
|
||||
rec = defaultdict(list)
|
||||
|
||||
# ── itr_total start ──
|
||||
torch.cuda.synchronize()
|
||||
itr_start = torch.cuda.Event(enable_timing=True)
|
||||
itr_end = torch.cuda.Event(enable_timing=True)
|
||||
itr_start.record()
|
||||
|
||||
# ① stack_to_device_1
|
||||
with CudaTimer("stack_to_device_1", rec):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
# ② synth_policy
|
||||
with CudaTimer("synth_policy", rec):
|
||||
pred_videos_0, pred_actions, _ = profiled_synthesis(
|
||||
model, prompt_text, observation, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=True,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode,
|
||||
records=rec, prefix="policy")
|
||||
|
||||
# ③ update_action_queue
|
||||
with WallTimer("update_action_queue", rec):
|
||||
for idx in range(len(pred_actions[0])):
|
||||
obs_a = {'action': pred_actions[0][idx:idx + 1]}
|
||||
obs_a['action'][:, ori_action_dim:] = 0.0
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
|
||||
|
||||
# ④ stack_to_device_2
|
||||
with CudaTimer("stack_to_device_2", rec):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
# ⑤ synth_world_model
|
||||
with CudaTimer("synth_world_model", rec):
|
||||
pred_videos_1, _, pred_states = profiled_synthesis(
|
||||
model, "", observation, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=True, decode_video=True,
|
||||
records=rec, prefix="wm")
|
||||
|
||||
# ⑥ update_obs_queue
|
||||
with WallTimer("update_obs_queue", rec):
|
||||
for idx in range(args.exe_steps):
|
||||
obs_u = {
|
||||
'observation.images.top':
|
||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
'observation.state':
|
||||
pred_states[0][idx:idx + 1],
|
||||
'action':
|
||||
torch.zeros_like(pred_actions[0][-1:]),
|
||||
}
|
||||
obs_u['observation.state'][:, ori_state_dim:] = 0.0
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
|
||||
|
||||
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
|
||||
with WallTimer("tensorboard_log", rec):
|
||||
for vid in [pred_videos_0, pred_videos_1]:
|
||||
if vid is not None and vid.dim() == 5:
|
||||
v = vid.permute(2, 0, 1, 3, 4)
|
||||
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
|
||||
_ = torch.stack(grids, dim=0)
|
||||
|
||||
# ⑧ save_results
|
||||
with WallTimer("save_results", rec):
|
||||
if pred_videos_0 is not None:
|
||||
save_results(pred_videos_0.cpu(),
|
||||
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
|
||||
fps=args.save_fps)
|
||||
save_results(pred_videos_1.cpu(),
|
||||
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
|
||||
fps=args.save_fps)
|
||||
|
||||
# ⑨ cpu_transfer
|
||||
with CudaTimer("cpu_transfer", rec):
|
||||
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
|
||||
|
||||
# ── itr_total end ──
|
||||
itr_end.record()
|
||||
torch.cuda.synchronize()
|
||||
itr_total_ms = itr_start.elapsed_time(itr_end)
|
||||
rec["itr_total"].append(itr_total_ms)
|
||||
|
||||
# Flatten: each stage has exactly one entry per itr
|
||||
itr_rec = {k: v[0] for k, v in rec.items()}
|
||||
all_records.append(itr_rec)
|
||||
|
||||
# Print live progress
|
||||
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
|
||||
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
|
||||
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
|
||||
f"save={itr_rec.get('save_results', 0):.0f} | "
|
||||
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
|
||||
|
||||
return all_records
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Layer 1: Console report
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def print_iteration_report(all_records, warmup=1):
|
||||
"""Print a structured table of per-stage timing across iterations."""
|
||||
if len(all_records) <= warmup:
|
||||
records = all_records
|
||||
else:
|
||||
records = all_records[warmup:]
|
||||
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
|
||||
|
||||
# Collect all stage keys in a stable order
|
||||
all_keys = []
|
||||
seen = set()
|
||||
for rec in records:
|
||||
for k in rec:
|
||||
if k not in seen:
|
||||
all_keys.append(k)
|
||||
seen.add(k)
|
||||
|
||||
# Separate top-level stages from sub-stages
|
||||
top_keys = [k for k in all_keys if '/' not in k]
|
||||
sub_keys = [k for k in all_keys if '/' in k]
|
||||
|
||||
def _print_table(keys, title):
|
||||
if not keys:
|
||||
return
|
||||
print("=" * 82)
|
||||
print(title)
|
||||
print("=" * 82)
|
||||
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
|
||||
print("-" * 82)
|
||||
|
||||
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
|
||||
for k in keys:
|
||||
vals = [rec.get(k, 0) for rec in records]
|
||||
mean = np.mean(vals)
|
||||
std = np.std(vals)
|
||||
mn = np.min(vals)
|
||||
mx = np.max(vals)
|
||||
pct = mean / total_mean * 100 if total_mean > 0 else 0
|
||||
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
|
||||
print("-" * 82)
|
||||
print()
|
||||
|
||||
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
|
||||
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Layer 3: CSV output for A/B comparison
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def write_csv(all_records, csv_path, warmup=1):
|
||||
"""Write per-iteration timing to CSV for later comparison."""
|
||||
records = all_records[warmup:] if len(all_records) > warmup else all_records
|
||||
|
||||
# Collect all keys
|
||||
all_keys = []
|
||||
seen = set()
|
||||
for rec in records:
|
||||
for k in rec:
|
||||
if k not in seen:
|
||||
all_keys.append(k)
|
||||
seen.add(k)
|
||||
|
||||
with open(csv_path, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
|
||||
writer.writeheader()
|
||||
for i, rec in enumerate(records):
|
||||
row = {'itr': i}
|
||||
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
|
||||
writer.writerow(row)
|
||||
|
||||
# Also write a summary row
|
||||
summary_path = csv_path.replace('.csv', '_summary.csv')
|
||||
with open(summary_path, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
|
||||
writer.writeheader()
|
||||
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
|
||||
('min', np.min), ('max', np.max)]:
|
||||
row = {'stat': stat_name}
|
||||
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
|
||||
for k in all_keys})
|
||||
writer.writerow(row)
|
||||
|
||||
print(f">>> CSV written to: {csv_path}")
|
||||
print(f">>> Summary written to: {summary_path}")
|
||||
|
||||
|
||||
def compare_csvs(path_a, path_b):
|
||||
"""Compare two summary CSVs and print a diff table."""
|
||||
df_a = pd.read_csv(path_a, index_col='stat')
|
||||
df_b = pd.read_csv(path_b, index_col='stat')
|
||||
|
||||
# Use mean row for comparison
|
||||
mean_a = df_a.loc['mean'].astype(float)
|
||||
mean_b = df_b.loc['mean'].astype(float)
|
||||
|
||||
print("=" * 90)
|
||||
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
|
||||
print("=" * 90)
|
||||
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
|
||||
print("-" * 90)
|
||||
|
||||
for col in mean_a.index:
|
||||
if col not in mean_b.index:
|
||||
continue
|
||||
a_val = mean_a[col]
|
||||
b_val = mean_b[col]
|
||||
diff = b_val - a_val
|
||||
speedup = a_val / b_val if b_val > 0 else float('inf')
|
||||
marker = " <<<" if abs(diff) > 50 else ""
|
||||
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
|
||||
|
||||
print("-" * 90)
|
||||
total_a = mean_a.get('itr_total', 0)
|
||||
total_b = mean_b.get('itr_total', 0)
|
||||
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
|
||||
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
|
||||
print()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Layer 2: GPU timeline trace wrapper
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def run_with_trace(model, args, config, noise_shape, device):
|
||||
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
|
||||
trace_dir = args.trace_dir
|
||||
os.makedirs(trace_dir, exist_ok=True)
|
||||
|
||||
# We need the same data setup as run_profiled_iterations
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
sample = df.iloc[0]
|
||||
|
||||
data_module = instantiate_from_config(config.data)
|
||||
data_module.setup()
|
||||
|
||||
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||
ori_fps = float(sample['fps'])
|
||||
fs = args.frame_stride
|
||||
model_input_fs = ori_fps // fs
|
||||
|
||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||
0, init_frame_path, transition_dict, fs,
|
||||
data_module.test_datasets[args.dataset],
|
||||
n_obs_steps=model.n_obs_steps_imagen)
|
||||
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||
'observation.state':
|
||||
batch['observation.state'][-1].unsqueeze(0),
|
||||
'action':
|
||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
cond_obs_queues = {
|
||||
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"action": deque(maxlen=args.video_length),
|
||||
}
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
prompt_text = sample['instruction']
|
||||
|
||||
# Total iterations: warmup + active
|
||||
n_warmup = 1
|
||||
n_active = min(args.n_iter, 2) # trace 2 active iterations max
|
||||
n_total = n_warmup + n_active
|
||||
|
||||
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
|
||||
print(f">>> Trace output: {trace_dir}")
|
||||
|
||||
with torch.no_grad(), torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=0, warmup=n_warmup, active=n_active, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
for itr_idx in range(n_total):
|
||||
phase = "warmup" if itr_idx < n_warmup else "active"
|
||||
print(f" trace itr {itr_idx} ({phase})...")
|
||||
|
||||
# ── One full iteration (same logic as run_inference) ──
|
||||
obs_loc = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
|
||||
|
||||
# Policy pass
|
||||
dummy_rec = defaultdict(list)
|
||||
pv0, pa, _ = profiled_synthesis(
|
||||
model, prompt_text, obs_loc, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=True,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode,
|
||||
records=dummy_rec, prefix="policy")
|
||||
|
||||
for idx in range(len(pa[0])):
|
||||
oa = {'action': pa[0][idx:idx + 1]}
|
||||
oa['action'][:, ori_action_dim:] = 0.0
|
||||
populate_queues(cond_obs_queues, oa)
|
||||
|
||||
# Re-stack for world model
|
||||
obs_loc2 = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
|
||||
|
||||
# World model pass
|
||||
pv1, _, ps = profiled_synthesis(
|
||||
model, "", obs_loc2, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=True, decode_video=True,
|
||||
records=dummy_rec, prefix="wm")
|
||||
|
||||
# Update obs queue
|
||||
for idx in range(args.exe_steps):
|
||||
ou = {
|
||||
'observation.images.top':
|
||||
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
'observation.state': ps[0][idx:idx + 1],
|
||||
'action': torch.zeros_like(pa[0][-1:]),
|
||||
}
|
||||
ou['observation.state'][:, ori_state_dim:] = 0.0
|
||||
populate_queues(cond_obs_queues, ou)
|
||||
|
||||
# Save results (captures CPU stall in trace)
|
||||
if pv0 is not None:
|
||||
save_results(pv0.cpu(),
|
||||
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
|
||||
fps=args.save_fps)
|
||||
save_results(pv1.cpu(),
|
||||
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
|
||||
fps=args.save_fps)
|
||||
|
||||
prof.step()
|
||||
|
||||
print(f">>> Trace saved to {trace_dir}")
|
||||
print(" View with: tensorboard --logdir", trace_dir)
|
||||
print(" Or open the .json file in chrome://tracing")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Argument parser
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def get_parser():
|
||||
p = argparse.ArgumentParser(description="Profile full iteration loop")
|
||||
|
||||
# Compare mode (no model needed)
|
||||
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
|
||||
help="Compare two summary CSVs and exit")
|
||||
|
||||
# Model / data
|
||||
p.add_argument("--ckpt_path", type=str, default=None)
|
||||
p.add_argument("--config", type=str, default=None)
|
||||
p.add_argument("--prompt_dir", type=str, default=None)
|
||||
p.add_argument("--dataset", type=str, default=None)
|
||||
p.add_argument("--savedir", type=str, default="profile_output")
|
||||
|
||||
# Inference params (match world_model_interaction.py)
|
||||
p.add_argument("--ddim_steps", type=int, default=50)
|
||||
p.add_argument("--ddim_eta", type=float, default=1.0)
|
||||
p.add_argument("--bs", type=int, default=1)
|
||||
p.add_argument("--height", type=int, default=320)
|
||||
p.add_argument("--width", type=int, default=512)
|
||||
p.add_argument("--frame_stride", type=int, default=4)
|
||||
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
|
||||
p.add_argument("--video_length", type=int, default=16)
|
||||
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
|
||||
p.add_argument("--guidance_rescale", type=float, default=0.7)
|
||||
p.add_argument("--exe_steps", type=int, default=16)
|
||||
p.add_argument("--n_iter", type=int, default=5)
|
||||
p.add_argument("--save_fps", type=int, default=8)
|
||||
p.add_argument("--seed", type=int, default=123)
|
||||
p.add_argument("--perframe_ae", action='store_true', default=False)
|
||||
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
|
||||
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
|
||||
|
||||
# Profiling control
|
||||
p.add_argument("--warmup", type=int, default=1,
|
||||
help="Number of warmup iterations to skip in statistics")
|
||||
p.add_argument("--csv", type=str, default=None,
|
||||
help="Write per-iteration timing to this CSV file")
|
||||
p.add_argument("--trace", action='store_true', default=False,
|
||||
help="Enable Layer 2: GPU timeline trace")
|
||||
p.add_argument("--trace_dir", type=str, default="./profile_traces",
|
||||
help="Directory for trace output")
|
||||
|
||||
return p
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
patch_norm_bypass_autocast()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# ── Compare mode: no model needed ──
|
||||
if args.compare:
|
||||
compare_csvs(args.compare[0], args.compare[1])
|
||||
return
|
||||
|
||||
# ── Validate required args ──
|
||||
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
|
||||
if getattr(args, required) is None:
|
||||
parser.error(f"--{required} is required for profiling mode")
|
||||
|
||||
seed_everything(args.seed)
|
||||
os.makedirs(args.savedir, exist_ok=True)
|
||||
|
||||
# ── Load model ──
|
||||
print("=" * 60)
|
||||
print("PROFILE ITERATION — Loading model...")
|
||||
print("=" * 60)
|
||||
model, config = load_model(args)
|
||||
device = next(model.parameters()).device
|
||||
|
||||
h, w = args.height // 8, args.width // 8
|
||||
channels = model.model.diffusion_model.out_channels
|
||||
noise_shape = [args.bs, channels, args.video_length, h, w]
|
||||
print(f">>> Noise shape: {noise_shape}")
|
||||
print(f">>> DDIM steps: {args.ddim_steps}")
|
||||
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
|
||||
|
||||
# ── Layer 2: GPU trace (optional) ──
|
||||
if args.trace:
|
||||
with torch.no_grad():
|
||||
run_with_trace(model, args, config, noise_shape, device)
|
||||
print()
|
||||
|
||||
# ── Layer 1: Iteration-level breakdown ──
|
||||
print("=" * 60)
|
||||
print("LAYER 1: ITERATION-LEVEL PROFILING")
|
||||
print("=" * 60)
|
||||
with torch.no_grad():
|
||||
all_records = run_profiled_iterations(
|
||||
model, args, config, noise_shape, device)
|
||||
|
||||
# Print report
|
||||
print_iteration_report(all_records, warmup=args.warmup)
|
||||
|
||||
# ── Layer 3: CSV output for A/B comparison ──
|
||||
if args.csv:
|
||||
write_csv(all_records, args.csv, warmup=args.warmup)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
@@ -0,0 +1,5 @@
|
||||
#\!/bin/bash
|
||||
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"
|
||||
Reference in New Issue
Block a user