多卡流水导出
This commit is contained in:
File diff suppressed because it is too large
Load Diff
209
scripts/export_trt.py
Normal file
209
scripts/export_trt.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Export video UNet backbone to ONNX, then convert to TensorRT engine.
|
||||
|
||||
Usage:
|
||||
python scripts/export_trt.py \
|
||||
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--out_dir trt_engines
|
||||
|
||||
python scripts/export_trt.py \
|
||||
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--engine_path trt_engines/video_backbone_multigpu.engine \
|
||||
--onnx_path trt_engines/video_backbone_multigpu.onnx
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.trt_utils import export_backbone_onnx
|
||||
|
||||
|
||||
class TacticRecorder(trt.IAlgorithmSelector):
|
||||
"""Pass 1: record all candidate tactics and the auto-selected winner."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.records = {} # layer_name -> {candidates: [...], selected: ...}
|
||||
|
||||
def select_algorithms(self, ctx, choices):
|
||||
name = ctx.name
|
||||
# Collect input/output shapes
|
||||
inputs = []
|
||||
for j in range(ctx.num_inputs):
|
||||
try:
|
||||
inputs.append([int(d) for d in ctx.get_shape(j)])
|
||||
except Exception:
|
||||
inputs.append(None)
|
||||
outputs = []
|
||||
for j in range(ctx.num_outputs):
|
||||
try:
|
||||
outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)])
|
||||
except Exception:
|
||||
outputs.append(None)
|
||||
self.records[name] = {
|
||||
"input_shapes": inputs,
|
||||
"output_shapes": outputs,
|
||||
"candidates": [],
|
||||
"selected": None,
|
||||
}
|
||||
for i, c in enumerate(choices):
|
||||
v = c.algorithm_variant
|
||||
self.records[name]["candidates"].append({
|
||||
"index": i,
|
||||
"implementation": v.implementation,
|
||||
"tactic": v.tactic,
|
||||
"timing_msec": c.timing_msec,
|
||||
"workspace_size": c.workspace_size,
|
||||
})
|
||||
# return all indices -> let TRT auto-pick the fastest
|
||||
return list(range(len(choices)))
|
||||
|
||||
def report_algorithms(self, ctx, choices):
|
||||
# Both ctx and choices are lists in report_algorithms
|
||||
for c, alg in zip(ctx, choices):
|
||||
name = c.name
|
||||
if name in self.records:
|
||||
v = alg.algorithm_variant
|
||||
self.records[name]["selected"] = {
|
||||
"implementation": v.implementation,
|
||||
"tactic": v.tactic,
|
||||
"timing_msec": alg.timing_msec,
|
||||
"workspace_size": alg.workspace_size,
|
||||
}
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.records, f, indent=2)
|
||||
print(f">>> Tactic info saved to {path} ({len(self.records)} layers)")
|
||||
|
||||
|
||||
class TacticForcer(trt.IAlgorithmSelector):
|
||||
"""Pass 2: force user-specified tactics from a JSON file."""
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
with open(path) as f:
|
||||
self.overrides = json.load(f)
|
||||
n = sum(1 for v in self.overrides.values() if v.get("force"))
|
||||
print(f">>> Loaded tactic overrides: {n} layers with 'force' set")
|
||||
|
||||
def select_algorithms(self, ctx, choices):
|
||||
name = ctx.name
|
||||
override = self.overrides.get(name)
|
||||
if override and override.get("force"):
|
||||
target_impl = override["force"]["implementation"]
|
||||
target_tactic = override["force"]["tactic"]
|
||||
for i, c in enumerate(choices):
|
||||
v = c.algorithm_variant
|
||||
if v.implementation == target_impl and v.tactic == target_tactic:
|
||||
return [i]
|
||||
print(f" WARN: forced tactic not found for {name}, using auto")
|
||||
return list(range(len(choices)))
|
||||
|
||||
def report_algorithms(self, ctx, choices):
|
||||
pass
|
||||
|
||||
|
||||
def load_model(config_path, ckpt_path, device):
|
||||
if ckpt_path.endswith('.prepared.pt'):
|
||||
model = torch.load(ckpt_path, map_location='cpu')
|
||||
else:
|
||||
config = OmegaConf.load(config_path)
|
||||
model = instantiate_from_config(config.model)
|
||||
state_dict = torch.load(ckpt_path, map_location='cpu')
|
||||
if 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.eval().to(device)
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ckpt', required=True)
|
||||
parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml')
|
||||
parser.add_argument('--out_dir', default='trt_engines')
|
||||
parser.add_argument('--gpu_id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='CUDA device id used for ONNX export and TRT build.')
|
||||
parser.add_argument('--onnx_path',
|
||||
default=None,
|
||||
help='Optional explicit ONNX output path. Overrides --out_dir default name.')
|
||||
parser.add_argument('--engine_path',
|
||||
default=None,
|
||||
help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.')
|
||||
parser.add_argument('--context_len', type=int, default=95)
|
||||
parser.add_argument('--fp16', action='store_true', default=True)
|
||||
parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON')
|
||||
parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON')
|
||||
args = parser.parse_args()
|
||||
device = torch.device('cuda', args.gpu_id)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
onnx_path = args.onnx_path or os.path.join(args.out_dir,
|
||||
'video_backbone.onnx')
|
||||
engine_path = args.engine_path or os.path.join(args.out_dir,
|
||||
'video_backbone.engine')
|
||||
os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True)
|
||||
|
||||
if os.path.exists(onnx_path):
|
||||
print(f">>> ONNX already exists at {onnx_path}, skipping export.")
|
||||
n_outputs = 10
|
||||
else:
|
||||
print(">>> Loading model ...")
|
||||
model = load_model(args.config, args.ckpt, device)
|
||||
print(">>> Exporting ONNX ...")
|
||||
with torch.no_grad():
|
||||
n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(">>> Converting ONNX -> TensorRT engine ...")
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
if not parser.parse_from_file(os.path.abspath(onnx_path)):
|
||||
for i in range(parser.num_errors):
|
||||
print(f" ONNX parse error: {parser.get_error(i)}")
|
||||
raise RuntimeError("ONNX parsing failed")
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
|
||||
if args.fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Tactic selection
|
||||
recorder = None
|
||||
if args.dump_tactics:
|
||||
recorder = TacticRecorder()
|
||||
config.algorithm_selector = recorder
|
||||
elif args.load_tactics:
|
||||
config.algorithm_selector = TacticForcer(args.load_tactics)
|
||||
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
|
||||
if recorder and args.dump_tactics:
|
||||
recorder.save(args.dump_tactics)
|
||||
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(engine_bytes)
|
||||
|
||||
print(f"\n>>> Done! Engine saved to {engine_path}")
|
||||
print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,12 +1,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
import time
|
||||
|
||||
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
||||
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
@@ -20,8 +21,9 @@ class DDIMSampler(object):
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
target_device = self.model.device
|
||||
if attr.device != target_device:
|
||||
attr = attr.to(target_device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
@@ -68,11 +70,12 @@ class DDIMSampler(object):
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
# Ensure tensors are on correct device for efficient indexing
|
||||
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
|
||||
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
|
||||
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -107,17 +110,9 @@ class DDIMSampler(object):
|
||||
fs=None,
|
||||
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
||||
guidance_rescale=0.0,
|
||||
action_T=None,
|
||||
state_T=None,
|
||||
record_step_outputs=False,
|
||||
head_schedule=None,
|
||||
head_log_steps=None,
|
||||
head_skip_mode="reuse_prediction",
|
||||
backbone_reuse_blocks=None,
|
||||
backbone_reuse_start_step=None,
|
||||
backbone_reuse_schedule_steps=None,
|
||||
backbone_reuse_force_compute_steps=None,
|
||||
backbone_reuse_mode="disabled",
|
||||
handoff_step: int = 0,
|
||||
handoff_callback=None,
|
||||
stop_at_handoff: bool = False,
|
||||
**kwargs):
|
||||
|
||||
# Check condition bs
|
||||
@@ -173,17 +168,9 @@ class DDIMSampler(object):
|
||||
precision=precision,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
action_T=action_T,
|
||||
state_T=state_T,
|
||||
record_step_outputs=record_step_outputs,
|
||||
head_schedule=head_schedule,
|
||||
head_log_steps=head_log_steps,
|
||||
head_skip_mode=head_skip_mode,
|
||||
backbone_reuse_blocks=backbone_reuse_blocks,
|
||||
backbone_reuse_start_step=backbone_reuse_start_step,
|
||||
backbone_reuse_schedule_steps=backbone_reuse_schedule_steps,
|
||||
backbone_reuse_force_compute_steps=backbone_reuse_force_compute_steps,
|
||||
backbone_reuse_mode=backbone_reuse_mode,
|
||||
handoff_step=handoff_step,
|
||||
handoff_callback=handoff_callback,
|
||||
stop_at_handoff=stop_at_handoff,
|
||||
**kwargs)
|
||||
return samples, actions, states, intermediates
|
||||
|
||||
@@ -210,44 +197,23 @@ class DDIMSampler(object):
|
||||
precision=None,
|
||||
fs=None,
|
||||
guidance_rescale=0.0,
|
||||
action_T=None,
|
||||
state_T=None,
|
||||
record_step_outputs=False,
|
||||
head_schedule=None,
|
||||
head_log_steps=None,
|
||||
head_skip_mode="reuse_prediction",
|
||||
backbone_reuse_blocks=None,
|
||||
backbone_reuse_start_step=None,
|
||||
backbone_reuse_schedule_steps=None,
|
||||
backbone_reuse_force_compute_steps=None,
|
||||
backbone_reuse_mode="disabled",
|
||||
handoff_step: int = 0,
|
||||
handoff_callback=None,
|
||||
stop_at_handoff: bool = False,
|
||||
**kwargs):
|
||||
device = self.model.betas.device
|
||||
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
horizon = shape[2] if len(shape) >= 3 else 16
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim), device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim), device=device)
|
||||
else:
|
||||
img = x_T
|
||||
if action_T is None:
|
||||
action = torch.randn((b, horizon, self.model.agent_action_dim),
|
||||
device=device)
|
||||
else:
|
||||
action = action_T
|
||||
if state_T is None:
|
||||
state = torch.randn((b, horizon, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
state = state_T
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
img = img.to(dtype=torch.float16)
|
||||
action = action.to(dtype=torch.float16)
|
||||
state = state.to(dtype=torch.float16)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim), device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim), device=device)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
@@ -265,40 +231,6 @@ class DDIMSampler(object):
|
||||
'x_inter_state': [state],
|
||||
'pred_x0_state': [state],
|
||||
}
|
||||
if record_step_outputs:
|
||||
intermediates['analysis_init'] = {
|
||||
'img': img.detach().cpu(),
|
||||
'action': action.detach().cpu(),
|
||||
'state': state.detach().cpu(),
|
||||
}
|
||||
intermediates['step_records'] = []
|
||||
head_schedule_set = None if head_schedule is None else {
|
||||
int(step_index) for step_index in head_schedule
|
||||
}
|
||||
head_log_steps_set = None if head_log_steps is None else {
|
||||
int(step_index) for step_index in head_log_steps
|
||||
}
|
||||
if head_log_steps_set is not None:
|
||||
intermediates['head_sparse_logs'] = {}
|
||||
backbone_reuse_blocks_set = None if backbone_reuse_blocks is None else {
|
||||
str(block_name) for block_name in backbone_reuse_blocks
|
||||
}
|
||||
backbone_reuse_schedule_steps_set = (
|
||||
None if backbone_reuse_schedule_steps is None else {
|
||||
int(step_index) for step_index in backbone_reuse_schedule_steps
|
||||
})
|
||||
backbone_reuse_force_compute_steps_set = (
|
||||
None if backbone_reuse_force_compute_steps is None else {
|
||||
int(step_index)
|
||||
for step_index in backbone_reuse_force_compute_steps
|
||||
})
|
||||
backbone_reuse_active = (backbone_reuse_mode == "reuse_output"
|
||||
and backbone_reuse_blocks_set)
|
||||
backbone_reuse_cache = {
|
||||
'single': {},
|
||||
'cond': {},
|
||||
'uncond': {},
|
||||
}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
@@ -309,89 +241,50 @@ class DDIMSampler(object):
|
||||
iterator = time_range
|
||||
|
||||
clean_cond = kwargs.pop("clean_cond", False)
|
||||
sync_device = device if isinstance(device, torch.device) else torch.device(
|
||||
device)
|
||||
should_sync = record_step_outputs and sync_device.type == "cuda"
|
||||
if head_skip_mode not in {"reuse_prediction", "freeze_state"}:
|
||||
raise ValueError(
|
||||
f"Unsupported head_skip_mode={head_skip_mode!r}. "
|
||||
"Expected 'reuse_prediction' or 'freeze_state'.")
|
||||
if backbone_reuse_mode not in {"disabled", "reuse_output"}:
|
||||
raise ValueError(
|
||||
f"Unsupported backbone_reuse_mode={backbone_reuse_mode!r}. "
|
||||
"Expected 'disabled' or 'reuse_output'.")
|
||||
x_action_frozen = action.detach().clone()
|
||||
x_state_frozen = state.detach().clone()
|
||||
action_pred_cache = None
|
||||
state_pred_cache = None
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||
enable_cross_attn_kv_cache(self.model)
|
||||
enable_ctx_cache(self.model)
|
||||
handoff = {}
|
||||
try:
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts.fill_(step)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if should_sync:
|
||||
torch.cuda.synchronize(sync_device)
|
||||
step_start_time = time.time()
|
||||
scheduled_head = head_schedule_set is None or i in head_schedule_set
|
||||
if head_skip_mode == "reuse_prediction":
|
||||
run_head = scheduled_head or action_pred_cache is None or state_pred_cache is None
|
||||
else:
|
||||
run_head = scheduled_head
|
||||
backbone_reuse_step_stats: dict[str, set[str]] | None = None
|
||||
if backbone_reuse_active:
|
||||
backbone_reuse_step_stats = {
|
||||
'single': set(),
|
||||
'cond': set(),
|
||||
'uncond': set(),
|
||||
}
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
precision=precision,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
backbone_step_index=i + 1,
|
||||
run_head=run_head,
|
||||
backbone_reuse_blocks=backbone_reuse_blocks_set,
|
||||
backbone_reuse_start_step=backbone_reuse_start_step,
|
||||
backbone_reuse_schedule_steps=backbone_reuse_schedule_steps_set,
|
||||
backbone_reuse_force_compute_steps=
|
||||
backbone_reuse_force_compute_steps_set,
|
||||
backbone_reuse_mode=backbone_reuse_mode,
|
||||
backbone_reuse_cache=backbone_reuse_cache,
|
||||
backbone_reuse_step_stats=backbone_reuse_step_stats,
|
||||
**kwargs)
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
if run_head:
|
||||
action_pred_cache = model_output_action.detach().clone()
|
||||
state_pred_cache = model_output_state.detach().clone()
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
@@ -404,73 +297,35 @@ class DDIMSampler(object):
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
x_action_frozen = action.detach().clone()
|
||||
x_state_frozen = state.detach().clone()
|
||||
else:
|
||||
if head_skip_mode == "reuse_prediction":
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
action_pred_cache,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
state_pred_cache,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
x_action_frozen = action.detach().clone()
|
||||
x_state_frozen = state.detach().clone()
|
||||
else:
|
||||
action = x_action_frozen
|
||||
state = x_state_frozen
|
||||
|
||||
if should_sync:
|
||||
torch.cuda.synchronize(sync_device)
|
||||
step_time_s = time.time() - step_start_time
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
if handoff_step > 0 and (i + 1) == handoff_step:
|
||||
handoff = {
|
||||
'samples': img.clone(),
|
||||
'actions': action.clone(),
|
||||
'states': state.clone(),
|
||||
'pred_x0': pred_x0.clone(),
|
||||
'step': i + 1,
|
||||
}
|
||||
if handoff_callback is not None:
|
||||
handoff_callback(handoff)
|
||||
if stop_at_handoff:
|
||||
intermediates['handoff'] = handoff
|
||||
break
|
||||
|
||||
reused_blocks = []
|
||||
if backbone_reuse_step_stats is not None:
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
reused_blocks = sorted(backbone_reuse_step_stats['single'])
|
||||
else:
|
||||
reused_blocks = sorted(
|
||||
backbone_reuse_step_stats['cond']
|
||||
| backbone_reuse_step_stats['uncond'])
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
if head_log_steps_set is not None and i in head_log_steps_set:
|
||||
intermediates['head_sparse_logs'][i] = {
|
||||
'step_index': i,
|
||||
'ddim_timestep': int(step),
|
||||
'head_executed': run_head,
|
||||
'img': img.detach().cpu(),
|
||||
'pred_x0': pred_x0.detach().cpu(),
|
||||
'action': action.detach().cpu(),
|
||||
'state': state.detach().cpu(),
|
||||
}
|
||||
if record_step_outputs:
|
||||
intermediates['step_records'].append({
|
||||
'step_index': i + 1,
|
||||
'ddim_timestep': int(step),
|
||||
'head_executed': run_head,
|
||||
'img': img.detach().cpu(),
|
||||
'pred_x0': pred_x0.detach().cpu(),
|
||||
'action': action.detach().cpu(),
|
||||
'state': state.detach().cpu(),
|
||||
'step_time_s': step_time_s,
|
||||
'backbone_reused_blocks_count': len(reused_blocks),
|
||||
'backbone_reuse_hit_blocks': ",".join(reused_blocks),
|
||||
})
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
finally:
|
||||
disable_cross_attn_kv_cache(self.model)
|
||||
disable_ctx_cache(self.model)
|
||||
|
||||
if handoff_step > 0:
|
||||
intermediates['handoff'] = handoff
|
||||
return img, action, state, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -481,6 +336,7 @@ class DDIMSampler(object):
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
precision=None,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
@@ -495,62 +351,35 @@ class DDIMSampler(object):
|
||||
mask=None,
|
||||
x0=None,
|
||||
guidance_rescale=0.0,
|
||||
run_head=True,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
x,
|
||||
x_action,
|
||||
x_state,
|
||||
t,
|
||||
c,
|
||||
run_head=run_head,
|
||||
backbone_reuse_branch="single",
|
||||
**kwargs) # unet denoiser
|
||||
else:
|
||||
# do_classifier_free_guidance
|
||||
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
||||
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
||||
x,
|
||||
x_action,
|
||||
x_state,
|
||||
t,
|
||||
c,
|
||||
run_head=run_head,
|
||||
backbone_reuse_branch="cond",
|
||||
**kwargs)
|
||||
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
||||
x,
|
||||
x_action,
|
||||
x_state,
|
||||
t,
|
||||
unconditional_conditioning,
|
||||
run_head=run_head,
|
||||
backbone_reuse_branch="uncond",
|
||||
**kwargs)
|
||||
use_autocast = precision == 16 and device.type == 'cuda'
|
||||
with torch.cuda.amp.autocast(enabled=use_autocast,
|
||||
dtype=torch.float16):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs) # unet denoiser
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
if run_head:
|
||||
# do_classifier_free_guidance
|
||||
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
||||
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs)
|
||||
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, unconditional_conditioning,
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
else:
|
||||
model_output_action = None
|
||||
model_output_state = None
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
||||
if run_head:
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
||||
model_output_action = rescale_noise_cfg(
|
||||
model_output_action,
|
||||
e_t_cond_action,
|
||||
@@ -575,17 +404,11 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index]
|
||||
a_prev = alphas_prev[index]
|
||||
sigma_t = sigmas[index]
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
@@ -593,12 +416,8 @@ class DDIMSampler(object):
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
scale_t = self.ddim_scale_arr[index]
|
||||
prev_scale_t = self.ddim_scale_arr_prev[index]
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
@@ -686,6 +685,37 @@ class WMAModel(nn.Module):
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
# Context precomputation cache
|
||||
self._ctx_cache_enabled = False
|
||||
self._ctx_cache = {}
|
||||
self._trt_backbone = None # TRT engine for video UNet backbone
|
||||
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state.pop('_state_stream', None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
if not hasattr(self, '_ctx_cache_enabled'):
|
||||
self._ctx_cache_enabled = False
|
||||
if not hasattr(self, '_ctx_cache'):
|
||||
self._ctx_cache = {}
|
||||
if not hasattr(self, '_trt_backbone'):
|
||||
self._trt_backbone = None
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def load_trt_backbone(self, engine_path, n_hs_a=9):
|
||||
"""Load a TensorRT engine for the video UNet backbone."""
|
||||
from unifolm_wma.trt_utils import TRTBackbone
|
||||
device = next(self.parameters()).device
|
||||
self._trt_backbone = TRTBackbone(engine_path,
|
||||
n_hs_a=n_hs_a,
|
||||
device=device)
|
||||
print(f">>> TRT backbone loaded from {engine_path} on {device}")
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
@@ -714,80 +744,70 @@ class WMAModel(nn.Module):
|
||||
Tuple of Tensors for predictions:
|
||||
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = x.shape
|
||||
run_head = kwargs.pop("run_head", True)
|
||||
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
|
||||
backbone_step_index = kwargs.pop("backbone_step_index", None)
|
||||
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
|
||||
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
|
||||
None)
|
||||
backbone_reuse_schedule_steps = kwargs.pop(
|
||||
"backbone_reuse_schedule_steps", None)
|
||||
backbone_reuse_force_compute_steps = kwargs.pop(
|
||||
"backbone_reuse_force_compute_steps", None)
|
||||
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
|
||||
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
|
||||
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
|
||||
None)
|
||||
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
_ctx_key = context.data_ptr()
|
||||
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
|
||||
context = self._ctx_cache[_ctx_key]
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
if self._ctx_cache_enabled:
|
||||
self._ctx_cache[_ctx_key] = context
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
@@ -807,150 +827,95 @@ class WMAModel(nn.Module):
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
emb = emb + fs_embed
|
||||
|
||||
def run_block_with_profile(block_name: str, block_stage: str,
|
||||
block_index: int | None,
|
||||
fn: Callable[[], Tensor]) -> Tensor:
|
||||
if backbone_block_profiler is None or backbone_step_index is None:
|
||||
return fn()
|
||||
if x.device.type == "cuda":
|
||||
torch.cuda.synchronize(x.device)
|
||||
start_time = time.perf_counter()
|
||||
out = fn()
|
||||
if x.device.type == "cuda":
|
||||
torch.cuda.synchronize(x.device)
|
||||
backbone_block_profiler.record_block(
|
||||
step=int(backbone_step_index),
|
||||
block_name=block_name,
|
||||
block_stage=block_stage,
|
||||
block_index=block_index,
|
||||
output=out,
|
||||
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
|
||||
)
|
||||
return out
|
||||
|
||||
reuse_cache_branch: Dict[str, Tensor] | None = None
|
||||
if backbone_reuse_cache is not None:
|
||||
reuse_cache_branch = backbone_reuse_cache.setdefault(
|
||||
backbone_reuse_branch, {})
|
||||
|
||||
def should_reuse_output_block(block_name: str) -> bool:
|
||||
if backbone_reuse_mode != "reuse_output":
|
||||
return False
|
||||
if backbone_step_index is None or backbone_reuse_start_step is None:
|
||||
return False
|
||||
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
|
||||
return False
|
||||
if int(backbone_step_index) < int(backbone_reuse_start_step):
|
||||
return False
|
||||
if (backbone_reuse_force_compute_steps is not None
|
||||
and int(backbone_step_index)
|
||||
in backbone_reuse_force_compute_steps):
|
||||
return False
|
||||
if (backbone_reuse_schedule_steps is not None
|
||||
and int(backbone_step_index)
|
||||
in backbone_reuse_schedule_steps):
|
||||
return False
|
||||
if reuse_cache_branch is None:
|
||||
return False
|
||||
return block_name in reuse_cache_branch
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
def run_input_block() -> Tensor:
|
||||
block_out = module(h, emb, context=context, batch_size=b)
|
||||
if self._trt_backbone is not None:
|
||||
# TRT path: run backbone via TensorRT engine
|
||||
h_in = x.type(self.dtype).contiguous()
|
||||
y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous())
|
||||
else:
|
||||
# PyTorch path: original backbone
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
block_out = self.init_attn(block_out,
|
||||
emb,
|
||||
context=context,
|
||||
batch_size=b)
|
||||
return block_out
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', b=b))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
|
||||
|
||||
h = run_block_with_profile(
|
||||
block_name=f"input_{id}",
|
||||
block_stage="input_blocks",
|
||||
block_index=id,
|
||||
fn=run_input_block,
|
||||
)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
|
||||
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = run_block_with_profile(
|
||||
block_name="middle",
|
||||
block_stage="middle_block",
|
||||
block_index=0,
|
||||
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
|
||||
)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
hs_out = []
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
skip_h = hs.pop()
|
||||
block_name = f"output_{id}"
|
||||
|
||||
def run_output_block() -> Tensor:
|
||||
return module(torch.cat([h, skip_h], dim=1),
|
||||
emb,
|
||||
context=context,
|
||||
batch_size=b)
|
||||
|
||||
if should_reuse_output_block(block_name):
|
||||
h = reuse_cache_branch[block_name].to(device=h.device,
|
||||
dtype=h.dtype)
|
||||
if backbone_reuse_step_stats is not None:
|
||||
backbone_reuse_step_stats.setdefault(
|
||||
backbone_reuse_branch, set()).add(block_name)
|
||||
else:
|
||||
h = run_block_with_profile(
|
||||
block_name=block_name,
|
||||
block_stage="output_blocks",
|
||||
block_index=id,
|
||||
fn=run_output_block,
|
||||
)
|
||||
if (reuse_cache_branch is not None and backbone_reuse_mode ==
|
||||
"reuse_output"
|
||||
and backbone_reuse_blocks is not None
|
||||
and block_name in backbone_reuse_blocks):
|
||||
reuse_cache_branch[block_name] = h.detach().clone()
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
if not self.base_model_gen_only and run_head:
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||
is_sim_mode = context_action[2] if len(context_action) > 2 else False
|
||||
|
||||
if is_sim_mode:
|
||||
# WM mode: only need state_unet, skip action_unet
|
||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
a_y = torch.zeros_like(x_action)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
elif not self.base_model_gen_only:
|
||||
a_y = None
|
||||
s_y = None
|
||||
# DM mode: only need action_unet, skip state_unet
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
|
||||
|
||||
def enable_ctx_cache(model):
|
||||
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = True
|
||||
m._ctx_cache = {}
|
||||
# conditional_unet1d cache
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = True
|
||||
m._global_cond_cache = {}
|
||||
|
||||
|
||||
def disable_ctx_cache(model):
|
||||
"""Disable and clear context precomputation cache."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = False
|
||||
m._ctx_cache = {}
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = False
|
||||
m._global_cond_cache = {}
|
||||
|
||||
196
src/unifolm_wma/trt_utils.py
Normal file
196
src/unifolm_wma/trt_utils.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""TensorRT acceleration utilities for the video UNet backbone."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from unifolm_wma.modules.networks.wma_model import Downsample, Upsample
|
||||
|
||||
|
||||
def _normalize_cuda_device(device) -> torch.device:
|
||||
if device is None:
|
||||
return torch.device('cuda', torch.cuda.current_device())
|
||||
if isinstance(device, torch.device):
|
||||
if device.type != 'cuda':
|
||||
raise ValueError(f"TensorRT requires a CUDA device, got {device}.")
|
||||
return device
|
||||
if isinstance(device, int):
|
||||
return torch.device('cuda', device)
|
||||
normalized = torch.device(device)
|
||||
if normalized.type != 'cuda':
|
||||
raise ValueError(f"TensorRT requires a CUDA device, got {normalized}.")
|
||||
return normalized
|
||||
|
||||
|
||||
class VideoBackboneForExport(nn.Module):
|
||||
"""Wrapper that isolates the video UNet backbone for ONNX export.
|
||||
|
||||
Takes already-preprocessed inputs (after context/time embedding prep)
|
||||
and returns y + hs_a as a flat tuple.
|
||||
"""
|
||||
|
||||
def __init__(self, wma_model):
|
||||
super().__init__()
|
||||
self.input_blocks = wma_model.input_blocks
|
||||
self.middle_block = wma_model.middle_block
|
||||
self.output_blocks = wma_model.output_blocks
|
||||
self.out = wma_model.out
|
||||
self.addition_attention = wma_model.addition_attention
|
||||
if self.addition_attention:
|
||||
self.init_attn = wma_model.init_attn
|
||||
self.dtype = wma_model.dtype
|
||||
|
||||
def forward(self, h, emb, context):
|
||||
t = 16
|
||||
b = 1
|
||||
|
||||
hs = []
|
||||
hs_a = []
|
||||
h = h.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
hs_a.append(rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs_out.append(h)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
y = self.out(h.type(h.dtype))
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
return (y, *hs_a)
|
||||
|
||||
|
||||
def export_backbone_onnx(model, save_path, context_len=95):
|
||||
wma = model.model.diffusion_model
|
||||
wrapper = VideoBackboneForExport(wma)
|
||||
device = next(wma.parameters()).device
|
||||
wrapper.eval().to(device)
|
||||
|
||||
for m in wrapper.modules():
|
||||
if hasattr(m, 'checkpoint'):
|
||||
m.checkpoint = False
|
||||
if hasattr(m, 'use_checkpoint'):
|
||||
m.use_checkpoint = False
|
||||
|
||||
import xformers.ops
|
||||
_orig_mea = xformers.ops.memory_efficient_attention
|
||||
def _sdpa_replacement(q, k, v, attn_bias=None, op=None, **kw):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||
xformers.ops.memory_efficient_attention = _sdpa_replacement
|
||||
|
||||
BT = 16
|
||||
emb_dim = wma.model_channels * 4
|
||||
ctx_dim = 1024
|
||||
in_ch = wma.in_channels
|
||||
|
||||
dummy_h = torch.randn(BT,
|
||||
in_ch,
|
||||
40,
|
||||
64,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
dummy_emb = torch.randn(BT,
|
||||
emb_dim,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
dummy_ctx = torch.randn(BT,
|
||||
context_len,
|
||||
ctx_dim,
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = wrapper(dummy_h, dummy_emb, dummy_ctx)
|
||||
n_outputs = len(outputs)
|
||||
print(f">>> Backbone has {n_outputs} outputs (1 y + {n_outputs-1} hs_a)")
|
||||
for i, o in enumerate(outputs):
|
||||
print(f" output[{i}]: {o.shape} {o.dtype}")
|
||||
|
||||
output_names = ['y'] + [f'hs_a_{i}' for i in range(n_outputs - 1)]
|
||||
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
(dummy_h, dummy_emb, dummy_ctx),
|
||||
save_path,
|
||||
input_names=['h', 'emb', 'context'],
|
||||
output_names=output_names,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
print(f">>> ONNX exported to {save_path}")
|
||||
xformers.ops.memory_efficient_attention = _orig_mea
|
||||
return n_outputs
|
||||
|
||||
|
||||
class TRTBackbone:
|
||||
"""TensorRT runtime wrapper for the video UNet backbone."""
|
||||
|
||||
def __init__(self, engine_path, n_hs_a=9, device=None):
|
||||
import tensorrt as trt
|
||||
|
||||
self.device = _normalize_cuda_device(device)
|
||||
self.logger = trt.Logger(trt.Logger.WARNING)
|
||||
with torch.cuda.device(self.device):
|
||||
with open(engine_path, 'rb') as f:
|
||||
runtime = trt.Runtime(self.logger)
|
||||
self.engine = runtime.deserialize_cuda_engine(f.read())
|
||||
self.context = self.engine.create_execution_context()
|
||||
self.n_hs_a = n_hs_a
|
||||
|
||||
import numpy as np
|
||||
self.output_buffers = {}
|
||||
with torch.cuda.device(self.device):
|
||||
for i in range(self.engine.num_io_tensors):
|
||||
name = self.engine.get_tensor_name(i)
|
||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
||||
shape = self.engine.get_tensor_shape(name)
|
||||
np_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
||||
buf = torch.empty(
|
||||
list(shape),
|
||||
dtype=torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype,
|
||||
device=self.device)
|
||||
self.output_buffers[name] = buf
|
||||
print(
|
||||
f" TRT output '{name}': {list(shape)} {buf.dtype} on {self.device}"
|
||||
)
|
||||
|
||||
def __call__(self, h, emb, context):
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
|
||||
bound_inputs = {}
|
||||
with torch.cuda.device(self.device):
|
||||
for name, tensor in [('h', h), ('emb', emb), ('context', context)]:
|
||||
expected_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
||||
torch_expected = torch.from_numpy(
|
||||
np.empty(0, dtype=expected_dtype)).dtype
|
||||
if tensor.device != self.device or tensor.dtype != torch_expected:
|
||||
tensor = tensor.to(device=self.device,
|
||||
dtype=torch_expected,
|
||||
non_blocking=True)
|
||||
tensor = tensor.contiguous()
|
||||
bound_inputs[name] = tensor
|
||||
self.context.set_tensor_address(name, tensor.data_ptr())
|
||||
|
||||
for name, buf in self.output_buffers.items():
|
||||
self.context.set_tensor_address(name, buf.data_ptr())
|
||||
|
||||
stream = torch.cuda.current_stream(device=self.device)
|
||||
self.context.execute_async_v3(stream.cuda_stream)
|
||||
|
||||
y = self.output_buffers['y']
|
||||
hs_a = [self.output_buffers[f'hs_a_{i}'] for i in range(self.n_hs_a)]
|
||||
return y, hs_a
|
||||
@@ -2,11 +2,11 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case2"
|
||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 "${PYTHON_BIN:-python}" scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=0,1 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--savedir "${res_dir}/output/sparse_8" \
|
||||
--savedir "${res_dir}/output" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 50 \
|
||||
@@ -21,9 +21,7 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae \
|
||||
--analysis_log_metrics \
|
||||
--analysis_reference_steps 50 \
|
||||
--head_schedule_steps 0 7 14 21 28 35 42 49 \
|
||||
--head_skip_mode reuse_prediction \
|
||||
--head_log_steps 40 43 46 47 48 49
|
||||
--pipeline_split_step 30 \
|
||||
--pipeline_multi_gpu \
|
||||
--pipeline_gpu_ids 0 1
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user