多卡流水导出

This commit is contained in:
qhy
2026-05-17 15:05:30 +08:00
parent 9d2d57d96b
commit afd90e59fe
6 changed files with 1787 additions and 1611 deletions

File diff suppressed because it is too large Load Diff

209
scripts/export_trt.py Normal file
View File

@@ -0,0 +1,209 @@
"""Export video UNet backbone to ONNX, then convert to TensorRT engine.
Usage:
python scripts/export_trt.py \
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
--config configs/inference/world_model_interaction.yaml \
--out_dir trt_engines
python scripts/export_trt.py \
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
--config configs/inference/world_model_interaction.yaml \
--engine_path trt_engines/video_backbone_multigpu.engine \
--onnx_path trt_engines/video_backbone_multigpu.onnx
"""
import os
import sys
import argparse
import json
import torch
import tensorrt as trt
from omegaconf import OmegaConf
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.trt_utils import export_backbone_onnx
class TacticRecorder(trt.IAlgorithmSelector):
"""Pass 1: record all candidate tactics and the auto-selected winner."""
def __init__(self):
super().__init__()
self.records = {} # layer_name -> {candidates: [...], selected: ...}
def select_algorithms(self, ctx, choices):
name = ctx.name
# Collect input/output shapes
inputs = []
for j in range(ctx.num_inputs):
try:
inputs.append([int(d) for d in ctx.get_shape(j)])
except Exception:
inputs.append(None)
outputs = []
for j in range(ctx.num_outputs):
try:
outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)])
except Exception:
outputs.append(None)
self.records[name] = {
"input_shapes": inputs,
"output_shapes": outputs,
"candidates": [],
"selected": None,
}
for i, c in enumerate(choices):
v = c.algorithm_variant
self.records[name]["candidates"].append({
"index": i,
"implementation": v.implementation,
"tactic": v.tactic,
"timing_msec": c.timing_msec,
"workspace_size": c.workspace_size,
})
# return all indices -> let TRT auto-pick the fastest
return list(range(len(choices)))
def report_algorithms(self, ctx, choices):
# Both ctx and choices are lists in report_algorithms
for c, alg in zip(ctx, choices):
name = c.name
if name in self.records:
v = alg.algorithm_variant
self.records[name]["selected"] = {
"implementation": v.implementation,
"tactic": v.tactic,
"timing_msec": alg.timing_msec,
"workspace_size": alg.workspace_size,
}
def save(self, path):
with open(path, "w") as f:
json.dump(self.records, f, indent=2)
print(f">>> Tactic info saved to {path} ({len(self.records)} layers)")
class TacticForcer(trt.IAlgorithmSelector):
"""Pass 2: force user-specified tactics from a JSON file."""
def __init__(self, path):
super().__init__()
with open(path) as f:
self.overrides = json.load(f)
n = sum(1 for v in self.overrides.values() if v.get("force"))
print(f">>> Loaded tactic overrides: {n} layers with 'force' set")
def select_algorithms(self, ctx, choices):
name = ctx.name
override = self.overrides.get(name)
if override and override.get("force"):
target_impl = override["force"]["implementation"]
target_tactic = override["force"]["tactic"]
for i, c in enumerate(choices):
v = c.algorithm_variant
if v.implementation == target_impl and v.tactic == target_tactic:
return [i]
print(f" WARN: forced tactic not found for {name}, using auto")
return list(range(len(choices)))
def report_algorithms(self, ctx, choices):
pass
def load_model(config_path, ckpt_path, device):
if ckpt_path.endswith('.prepared.pt'):
model = torch.load(ckpt_path, map_location='cpu')
else:
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict, strict=False)
model.eval().to(device)
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', required=True)
parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml')
parser.add_argument('--out_dir', default='trt_engines')
parser.add_argument('--gpu_id',
type=int,
default=0,
help='CUDA device id used for ONNX export and TRT build.')
parser.add_argument('--onnx_path',
default=None,
help='Optional explicit ONNX output path. Overrides --out_dir default name.')
parser.add_argument('--engine_path',
default=None,
help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.')
parser.add_argument('--context_len', type=int, default=95)
parser.add_argument('--fp16', action='store_true', default=True)
parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON')
parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON')
args = parser.parse_args()
device = torch.device('cuda', args.gpu_id)
torch.cuda.set_device(device)
onnx_path = args.onnx_path or os.path.join(args.out_dir,
'video_backbone.onnx')
engine_path = args.engine_path or os.path.join(args.out_dir,
'video_backbone.engine')
os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True)
os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True)
if os.path.exists(onnx_path):
print(f">>> ONNX already exists at {onnx_path}, skipping export.")
n_outputs = 10
else:
print(">>> Loading model ...")
model = load_model(args.config, args.ckpt, device)
print(">>> Exporting ONNX ...")
with torch.no_grad():
n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len)
del model
torch.cuda.empty_cache()
print(">>> Converting ONNX -> TensorRT engine ...")
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(os.path.abspath(onnx_path)):
for i in range(parser.num_errors):
print(f" ONNX parse error: {parser.get_error(i)}")
raise RuntimeError("ONNX parsing failed")
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
if args.fp16:
config.set_flag(trt.BuilderFlag.FP16)
# Tactic selection
recorder = None
if args.dump_tactics:
recorder = TacticRecorder()
config.algorithm_selector = recorder
elif args.load_tactics:
config.algorithm_selector = TacticForcer(args.load_tactics)
engine_bytes = builder.build_serialized_network(network, config)
if recorder and args.dump_tactics:
recorder.save(args.dump_tactics)
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
print(f"\n>>> Done! Engine saved to {engine_path}")
print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors")
if __name__ == '__main__':
main()

View File

@@ -1,12 +1,13 @@
import numpy as np import numpy as np
import torch import torch
import copy import copy
import time
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
class DDIMSampler(object): class DDIMSampler(object):
@@ -20,8 +21,9 @@ class DDIMSampler(object):
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): target_device = self.model.device
attr = attr.to(torch.device("cuda")) if attr.device != target_device:
attr = attr.to(target_device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, def make_schedule(self,
@@ -68,11 +70,12 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, eta=ddim_eta,
verbose=verbose) verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas) # Ensure tensors are on correct device for efficient indexing
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
self.register_buffer('ddim_sqrt_one_minus_alphas', self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas)) to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev)) (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -107,17 +110,9 @@ class DDIMSampler(object):
fs=None, fs=None,
timestep_spacing='uniform', #uniform_trailing for starting from last timestep timestep_spacing='uniform', #uniform_trailing for starting from last timestep
guidance_rescale=0.0, guidance_rescale=0.0,
action_T=None, handoff_step: int = 0,
state_T=None, handoff_callback=None,
record_step_outputs=False, stop_at_handoff: bool = False,
head_schedule=None,
head_log_steps=None,
head_skip_mode="reuse_prediction",
backbone_reuse_blocks=None,
backbone_reuse_start_step=None,
backbone_reuse_schedule_steps=None,
backbone_reuse_force_compute_steps=None,
backbone_reuse_mode="disabled",
**kwargs): **kwargs):
# Check condition bs # Check condition bs
@@ -173,17 +168,9 @@ class DDIMSampler(object):
precision=precision, precision=precision,
fs=fs, fs=fs,
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
action_T=action_T, handoff_step=handoff_step,
state_T=state_T, handoff_callback=handoff_callback,
record_step_outputs=record_step_outputs, stop_at_handoff=stop_at_handoff,
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_reuse_blocks=backbone_reuse_blocks,
backbone_reuse_start_step=backbone_reuse_start_step,
backbone_reuse_schedule_steps=backbone_reuse_schedule_steps,
backbone_reuse_force_compute_steps=backbone_reuse_force_compute_steps,
backbone_reuse_mode=backbone_reuse_mode,
**kwargs) **kwargs)
return samples, actions, states, intermediates return samples, actions, states, intermediates
@@ -210,44 +197,23 @@ class DDIMSampler(object):
precision=None, precision=None,
fs=None, fs=None,
guidance_rescale=0.0, guidance_rescale=0.0,
action_T=None, handoff_step: int = 0,
state_T=None, handoff_callback=None,
record_step_outputs=False, stop_at_handoff: bool = False,
head_schedule=None,
head_log_steps=None,
head_skip_mode="reuse_prediction",
backbone_reuse_blocks=None,
backbone_reuse_start_step=None,
backbone_reuse_schedule_steps=None,
backbone_reuse_force_compute_steps=None,
backbone_reuse_mode="disabled",
**kwargs): **kwargs):
device = self.model.betas.device device = self.model.betas.device
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0] b = shape[0]
horizon = shape[2] if len(shape) >= 3 else 16
if x_T is None: if x_T is None:
img = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim), device=device)
state = torch.randn((b, 16, self.model.agent_state_dim), device=device)
else: else:
img = x_T img = x_T
if action_T is None: action = torch.randn((b, 16, self.model.agent_action_dim), device=device)
action = torch.randn((b, horizon, self.model.agent_action_dim), state = torch.randn((b, 16, self.model.agent_state_dim), device=device)
device=device)
else:
action = action_T
if state_T is None:
state = torch.randn((b, horizon, self.model.agent_state_dim),
device=device)
else:
state = state_T
if precision is not None:
if precision == 16:
img = img.to(dtype=torch.float16)
action = action.to(dtype=torch.float16)
state = state.to(dtype=torch.float16)
if timesteps is None: if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
@@ -265,40 +231,6 @@ class DDIMSampler(object):
'x_inter_state': [state], 'x_inter_state': [state],
'pred_x0_state': [state], 'pred_x0_state': [state],
} }
if record_step_outputs:
intermediates['analysis_init'] = {
'img': img.detach().cpu(),
'action': action.detach().cpu(),
'state': state.detach().cpu(),
}
intermediates['step_records'] = []
head_schedule_set = None if head_schedule is None else {
int(step_index) for step_index in head_schedule
}
head_log_steps_set = None if head_log_steps is None else {
int(step_index) for step_index in head_log_steps
}
if head_log_steps_set is not None:
intermediates['head_sparse_logs'] = {}
backbone_reuse_blocks_set = None if backbone_reuse_blocks is None else {
str(block_name) for block_name in backbone_reuse_blocks
}
backbone_reuse_schedule_steps_set = (
None if backbone_reuse_schedule_steps is None else {
int(step_index) for step_index in backbone_reuse_schedule_steps
})
backbone_reuse_force_compute_steps_set = (
None if backbone_reuse_force_compute_steps is None else {
int(step_index)
for step_index in backbone_reuse_force_compute_steps
})
backbone_reuse_active = (backbone_reuse_mode == "reuse_output"
and backbone_reuse_blocks_set)
backbone_reuse_cache = {
'single': {},
'cond': {},
'uncond': {},
}
time_range = reversed(range( time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
@@ -309,27 +241,16 @@ class DDIMSampler(object):
iterator = time_range iterator = time_range
clean_cond = kwargs.pop("clean_cond", False) clean_cond = kwargs.pop("clean_cond", False)
sync_device = device if isinstance(device, torch.device) else torch.device(
device)
should_sync = record_step_outputs and sync_device.type == "cuda"
if head_skip_mode not in {"reuse_prediction", "freeze_state"}:
raise ValueError(
f"Unsupported head_skip_mode={head_skip_mode!r}. "
"Expected 'reuse_prediction' or 'freeze_state'.")
if backbone_reuse_mode not in {"disabled", "reuse_output"}:
raise ValueError(
f"Unsupported backbone_reuse_mode={backbone_reuse_mode!r}. "
"Expected 'disabled' or 'reuse_output'.")
x_action_frozen = action.detach().clone()
x_state_frozen = state.detach().clone()
action_pred_cache = None
state_pred_cache = None
dp_ddim_scheduler_action.set_timesteps(len(timesteps)) dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
enable_cross_attn_kv_cache(self.model)
enable_ctx_cache(self.model)
handoff = {}
try:
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long) ts.fill_(step)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img) # Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None: if mask is not None:
@@ -340,22 +261,6 @@ class DDIMSampler(object):
img_orig = self.model.q_sample(x0, ts) img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img img = img_orig * mask + (1. - mask) * img
if should_sync:
torch.cuda.synchronize(sync_device)
step_start_time = time.time()
scheduled_head = head_schedule_set is None or i in head_schedule_set
if head_skip_mode == "reuse_prediction":
run_head = scheduled_head or action_pred_cache is None or state_pred_cache is None
else:
run_head = scheduled_head
backbone_reuse_step_stats: dict[str, set[str]] | None = None
if backbone_reuse_active:
backbone_reuse_step_stats = {
'single': set(),
'cond': set(),
'uncond': set(),
}
outs = self.p_sample_ddim( outs = self.p_sample_ddim(
img, img,
action, action,
@@ -363,6 +268,7 @@ class DDIMSampler(object):
cond, cond,
ts, ts,
index=index, index=index,
precision=precision,
use_original_steps=ddim_use_original_steps, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, quantize_denoised=quantize_denoised,
temperature=temperature, temperature=temperature,
@@ -375,23 +281,10 @@ class DDIMSampler(object):
x0=x0, x0=x0,
fs=fs, fs=fs,
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
backbone_step_index=i + 1,
run_head=run_head,
backbone_reuse_blocks=backbone_reuse_blocks_set,
backbone_reuse_start_step=backbone_reuse_start_step,
backbone_reuse_schedule_steps=backbone_reuse_schedule_steps_set,
backbone_reuse_force_compute_steps=
backbone_reuse_force_compute_steps_set,
backbone_reuse_mode=backbone_reuse_mode,
backbone_reuse_cache=backbone_reuse_cache,
backbone_reuse_step_stats=backbone_reuse_step_stats,
**kwargs) **kwargs)
img, pred_x0, model_output_action, model_output_state = outs img, pred_x0, model_output_action, model_output_state = outs
if run_head:
action_pred_cache = model_output_action.detach().clone()
state_pred_cache = model_output_state.detach().clone()
action = dp_ddim_scheduler_action.step( action = dp_ddim_scheduler_action.step(
model_output_action, model_output_action,
step, step,
@@ -404,73 +297,35 @@ class DDIMSampler(object):
state, state,
generator=None, generator=None,
).prev_sample ).prev_sample
x_action_frozen = action.detach().clone()
x_state_frozen = state.detach().clone()
else:
if head_skip_mode == "reuse_prediction":
action = dp_ddim_scheduler_action.step(
action_pred_cache,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
state_pred_cache,
step,
state,
generator=None,
).prev_sample
x_action_frozen = action.detach().clone()
x_state_frozen = state.detach().clone()
else:
action = x_action_frozen
state = x_state_frozen
if should_sync:
torch.cuda.synchronize(sync_device)
step_time_s = time.time() - step_start_time
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
reused_blocks = [] if handoff_step > 0 and (i + 1) == handoff_step:
if backbone_reuse_step_stats is not None: handoff = {
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 'samples': img.clone(),
reused_blocks = sorted(backbone_reuse_step_stats['single']) 'actions': action.clone(),
else: 'states': state.clone(),
reused_blocks = sorted( 'pred_x0': pred_x0.clone(),
backbone_reuse_step_stats['cond'] 'step': i + 1,
| backbone_reuse_step_stats['uncond']) }
if handoff_callback is not None:
handoff_callback(handoff)
if stop_at_handoff:
intermediates['handoff'] = handoff
break
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0) intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action) intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state) intermediates['x_inter_state'].append(state)
if head_log_steps_set is not None and i in head_log_steps_set: finally:
intermediates['head_sparse_logs'][i] = { disable_cross_attn_kv_cache(self.model)
'step_index': i, disable_ctx_cache(self.model)
'ddim_timestep': int(step),
'head_executed': run_head,
'img': img.detach().cpu(),
'pred_x0': pred_x0.detach().cpu(),
'action': action.detach().cpu(),
'state': state.detach().cpu(),
}
if record_step_outputs:
intermediates['step_records'].append({
'step_index': i + 1,
'ddim_timestep': int(step),
'head_executed': run_head,
'img': img.detach().cpu(),
'pred_x0': pred_x0.detach().cpu(),
'action': action.detach().cpu(),
'state': state.detach().cpu(),
'step_time_s': step_time_s,
'backbone_reused_blocks_count': len(reused_blocks),
'backbone_reuse_hit_blocks': ",".join(reused_blocks),
})
if handoff_step > 0:
intermediates['handoff'] = handoff
return img, action, state, intermediates return img, action, state, intermediates
@torch.no_grad() @torch.no_grad()
@@ -481,6 +336,7 @@ class DDIMSampler(object):
c, c,
t, t,
index, index,
precision=None,
repeat_noise=False, repeat_noise=False,
use_original_steps=False, use_original_steps=False,
quantize_denoised=False, quantize_denoised=False,
@@ -495,62 +351,35 @@ class DDIMSampler(object):
mask=None, mask=None,
x0=None, x0=None,
guidance_rescale=0.0, guidance_rescale=0.0,
run_head=True,
**kwargs): **kwargs):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
use_autocast = precision == 16 and device.type == 'cuda'
with torch.cuda.amp.autocast(enabled=use_autocast,
dtype=torch.float16):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model( model_output, model_output_action, model_output_state = self.model.apply_model(
x, x, x_action, x_state, t, c, **kwargs) # unet denoiser
x_action,
x_state,
t,
c,
run_head=run_head,
backbone_reuse_branch="single",
**kwargs) # unet denoiser
else: else:
# do_classifier_free_guidance # do_classifier_free_guidance
if isinstance(c, torch.Tensor) or isinstance(c, dict): if isinstance(c, torch.Tensor) or isinstance(c, dict):
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model( e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
x, x, x_action, x_state, t, c, **kwargs)
x_action,
x_state,
t,
c,
run_head=run_head,
backbone_reuse_branch="cond",
**kwargs)
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model( e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
x, x, x_action, x_state, t, unconditional_conditioning,
x_action,
x_state,
t,
unconditional_conditioning,
run_head=run_head,
backbone_reuse_branch="uncond",
**kwargs) **kwargs)
else: else:
raise NotImplementedError raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * ( model_output = e_t_uncond + unconditional_guidance_scale * (
e_t_cond - e_t_uncond) e_t_cond - e_t_uncond)
if run_head:
model_output_action = e_t_uncond_action + unconditional_guidance_scale * ( model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
e_t_cond_action - e_t_uncond_action) e_t_cond_action - e_t_uncond_action)
model_output_state = e_t_uncond_state + unconditional_guidance_scale * ( model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
e_t_cond_state - e_t_uncond_state) e_t_cond_state - e_t_uncond_state)
else:
model_output_action = None
model_output_state = None
if guidance_rescale > 0.0: if guidance_rescale > 0.0:
model_output = rescale_noise_cfg( model_output = rescale_noise_cfg(
model_output, e_t_cond, guidance_rescale=guidance_rescale) model_output, e_t_cond, guidance_rescale=guidance_rescale)
if run_head:
model_output_action = rescale_noise_cfg( model_output_action = rescale_noise_cfg(
model_output_action, model_output_action,
e_t_cond_action, e_t_cond_action,
@@ -575,17 +404,11 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video: # Use 0-d tensors directly (already on device); broadcasting handles shape
size = (b, 1, 1, 1, 1) a_t = alphas[index]
else: a_prev = alphas_prev[index]
size = (b, 1, 1, 1) sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
if self.model.parameterization != "v": if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -593,12 +416,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale: if self.model.use_dynamic_rescale:
scale_t = torch.full(size, scale_t = self.ddim_scale_arr[index]
self.ddim_scale_arr[index], prev_scale_t = self.ddim_scale_arr_prev[index]
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
rescale = (prev_scale_t / scale_t) rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale pred_x0 *= rescale

View File

@@ -1,7 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import time
from torch import Tensor from torch import Tensor
from functools import partial from functools import partial
@@ -686,6 +685,37 @@ class WMAModel(nn.Module):
self.action_token_projector = instantiate_from_config( self.action_token_projector = instantiate_from_config(
stem_process_config) stem_process_config)
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
self._trt_backbone = None # TRT engine for video UNet backbone
# Reusable CUDA stream for parallel state_unet / action_unet
self._state_stream = torch.cuda.Stream()
def __getstate__(self):
state = self.__dict__.copy()
state.pop('_state_stream', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, '_ctx_cache_enabled'):
self._ctx_cache_enabled = False
if not hasattr(self, '_ctx_cache'):
self._ctx_cache = {}
if not hasattr(self, '_trt_backbone'):
self._trt_backbone = None
self._state_stream = torch.cuda.Stream()
def load_trt_backbone(self, engine_path, n_hs_a=9):
"""Load a TensorRT engine for the video UNet backbone."""
from unifolm_wma.trt_utils import TRTBackbone
device = next(self.parameters()).device
self._trt_backbone = TRTBackbone(engine_path,
n_hs_a=n_hs_a,
device=device)
print(f">>> TRT backbone loaded from {engine_path} on {device}")
def forward(self, def forward(self,
x: Tensor, x: Tensor,
x_action: Tensor, x_action: Tensor,
@@ -714,28 +744,16 @@ class WMAModel(nn.Module):
Tuple of Tensors for predictions: Tuple of Tensors for predictions:
""" """
b, _, t, _, _ = x.shape b, _, t, _, _ = x.shape
run_head = kwargs.pop("run_head", True)
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
backbone_step_index = kwargs.pop("backbone_step_index", None)
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
None)
backbone_reuse_schedule_steps = kwargs.pop(
"backbone_reuse_schedule_steps", None)
backbone_reuse_force_compute_steps = kwargs.pop(
"backbone_reuse_force_compute_steps", None)
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
None)
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
t_emb = timestep_embedding(timesteps, t_emb = timestep_embedding(timesteps,
self.model_channels, self.model_channels,
repeat_only=False).type(x.dtype) repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
_ctx_key = context.data_ptr()
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
context = self._ctx_cache[_ctx_key]
else:
bt, l_context, _ = context.shape bt, l_context, _ = context.shape
if self.base_model_gen_only: if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
@@ -788,6 +806,8 @@ class WMAModel(nn.Module):
context_img context_img
], ],
dim=1) dim=1)
if self._ctx_cache_enabled:
self._ctx_cache[_ctx_key] = context
emb = emb.repeat_interleave(repeats=t, dim=0) emb = emb.repeat_interleave(repeats=t, dim=0)
@@ -807,73 +827,20 @@ class WMAModel(nn.Module):
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed emb = emb + fs_embed
def run_block_with_profile(block_name: str, block_stage: str, if self._trt_backbone is not None:
block_index: int | None, # TRT path: run backbone via TensorRT engine
fn: Callable[[], Tensor]) -> Tensor: h_in = x.type(self.dtype).contiguous()
if backbone_block_profiler is None or backbone_step_index is None: y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous())
return fn() else:
if x.device.type == "cuda": # PyTorch path: original backbone
torch.cuda.synchronize(x.device)
start_time = time.perf_counter()
out = fn()
if x.device.type == "cuda":
torch.cuda.synchronize(x.device)
backbone_block_profiler.record_block(
step=int(backbone_step_index),
block_name=block_name,
block_stage=block_stage,
block_index=block_index,
output=out,
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
)
return out
reuse_cache_branch: Dict[str, Tensor] | None = None
if backbone_reuse_cache is not None:
reuse_cache_branch = backbone_reuse_cache.setdefault(
backbone_reuse_branch, {})
def should_reuse_output_block(block_name: str) -> bool:
if backbone_reuse_mode != "reuse_output":
return False
if backbone_step_index is None or backbone_reuse_start_step is None:
return False
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
return False
if int(backbone_step_index) < int(backbone_reuse_start_step):
return False
if (backbone_reuse_force_compute_steps is not None
and int(backbone_step_index)
in backbone_reuse_force_compute_steps):
return False
if (backbone_reuse_schedule_steps is not None
and int(backbone_step_index)
in backbone_reuse_schedule_steps):
return False
if reuse_cache_branch is None:
return False
return block_name in reuse_cache_branch
h = x.type(self.dtype) h = x.type(self.dtype)
adapter_idx = 0 adapter_idx = 0
hs = [] hs = []
hs_a = [] hs_a = []
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
def run_input_block() -> Tensor: h = module(h, emb, context=context, batch_size=b)
block_out = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention: if id == 0 and self.addition_attention:
block_out = self.init_attn(block_out, h = self.init_attn(h, emb, context=context, batch_size=b)
emb,
context=context,
batch_size=b)
return block_out
h = run_block_with_profile(
block_name=f"input_{id}",
block_stage="input_blocks",
block_index=id,
fn=run_input_block,
)
# plug-in adapter features # plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None: if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx] h = h + features_adapter[adapter_idx]
@@ -881,76 +848,74 @@ class WMAModel(nn.Module):
if id != 0: if id != 0:
if isinstance(module[0], Downsample): if isinstance(module[0], Downsample):
hs_a.append( hs_a.append(
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t)) rearrange(hs[-1], '(b t) c h w -> b t c h w', b=b))
hs.append(h) hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t)) hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
if features_adapter is not None: if features_adapter is not None:
assert len( assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter' features_adapter) == adapter_idx, 'Wrong features_adapter'
h = run_block_with_profile( h = self.middle_block(h, emb, context=context, batch_size=b)
block_name="middle", hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
block_stage="middle_block",
block_index=0,
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
hs_out = [] hs_out = []
for id, module in enumerate(self.output_blocks): for module in self.output_blocks:
skip_h = hs.pop() h = torch.cat([h, hs.pop()], dim=1)
block_name = f"output_{id}" h = module(h, emb, context=context, batch_size=b)
def run_output_block() -> Tensor:
return module(torch.cat([h, skip_h], dim=1),
emb,
context=context,
batch_size=b)
if should_reuse_output_block(block_name):
h = reuse_cache_branch[block_name].to(device=h.device,
dtype=h.dtype)
if backbone_reuse_step_stats is not None:
backbone_reuse_step_stats.setdefault(
backbone_reuse_branch, set()).add(block_name)
else:
h = run_block_with_profile(
block_name=block_name,
block_stage="output_blocks",
block_index=id,
fn=run_output_block,
)
if (reuse_cache_branch is not None and backbone_reuse_mode ==
"reuse_output"
and backbone_reuse_blocks is not None
and block_name in backbone_reuse_blocks):
reuse_cache_branch[block_name] = h.detach().clone()
if isinstance(module[-1], Upsample): if isinstance(module[-1], Upsample):
hs_a.append( hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
hs_out.append(h) hs_out.append(h)
h = h.type(x.dtype) h = h.type(x.dtype)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t)) hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
y = self.out(h) y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b) y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
if not self.base_model_gen_only and run_head: if not self.base_model_gen_only:
ba, _, _ = x_action.shape ba, _, _ = x_action.shape
ts_state = timesteps[:ba] if b > 1 else timesteps
is_sim_mode = context_action[2] if len(context_action) > 2 else False
if is_sim_mode:
# WM mode: only need state_unet, skip action_unet
s_y = self.state_unet(x_state, ts_state, hs_a,
context_action[:2], **kwargs)
a_y = torch.zeros_like(x_action)
else:
# DM mode: only need action_unet, skip state_unet
a_y = self.action_unet(x_action, timesteps[:ba], hs_a, a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs) context_action[:2], **kwargs)
# Predict state s_y = torch.zeros_like(x_state)
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
elif not self.base_model_gen_only:
a_y = None
s_y = None
else: else:
a_y = torch.zeros_like(x_action) a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state) s_y = torch.zeros_like(x_state)
return y, a_y, s_y return y, a_y, s_y
def enable_ctx_cache(model):
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = True
m._global_cond_cache = {}
def disable_ctx_cache(model):
"""Disable and clear context precomputation cache."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = False
m._global_cond_cache = {}

View File

@@ -0,0 +1,196 @@
"""TensorRT acceleration utilities for the video UNet backbone."""
import torch
import torch.nn as nn
from einops import rearrange
from unifolm_wma.modules.networks.wma_model import Downsample, Upsample
def _normalize_cuda_device(device) -> torch.device:
if device is None:
return torch.device('cuda', torch.cuda.current_device())
if isinstance(device, torch.device):
if device.type != 'cuda':
raise ValueError(f"TensorRT requires a CUDA device, got {device}.")
return device
if isinstance(device, int):
return torch.device('cuda', device)
normalized = torch.device(device)
if normalized.type != 'cuda':
raise ValueError(f"TensorRT requires a CUDA device, got {normalized}.")
return normalized
class VideoBackboneForExport(nn.Module):
"""Wrapper that isolates the video UNet backbone for ONNX export.
Takes already-preprocessed inputs (after context/time embedding prep)
and returns y + hs_a as a flat tuple.
"""
def __init__(self, wma_model):
super().__init__()
self.input_blocks = wma_model.input_blocks
self.middle_block = wma_model.middle_block
self.output_blocks = wma_model.output_blocks
self.out = wma_model.out
self.addition_attention = wma_model.addition_attention
if self.addition_attention:
self.init_attn = wma_model.init_attn
self.dtype = wma_model.dtype
def forward(self, h, emb, context):
t = 16
b = 1
hs = []
hs_a = []
h = h.type(self.dtype)
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
if id != 0:
if isinstance(module[0], Downsample):
hs_a.append(rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
h = self.middle_block(h, emb, context=context, batch_size=b)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
hs_out = []
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
if isinstance(module[-1], Upsample):
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
hs_out.append(h)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
y = self.out(h.type(h.dtype))
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
return (y, *hs_a)
def export_backbone_onnx(model, save_path, context_len=95):
wma = model.model.diffusion_model
wrapper = VideoBackboneForExport(wma)
device = next(wma.parameters()).device
wrapper.eval().to(device)
for m in wrapper.modules():
if hasattr(m, 'checkpoint'):
m.checkpoint = False
if hasattr(m, 'use_checkpoint'):
m.use_checkpoint = False
import xformers.ops
_orig_mea = xformers.ops.memory_efficient_attention
def _sdpa_replacement(q, k, v, attn_bias=None, op=None, **kw):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
xformers.ops.memory_efficient_attention = _sdpa_replacement
BT = 16
emb_dim = wma.model_channels * 4
ctx_dim = 1024
in_ch = wma.in_channels
dummy_h = torch.randn(BT,
in_ch,
40,
64,
device=device,
dtype=torch.float32)
dummy_emb = torch.randn(BT,
emb_dim,
device=device,
dtype=torch.float32)
dummy_ctx = torch.randn(BT,
context_len,
ctx_dim,
device=device,
dtype=torch.float32)
with torch.no_grad():
outputs = wrapper(dummy_h, dummy_emb, dummy_ctx)
n_outputs = len(outputs)
print(f">>> Backbone has {n_outputs} outputs (1 y + {n_outputs-1} hs_a)")
for i, o in enumerate(outputs):
print(f" output[{i}]: {o.shape} {o.dtype}")
output_names = ['y'] + [f'hs_a_{i}' for i in range(n_outputs - 1)]
torch.onnx.export(
wrapper,
(dummy_h, dummy_emb, dummy_ctx),
save_path,
input_names=['h', 'emb', 'context'],
output_names=output_names,
opset_version=17,
do_constant_folding=True,
)
print(f">>> ONNX exported to {save_path}")
xformers.ops.memory_efficient_attention = _orig_mea
return n_outputs
class TRTBackbone:
"""TensorRT runtime wrapper for the video UNet backbone."""
def __init__(self, engine_path, n_hs_a=9, device=None):
import tensorrt as trt
self.device = _normalize_cuda_device(device)
self.logger = trt.Logger(trt.Logger.WARNING)
with torch.cuda.device(self.device):
with open(engine_path, 'rb') as f:
runtime = trt.Runtime(self.logger)
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
self.n_hs_a = n_hs_a
import numpy as np
self.output_buffers = {}
with torch.cuda.device(self.device):
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = self.engine.get_tensor_shape(name)
np_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
buf = torch.empty(
list(shape),
dtype=torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype,
device=self.device)
self.output_buffers[name] = buf
print(
f" TRT output '{name}': {list(shape)} {buf.dtype} on {self.device}"
)
def __call__(self, h, emb, context):
import numpy as np
import tensorrt as trt
bound_inputs = {}
with torch.cuda.device(self.device):
for name, tensor in [('h', h), ('emb', emb), ('context', context)]:
expected_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
torch_expected = torch.from_numpy(
np.empty(0, dtype=expected_dtype)).dtype
if tensor.device != self.device or tensor.dtype != torch_expected:
tensor = tensor.to(device=self.device,
dtype=torch_expected,
non_blocking=True)
tensor = tensor.contiguous()
bound_inputs[name] = tensor
self.context.set_tensor_address(name, tensor.data_ptr())
for name, buf in self.output_buffers.items():
self.context.set_tensor_address(name, buf.data_ptr())
stream = torch.cuda.current_stream(device=self.device)
self.context.execute_async_v3(stream.cuda_stream)
y = self.output_buffers['y']
hs_a = [self.output_buffers[f'hs_a_{i}'] for i in range(self.n_hs_a)]
return y, hs_a

View File

@@ -2,11 +2,11 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case2"
dataset="unitree_z1_dual_arm_stackbox_v2" dataset="unitree_z1_dual_arm_stackbox_v2"
{ {
time CUDA_VISIBLE_DEVICES=0 "${PYTHON_BIN:-python}" scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=0,1 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \
--savedir "${res_dir}/output/sparse_8" \ --savedir "${res_dir}/output" \
--bs 1 --height 320 --width 512 \ --bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \ --unconditional_guidance_scale 1.0 \
--ddim_steps 50 \ --ddim_steps 50 \
@@ -21,9 +21,7 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae \ --perframe_ae \
--analysis_log_metrics \ --pipeline_split_step 30 \
--analysis_reference_steps 50 \ --pipeline_multi_gpu \
--head_schedule_steps 0 7 14 21 28 35 42 49 \ --pipeline_gpu_ids 0 1
--head_skip_mode reuse_prediction \
--head_log_steps 40 43 46 47 48 49
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"