多卡流水导出
This commit is contained in:
File diff suppressed because it is too large
Load Diff
209
scripts/export_trt.py
Normal file
209
scripts/export_trt.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Export video UNet backbone to ONNX, then convert to TensorRT engine.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/export_trt.py \
|
||||||
|
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
||||||
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
--out_dir trt_engines
|
||||||
|
|
||||||
|
python scripts/export_trt.py \
|
||||||
|
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
||||||
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
--engine_path trt_engines/video_backbone_multigpu.engine \
|
||||||
|
--onnx_path trt_engines/video_backbone_multigpu.onnx
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tensorrt as trt
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||||
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
from unifolm_wma.trt_utils import export_backbone_onnx
|
||||||
|
|
||||||
|
|
||||||
|
class TacticRecorder(trt.IAlgorithmSelector):
|
||||||
|
"""Pass 1: record all candidate tactics and the auto-selected winner."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.records = {} # layer_name -> {candidates: [...], selected: ...}
|
||||||
|
|
||||||
|
def select_algorithms(self, ctx, choices):
|
||||||
|
name = ctx.name
|
||||||
|
# Collect input/output shapes
|
||||||
|
inputs = []
|
||||||
|
for j in range(ctx.num_inputs):
|
||||||
|
try:
|
||||||
|
inputs.append([int(d) for d in ctx.get_shape(j)])
|
||||||
|
except Exception:
|
||||||
|
inputs.append(None)
|
||||||
|
outputs = []
|
||||||
|
for j in range(ctx.num_outputs):
|
||||||
|
try:
|
||||||
|
outputs.append([int(d) for d in ctx.get_shape(ctx.num_inputs + j)])
|
||||||
|
except Exception:
|
||||||
|
outputs.append(None)
|
||||||
|
self.records[name] = {
|
||||||
|
"input_shapes": inputs,
|
||||||
|
"output_shapes": outputs,
|
||||||
|
"candidates": [],
|
||||||
|
"selected": None,
|
||||||
|
}
|
||||||
|
for i, c in enumerate(choices):
|
||||||
|
v = c.algorithm_variant
|
||||||
|
self.records[name]["candidates"].append({
|
||||||
|
"index": i,
|
||||||
|
"implementation": v.implementation,
|
||||||
|
"tactic": v.tactic,
|
||||||
|
"timing_msec": c.timing_msec,
|
||||||
|
"workspace_size": c.workspace_size,
|
||||||
|
})
|
||||||
|
# return all indices -> let TRT auto-pick the fastest
|
||||||
|
return list(range(len(choices)))
|
||||||
|
|
||||||
|
def report_algorithms(self, ctx, choices):
|
||||||
|
# Both ctx and choices are lists in report_algorithms
|
||||||
|
for c, alg in zip(ctx, choices):
|
||||||
|
name = c.name
|
||||||
|
if name in self.records:
|
||||||
|
v = alg.algorithm_variant
|
||||||
|
self.records[name]["selected"] = {
|
||||||
|
"implementation": v.implementation,
|
||||||
|
"tactic": v.tactic,
|
||||||
|
"timing_msec": alg.timing_msec,
|
||||||
|
"workspace_size": alg.workspace_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(self.records, f, indent=2)
|
||||||
|
print(f">>> Tactic info saved to {path} ({len(self.records)} layers)")
|
||||||
|
|
||||||
|
|
||||||
|
class TacticForcer(trt.IAlgorithmSelector):
|
||||||
|
"""Pass 2: force user-specified tactics from a JSON file."""
|
||||||
|
|
||||||
|
def __init__(self, path):
|
||||||
|
super().__init__()
|
||||||
|
with open(path) as f:
|
||||||
|
self.overrides = json.load(f)
|
||||||
|
n = sum(1 for v in self.overrides.values() if v.get("force"))
|
||||||
|
print(f">>> Loaded tactic overrides: {n} layers with 'force' set")
|
||||||
|
|
||||||
|
def select_algorithms(self, ctx, choices):
|
||||||
|
name = ctx.name
|
||||||
|
override = self.overrides.get(name)
|
||||||
|
if override and override.get("force"):
|
||||||
|
target_impl = override["force"]["implementation"]
|
||||||
|
target_tactic = override["force"]["tactic"]
|
||||||
|
for i, c in enumerate(choices):
|
||||||
|
v = c.algorithm_variant
|
||||||
|
if v.implementation == target_impl and v.tactic == target_tactic:
|
||||||
|
return [i]
|
||||||
|
print(f" WARN: forced tactic not found for {name}, using auto")
|
||||||
|
return list(range(len(choices)))
|
||||||
|
|
||||||
|
def report_algorithms(self, ctx, choices):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(config_path, ckpt_path, device):
|
||||||
|
if ckpt_path.endswith('.prepared.pt'):
|
||||||
|
model = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
else:
|
||||||
|
config = OmegaConf.load(config_path)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
state_dict = torch.load(ckpt_path, map_location='cpu')
|
||||||
|
if 'state_dict' in state_dict:
|
||||||
|
state_dict = state_dict['state_dict']
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
model.eval().to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--ckpt', required=True)
|
||||||
|
parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml')
|
||||||
|
parser.add_argument('--out_dir', default='trt_engines')
|
||||||
|
parser.add_argument('--gpu_id',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='CUDA device id used for ONNX export and TRT build.')
|
||||||
|
parser.add_argument('--onnx_path',
|
||||||
|
default=None,
|
||||||
|
help='Optional explicit ONNX output path. Overrides --out_dir default name.')
|
||||||
|
parser.add_argument('--engine_path',
|
||||||
|
default=None,
|
||||||
|
help='Optional explicit TensorRT engine output path. Overrides --out_dir default name.')
|
||||||
|
parser.add_argument('--context_len', type=int, default=95)
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=True)
|
||||||
|
parser.add_argument('--dump-tactics', default=None, help='Pass 1: dump tactic info to JSON')
|
||||||
|
parser.add_argument('--load-tactics', default=None, help='Pass 2: force tactics from JSON')
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = torch.device('cuda', args.gpu_id)
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
onnx_path = args.onnx_path or os.path.join(args.out_dir,
|
||||||
|
'video_backbone.onnx')
|
||||||
|
engine_path = args.engine_path or os.path.join(args.out_dir,
|
||||||
|
'video_backbone.engine')
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(onnx_path)), exist_ok=True)
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True)
|
||||||
|
|
||||||
|
if os.path.exists(onnx_path):
|
||||||
|
print(f">>> ONNX already exists at {onnx_path}, skipping export.")
|
||||||
|
n_outputs = 10
|
||||||
|
else:
|
||||||
|
print(">>> Loading model ...")
|
||||||
|
model = load_model(args.config, args.ckpt, device)
|
||||||
|
print(">>> Exporting ONNX ...")
|
||||||
|
with torch.no_grad():
|
||||||
|
n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len)
|
||||||
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
print(">>> Converting ONNX -> TensorRT engine ...")
|
||||||
|
logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
|
||||||
|
if not parser.parse_from_file(os.path.abspath(onnx_path)):
|
||||||
|
for i in range(parser.num_errors):
|
||||||
|
print(f" ONNX parse error: {parser.get_error(i)}")
|
||||||
|
raise RuntimeError("ONNX parsing failed")
|
||||||
|
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
|
||||||
|
if args.fp16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
|
||||||
|
# Tactic selection
|
||||||
|
recorder = None
|
||||||
|
if args.dump_tactics:
|
||||||
|
recorder = TacticRecorder()
|
||||||
|
config.algorithm_selector = recorder
|
||||||
|
elif args.load_tactics:
|
||||||
|
config.algorithm_selector = TacticForcer(args.load_tactics)
|
||||||
|
|
||||||
|
engine_bytes = builder.build_serialized_network(network, config)
|
||||||
|
|
||||||
|
if recorder and args.dump_tactics:
|
||||||
|
recorder.save(args.dump_tactics)
|
||||||
|
|
||||||
|
with open(engine_path, 'wb') as f:
|
||||||
|
f.write(engine_bytes)
|
||||||
|
|
||||||
|
print(f"\n>>> Done! Engine saved to {engine_path}")
|
||||||
|
print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import time
|
|
||||||
|
|
||||||
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
||||||
from unifolm_wma.utils.common import noise_like
|
from unifolm_wma.utils.common import noise_like
|
||||||
from unifolm_wma.utils.common import extract_into_tensor
|
from unifolm_wma.utils.common import extract_into_tensor
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
||||||
|
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
@@ -20,8 +21,9 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
target_device = self.model.device
|
||||||
attr = attr.to(torch.device("cuda"))
|
if attr.device != target_device:
|
||||||
|
attr = attr.to(target_device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self,
|
def make_schedule(self,
|
||||||
@@ -68,11 +70,12 @@ class DDIMSampler(object):
|
|||||||
ddim_timesteps=self.ddim_timesteps,
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
# Ensure tensors are on correct device for efficient indexing
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
|
||||||
|
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||||
np.sqrt(1. - ddim_alphas))
|
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
@@ -107,17 +110,9 @@ class DDIMSampler(object):
|
|||||||
fs=None,
|
fs=None,
|
||||||
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
||||||
guidance_rescale=0.0,
|
guidance_rescale=0.0,
|
||||||
action_T=None,
|
handoff_step: int = 0,
|
||||||
state_T=None,
|
handoff_callback=None,
|
||||||
record_step_outputs=False,
|
stop_at_handoff: bool = False,
|
||||||
head_schedule=None,
|
|
||||||
head_log_steps=None,
|
|
||||||
head_skip_mode="reuse_prediction",
|
|
||||||
backbone_reuse_blocks=None,
|
|
||||||
backbone_reuse_start_step=None,
|
|
||||||
backbone_reuse_schedule_steps=None,
|
|
||||||
backbone_reuse_force_compute_steps=None,
|
|
||||||
backbone_reuse_mode="disabled",
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
# Check condition bs
|
# Check condition bs
|
||||||
@@ -173,17 +168,9 @@ class DDIMSampler(object):
|
|||||||
precision=precision,
|
precision=precision,
|
||||||
fs=fs,
|
fs=fs,
|
||||||
guidance_rescale=guidance_rescale,
|
guidance_rescale=guidance_rescale,
|
||||||
action_T=action_T,
|
handoff_step=handoff_step,
|
||||||
state_T=state_T,
|
handoff_callback=handoff_callback,
|
||||||
record_step_outputs=record_step_outputs,
|
stop_at_handoff=stop_at_handoff,
|
||||||
head_schedule=head_schedule,
|
|
||||||
head_log_steps=head_log_steps,
|
|
||||||
head_skip_mode=head_skip_mode,
|
|
||||||
backbone_reuse_blocks=backbone_reuse_blocks,
|
|
||||||
backbone_reuse_start_step=backbone_reuse_start_step,
|
|
||||||
backbone_reuse_schedule_steps=backbone_reuse_schedule_steps,
|
|
||||||
backbone_reuse_force_compute_steps=backbone_reuse_force_compute_steps,
|
|
||||||
backbone_reuse_mode=backbone_reuse_mode,
|
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return samples, actions, states, intermediates
|
return samples, actions, states, intermediates
|
||||||
|
|
||||||
@@ -210,44 +197,23 @@ class DDIMSampler(object):
|
|||||||
precision=None,
|
precision=None,
|
||||||
fs=None,
|
fs=None,
|
||||||
guidance_rescale=0.0,
|
guidance_rescale=0.0,
|
||||||
action_T=None,
|
handoff_step: int = 0,
|
||||||
state_T=None,
|
handoff_callback=None,
|
||||||
record_step_outputs=False,
|
stop_at_handoff: bool = False,
|
||||||
head_schedule=None,
|
|
||||||
head_log_steps=None,
|
|
||||||
head_skip_mode="reuse_prediction",
|
|
||||||
backbone_reuse_blocks=None,
|
|
||||||
backbone_reuse_start_step=None,
|
|
||||||
backbone_reuse_schedule_steps=None,
|
|
||||||
backbone_reuse_force_compute_steps=None,
|
|
||||||
backbone_reuse_mode="disabled",
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
||||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
horizon = shape[2] if len(shape) >= 3 else 16
|
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
img = torch.randn(shape, device=device)
|
img = torch.randn(shape, device=device)
|
||||||
|
action = torch.randn((b, 16, self.model.agent_action_dim), device=device)
|
||||||
|
state = torch.randn((b, 16, self.model.agent_state_dim), device=device)
|
||||||
else:
|
else:
|
||||||
img = x_T
|
img = x_T
|
||||||
if action_T is None:
|
action = torch.randn((b, 16, self.model.agent_action_dim), device=device)
|
||||||
action = torch.randn((b, horizon, self.model.agent_action_dim),
|
state = torch.randn((b, 16, self.model.agent_state_dim), device=device)
|
||||||
device=device)
|
|
||||||
else:
|
|
||||||
action = action_T
|
|
||||||
if state_T is None:
|
|
||||||
state = torch.randn((b, horizon, self.model.agent_state_dim),
|
|
||||||
device=device)
|
|
||||||
else:
|
|
||||||
state = state_T
|
|
||||||
|
|
||||||
if precision is not None:
|
|
||||||
if precision == 16:
|
|
||||||
img = img.to(dtype=torch.float16)
|
|
||||||
action = action.to(dtype=torch.float16)
|
|
||||||
state = state.to(dtype=torch.float16)
|
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
@@ -265,40 +231,6 @@ class DDIMSampler(object):
|
|||||||
'x_inter_state': [state],
|
'x_inter_state': [state],
|
||||||
'pred_x0_state': [state],
|
'pred_x0_state': [state],
|
||||||
}
|
}
|
||||||
if record_step_outputs:
|
|
||||||
intermediates['analysis_init'] = {
|
|
||||||
'img': img.detach().cpu(),
|
|
||||||
'action': action.detach().cpu(),
|
|
||||||
'state': state.detach().cpu(),
|
|
||||||
}
|
|
||||||
intermediates['step_records'] = []
|
|
||||||
head_schedule_set = None if head_schedule is None else {
|
|
||||||
int(step_index) for step_index in head_schedule
|
|
||||||
}
|
|
||||||
head_log_steps_set = None if head_log_steps is None else {
|
|
||||||
int(step_index) for step_index in head_log_steps
|
|
||||||
}
|
|
||||||
if head_log_steps_set is not None:
|
|
||||||
intermediates['head_sparse_logs'] = {}
|
|
||||||
backbone_reuse_blocks_set = None if backbone_reuse_blocks is None else {
|
|
||||||
str(block_name) for block_name in backbone_reuse_blocks
|
|
||||||
}
|
|
||||||
backbone_reuse_schedule_steps_set = (
|
|
||||||
None if backbone_reuse_schedule_steps is None else {
|
|
||||||
int(step_index) for step_index in backbone_reuse_schedule_steps
|
|
||||||
})
|
|
||||||
backbone_reuse_force_compute_steps_set = (
|
|
||||||
None if backbone_reuse_force_compute_steps is None else {
|
|
||||||
int(step_index)
|
|
||||||
for step_index in backbone_reuse_force_compute_steps
|
|
||||||
})
|
|
||||||
backbone_reuse_active = (backbone_reuse_mode == "reuse_output"
|
|
||||||
and backbone_reuse_blocks_set)
|
|
||||||
backbone_reuse_cache = {
|
|
||||||
'single': {},
|
|
||||||
'cond': {},
|
|
||||||
'uncond': {},
|
|
||||||
}
|
|
||||||
time_range = reversed(range(
|
time_range = reversed(range(
|
||||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||||
@@ -309,89 +241,50 @@ class DDIMSampler(object):
|
|||||||
iterator = time_range
|
iterator = time_range
|
||||||
|
|
||||||
clean_cond = kwargs.pop("clean_cond", False)
|
clean_cond = kwargs.pop("clean_cond", False)
|
||||||
sync_device = device if isinstance(device, torch.device) else torch.device(
|
|
||||||
device)
|
|
||||||
should_sync = record_step_outputs and sync_device.type == "cuda"
|
|
||||||
if head_skip_mode not in {"reuse_prediction", "freeze_state"}:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported head_skip_mode={head_skip_mode!r}. "
|
|
||||||
"Expected 'reuse_prediction' or 'freeze_state'.")
|
|
||||||
if backbone_reuse_mode not in {"disabled", "reuse_output"}:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported backbone_reuse_mode={backbone_reuse_mode!r}. "
|
|
||||||
"Expected 'disabled' or 'reuse_output'.")
|
|
||||||
x_action_frozen = action.detach().clone()
|
|
||||||
x_state_frozen = state.detach().clone()
|
|
||||||
action_pred_cache = None
|
|
||||||
state_pred_cache = None
|
|
||||||
|
|
||||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||||
for i, step in enumerate(iterator):
|
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||||
index = total_steps - i - 1
|
enable_cross_attn_kv_cache(self.model)
|
||||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
enable_ctx_cache(self.model)
|
||||||
|
handoff = {}
|
||||||
|
try:
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts.fill_(step)
|
||||||
|
|
||||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert x0 is not None
|
assert x0 is not None
|
||||||
if clean_cond:
|
if clean_cond:
|
||||||
img_orig = x0
|
img_orig = x0
|
||||||
else:
|
else:
|
||||||
img_orig = self.model.q_sample(x0, ts)
|
img_orig = self.model.q_sample(x0, ts)
|
||||||
img = img_orig * mask + (1. - mask) * img
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
if should_sync:
|
outs = self.p_sample_ddim(
|
||||||
torch.cuda.synchronize(sync_device)
|
img,
|
||||||
step_start_time = time.time()
|
action,
|
||||||
scheduled_head = head_schedule_set is None or i in head_schedule_set
|
state,
|
||||||
if head_skip_mode == "reuse_prediction":
|
cond,
|
||||||
run_head = scheduled_head or action_pred_cache is None or state_pred_cache is None
|
ts,
|
||||||
else:
|
index=index,
|
||||||
run_head = scheduled_head
|
precision=precision,
|
||||||
backbone_reuse_step_stats: dict[str, set[str]] | None = None
|
use_original_steps=ddim_use_original_steps,
|
||||||
if backbone_reuse_active:
|
quantize_denoised=quantize_denoised,
|
||||||
backbone_reuse_step_stats = {
|
temperature=temperature,
|
||||||
'single': set(),
|
noise_dropout=noise_dropout,
|
||||||
'cond': set(),
|
score_corrector=score_corrector,
|
||||||
'uncond': set(),
|
corrector_kwargs=corrector_kwargs,
|
||||||
}
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
mask=mask,
|
||||||
|
x0=x0,
|
||||||
|
fs=fs,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
outs = self.p_sample_ddim(
|
img, pred_x0, model_output_action, model_output_state = outs
|
||||||
img,
|
|
||||||
action,
|
|
||||||
state,
|
|
||||||
cond,
|
|
||||||
ts,
|
|
||||||
index=index,
|
|
||||||
use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised,
|
|
||||||
temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
mask=mask,
|
|
||||||
x0=x0,
|
|
||||||
fs=fs,
|
|
||||||
guidance_rescale=guidance_rescale,
|
|
||||||
backbone_step_index=i + 1,
|
|
||||||
run_head=run_head,
|
|
||||||
backbone_reuse_blocks=backbone_reuse_blocks_set,
|
|
||||||
backbone_reuse_start_step=backbone_reuse_start_step,
|
|
||||||
backbone_reuse_schedule_steps=backbone_reuse_schedule_steps_set,
|
|
||||||
backbone_reuse_force_compute_steps=
|
|
||||||
backbone_reuse_force_compute_steps_set,
|
|
||||||
backbone_reuse_mode=backbone_reuse_mode,
|
|
||||||
backbone_reuse_cache=backbone_reuse_cache,
|
|
||||||
backbone_reuse_step_stats=backbone_reuse_step_stats,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
img, pred_x0, model_output_action, model_output_state = outs
|
|
||||||
|
|
||||||
if run_head:
|
|
||||||
action_pred_cache = model_output_action.detach().clone()
|
|
||||||
state_pred_cache = model_output_state.detach().clone()
|
|
||||||
action = dp_ddim_scheduler_action.step(
|
action = dp_ddim_scheduler_action.step(
|
||||||
model_output_action,
|
model_output_action,
|
||||||
step,
|
step,
|
||||||
@@ -404,73 +297,35 @@ class DDIMSampler(object):
|
|||||||
state,
|
state,
|
||||||
generator=None,
|
generator=None,
|
||||||
).prev_sample
|
).prev_sample
|
||||||
x_action_frozen = action.detach().clone()
|
|
||||||
x_state_frozen = state.detach().clone()
|
|
||||||
else:
|
|
||||||
if head_skip_mode == "reuse_prediction":
|
|
||||||
action = dp_ddim_scheduler_action.step(
|
|
||||||
action_pred_cache,
|
|
||||||
step,
|
|
||||||
action,
|
|
||||||
generator=None,
|
|
||||||
).prev_sample
|
|
||||||
state = dp_ddim_scheduler_state.step(
|
|
||||||
state_pred_cache,
|
|
||||||
step,
|
|
||||||
state,
|
|
||||||
generator=None,
|
|
||||||
).prev_sample
|
|
||||||
x_action_frozen = action.detach().clone()
|
|
||||||
x_state_frozen = state.detach().clone()
|
|
||||||
else:
|
|
||||||
action = x_action_frozen
|
|
||||||
state = x_state_frozen
|
|
||||||
|
|
||||||
if should_sync:
|
if callback: callback(i)
|
||||||
torch.cuda.synchronize(sync_device)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
step_time_s = time.time() - step_start_time
|
|
||||||
|
|
||||||
if callback: callback(i)
|
if handoff_step > 0 and (i + 1) == handoff_step:
|
||||||
if img_callback: img_callback(pred_x0, i)
|
handoff = {
|
||||||
|
'samples': img.clone(),
|
||||||
|
'actions': action.clone(),
|
||||||
|
'states': state.clone(),
|
||||||
|
'pred_x0': pred_x0.clone(),
|
||||||
|
'step': i + 1,
|
||||||
|
}
|
||||||
|
if handoff_callback is not None:
|
||||||
|
handoff_callback(handoff)
|
||||||
|
if stop_at_handoff:
|
||||||
|
intermediates['handoff'] = handoff
|
||||||
|
break
|
||||||
|
|
||||||
reused_blocks = []
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
if backbone_reuse_step_stats is not None:
|
intermediates['x_inter'].append(img)
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
reused_blocks = sorted(backbone_reuse_step_stats['single'])
|
intermediates['x_inter_action'].append(action)
|
||||||
else:
|
intermediates['x_inter_state'].append(state)
|
||||||
reused_blocks = sorted(
|
finally:
|
||||||
backbone_reuse_step_stats['cond']
|
disable_cross_attn_kv_cache(self.model)
|
||||||
| backbone_reuse_step_stats['uncond'])
|
disable_ctx_cache(self.model)
|
||||||
|
|
||||||
if index % log_every_t == 0 or index == total_steps - 1:
|
|
||||||
intermediates['x_inter'].append(img)
|
|
||||||
intermediates['pred_x0'].append(pred_x0)
|
|
||||||
intermediates['x_inter_action'].append(action)
|
|
||||||
intermediates['x_inter_state'].append(state)
|
|
||||||
if head_log_steps_set is not None and i in head_log_steps_set:
|
|
||||||
intermediates['head_sparse_logs'][i] = {
|
|
||||||
'step_index': i,
|
|
||||||
'ddim_timestep': int(step),
|
|
||||||
'head_executed': run_head,
|
|
||||||
'img': img.detach().cpu(),
|
|
||||||
'pred_x0': pred_x0.detach().cpu(),
|
|
||||||
'action': action.detach().cpu(),
|
|
||||||
'state': state.detach().cpu(),
|
|
||||||
}
|
|
||||||
if record_step_outputs:
|
|
||||||
intermediates['step_records'].append({
|
|
||||||
'step_index': i + 1,
|
|
||||||
'ddim_timestep': int(step),
|
|
||||||
'head_executed': run_head,
|
|
||||||
'img': img.detach().cpu(),
|
|
||||||
'pred_x0': pred_x0.detach().cpu(),
|
|
||||||
'action': action.detach().cpu(),
|
|
||||||
'state': state.detach().cpu(),
|
|
||||||
'step_time_s': step_time_s,
|
|
||||||
'backbone_reused_blocks_count': len(reused_blocks),
|
|
||||||
'backbone_reuse_hit_blocks': ",".join(reused_blocks),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
if handoff_step > 0:
|
||||||
|
intermediates['handoff'] = handoff
|
||||||
return img, action, state, intermediates
|
return img, action, state, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -481,6 +336,7 @@ class DDIMSampler(object):
|
|||||||
c,
|
c,
|
||||||
t,
|
t,
|
||||||
index,
|
index,
|
||||||
|
precision=None,
|
||||||
repeat_noise=False,
|
repeat_noise=False,
|
||||||
use_original_steps=False,
|
use_original_steps=False,
|
||||||
quantize_denoised=False,
|
quantize_denoised=False,
|
||||||
@@ -495,62 +351,35 @@ class DDIMSampler(object):
|
|||||||
mask=None,
|
mask=None,
|
||||||
x0=None,
|
x0=None,
|
||||||
guidance_rescale=0.0,
|
guidance_rescale=0.0,
|
||||||
run_head=True,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
if x.dim() == 5:
|
|
||||||
is_video = True
|
|
||||||
else:
|
|
||||||
is_video = False
|
|
||||||
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
use_autocast = precision == 16 and device.type == 'cuda'
|
||||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
with torch.cuda.amp.autocast(enabled=use_autocast,
|
||||||
x,
|
dtype=torch.float16):
|
||||||
x_action,
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
x_state,
|
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||||
t,
|
x, x_action, x_state, t, c, **kwargs) # unet denoiser
|
||||||
c,
|
|
||||||
run_head=run_head,
|
|
||||||
backbone_reuse_branch="single",
|
|
||||||
**kwargs) # unet denoiser
|
|
||||||
else:
|
|
||||||
# do_classifier_free_guidance
|
|
||||||
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
|
||||||
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
|
||||||
x,
|
|
||||||
x_action,
|
|
||||||
x_state,
|
|
||||||
t,
|
|
||||||
c,
|
|
||||||
run_head=run_head,
|
|
||||||
backbone_reuse_branch="cond",
|
|
||||||
**kwargs)
|
|
||||||
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
|
||||||
x,
|
|
||||||
x_action,
|
|
||||||
x_state,
|
|
||||||
t,
|
|
||||||
unconditional_conditioning,
|
|
||||||
run_head=run_head,
|
|
||||||
backbone_reuse_branch="uncond",
|
|
||||||
**kwargs)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
# do_classifier_free_guidance
|
||||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
||||||
e_t_cond - e_t_uncond)
|
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
||||||
if run_head:
|
x, x_action, x_state, t, c, **kwargs)
|
||||||
|
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
||||||
|
x, x_action, x_state, t, unconditional_conditioning,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||||
|
e_t_cond - e_t_uncond)
|
||||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||||
e_t_cond_action - e_t_uncond_action)
|
e_t_cond_action - e_t_uncond_action)
|
||||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||||
e_t_cond_state - e_t_uncond_state)
|
e_t_cond_state - e_t_uncond_state)
|
||||||
else:
|
|
||||||
model_output_action = None
|
|
||||||
model_output_state = None
|
|
||||||
|
|
||||||
if guidance_rescale > 0.0:
|
if guidance_rescale > 0.0:
|
||||||
model_output = rescale_noise_cfg(
|
model_output = rescale_noise_cfg(
|
||||||
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
||||||
if run_head:
|
|
||||||
model_output_action = rescale_noise_cfg(
|
model_output_action = rescale_noise_cfg(
|
||||||
model_output_action,
|
model_output_action,
|
||||||
e_t_cond_action,
|
e_t_cond_action,
|
||||||
@@ -575,17 +404,11 @@ class DDIMSampler(object):
|
|||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
if is_video:
|
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||||
size = (b, 1, 1, 1, 1)
|
a_t = alphas[index]
|
||||||
else:
|
a_prev = alphas_prev[index]
|
||||||
size = (b, 1, 1, 1)
|
sigma_t = sigmas[index]
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||||
a_t = torch.full(size, alphas[index], device=device)
|
|
||||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full(size,
|
|
||||||
sqrt_one_minus_alphas[index],
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
@@ -593,12 +416,8 @@ class DDIMSampler(object):
|
|||||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||||
|
|
||||||
if self.model.use_dynamic_rescale:
|
if self.model.use_dynamic_rescale:
|
||||||
scale_t = torch.full(size,
|
scale_t = self.ddim_scale_arr[index]
|
||||||
self.ddim_scale_arr[index],
|
prev_scale_t = self.ddim_scale_arr_prev[index]
|
||||||
device=device)
|
|
||||||
prev_scale_t = torch.full(size,
|
|
||||||
self.ddim_scale_arr_prev[index],
|
|
||||||
device=device)
|
|
||||||
rescale = (prev_scale_t / scale_t)
|
rescale = (prev_scale_t / scale_t)
|
||||||
pred_x0 *= rescale
|
pred_x0 *= rescale
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import time
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -686,6 +685,37 @@ class WMAModel(nn.Module):
|
|||||||
self.action_token_projector = instantiate_from_config(
|
self.action_token_projector = instantiate_from_config(
|
||||||
stem_process_config)
|
stem_process_config)
|
||||||
|
|
||||||
|
# Context precomputation cache
|
||||||
|
self._ctx_cache_enabled = False
|
||||||
|
self._ctx_cache = {}
|
||||||
|
self._trt_backbone = None # TRT engine for video UNet backbone
|
||||||
|
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||||
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state.pop('_state_stream', None)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
if not hasattr(self, '_ctx_cache_enabled'):
|
||||||
|
self._ctx_cache_enabled = False
|
||||||
|
if not hasattr(self, '_ctx_cache'):
|
||||||
|
self._ctx_cache = {}
|
||||||
|
if not hasattr(self, '_trt_backbone'):
|
||||||
|
self._trt_backbone = None
|
||||||
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def load_trt_backbone(self, engine_path, n_hs_a=9):
|
||||||
|
"""Load a TensorRT engine for the video UNet backbone."""
|
||||||
|
from unifolm_wma.trt_utils import TRTBackbone
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
self._trt_backbone = TRTBackbone(engine_path,
|
||||||
|
n_hs_a=n_hs_a,
|
||||||
|
device=device)
|
||||||
|
print(f">>> TRT backbone loaded from {engine_path} on {device}")
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
x_action: Tensor,
|
x_action: Tensor,
|
||||||
@@ -714,80 +744,70 @@ class WMAModel(nn.Module):
|
|||||||
Tuple of Tensors for predictions:
|
Tuple of Tensors for predictions:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
b, _, t, _, _ = x.shape
|
b, _, t, _, _ = x.shape
|
||||||
run_head = kwargs.pop("run_head", True)
|
|
||||||
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
|
|
||||||
backbone_step_index = kwargs.pop("backbone_step_index", None)
|
|
||||||
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
|
|
||||||
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
|
|
||||||
None)
|
|
||||||
backbone_reuse_schedule_steps = kwargs.pop(
|
|
||||||
"backbone_reuse_schedule_steps", None)
|
|
||||||
backbone_reuse_force_compute_steps = kwargs.pop(
|
|
||||||
"backbone_reuse_force_compute_steps", None)
|
|
||||||
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
|
|
||||||
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
|
|
||||||
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
|
|
||||||
None)
|
|
||||||
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
|
|
||||||
t_emb = timestep_embedding(timesteps,
|
t_emb = timestep_embedding(timesteps,
|
||||||
self.model_channels,
|
self.model_channels,
|
||||||
repeat_only=False).type(x.dtype)
|
repeat_only=False).type(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
bt, l_context, _ = context.shape
|
_ctx_key = context.data_ptr()
|
||||||
if self.base_model_gen_only:
|
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
|
||||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
context = self._ctx_cache[_ctx_key]
|
||||||
else:
|
else:
|
||||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
bt, l_context, _ = context.shape
|
||||||
context_agent_state = context[:, :self.n_obs_steps]
|
if self.base_model_gen_only:
|
||||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||||
77, :]
|
else:
|
||||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||||
context_agent_state = context_agent_state.repeat_interleave(
|
context_agent_state = context[:, :self.n_obs_steps]
|
||||||
repeats=t, dim=0)
|
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
77, :]
|
||||||
context_img = rearrange(context_img,
|
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||||
'b (t l) c -> (b t) l c',
|
context_agent_state = context_agent_state.repeat_interleave(
|
||||||
t=t)
|
repeats=t, dim=0)
|
||||||
context = torch.cat(
|
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||||
[context_agent_state, context_text, context_img], dim=1)
|
context_img = rearrange(context_img,
|
||||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
'b (t l) c -> (b t) l c',
|
||||||
context_agent_state = context[:, :self.n_obs_steps]
|
t=t)
|
||||||
context_agent_action = context[:, self.
|
context = torch.cat(
|
||||||
n_obs_steps:self.n_obs_steps +
|
[context_agent_state, context_text, context_img], dim=1)
|
||||||
16, :]
|
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||||
context_agent_action = rearrange(
|
context_agent_state = context[:, :self.n_obs_steps]
|
||||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
context_agent_action = context[:, self.
|
||||||
context_agent_action = self.action_token_projector(
|
n_obs_steps:self.n_obs_steps +
|
||||||
context_agent_action)
|
16, :]
|
||||||
context_agent_action = rearrange(context_agent_action,
|
context_agent_action = rearrange(
|
||||||
'(b o) l d -> b o l d',
|
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||||
o=t)
|
context_agent_action = self.action_token_projector(
|
||||||
context_agent_action = rearrange(context_agent_action,
|
context_agent_action)
|
||||||
'b o (t l) d -> b o t l d',
|
context_agent_action = rearrange(context_agent_action,
|
||||||
t=t)
|
'(b o) l d -> b o l d',
|
||||||
context_agent_action = context_agent_action.permute(
|
o=t)
|
||||||
0, 2, 1, 3, 4)
|
context_agent_action = rearrange(context_agent_action,
|
||||||
context_agent_action = rearrange(context_agent_action,
|
'b o (t l) d -> b o t l d',
|
||||||
'b t o l d -> (b t) (o l) d')
|
t=t)
|
||||||
|
context_agent_action = context_agent_action.permute(
|
||||||
|
0, 2, 1, 3, 4)
|
||||||
|
context_agent_action = rearrange(context_agent_action,
|
||||||
|
'b t o l d -> (b t) (o l) d')
|
||||||
|
|
||||||
context_text = context[:, self.n_obs_steps +
|
context_text = context[:, self.n_obs_steps +
|
||||||
16:self.n_obs_steps + 16 + 77, :]
|
16:self.n_obs_steps + 16 + 77, :]
|
||||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||||
|
|
||||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||||
context_img = rearrange(context_img,
|
context_img = rearrange(context_img,
|
||||||
'b (t l) c -> (b t) l c',
|
'b (t l) c -> (b t) l c',
|
||||||
t=t)
|
t=t)
|
||||||
context_agent_state = context_agent_state.repeat_interleave(
|
context_agent_state = context_agent_state.repeat_interleave(
|
||||||
repeats=t, dim=0)
|
repeats=t, dim=0)
|
||||||
context = torch.cat([
|
context = torch.cat([
|
||||||
context_agent_state, context_agent_action, context_text,
|
context_agent_state, context_agent_action, context_text,
|
||||||
context_img
|
context_img
|
||||||
],
|
],
|
||||||
dim=1)
|
dim=1)
|
||||||
|
if self._ctx_cache_enabled:
|
||||||
|
self._ctx_cache[_ctx_key] = context
|
||||||
|
|
||||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||||
|
|
||||||
@@ -807,150 +827,95 @@ class WMAModel(nn.Module):
|
|||||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||||
emb = emb + fs_embed
|
emb = emb + fs_embed
|
||||||
|
|
||||||
def run_block_with_profile(block_name: str, block_stage: str,
|
if self._trt_backbone is not None:
|
||||||
block_index: int | None,
|
# TRT path: run backbone via TensorRT engine
|
||||||
fn: Callable[[], Tensor]) -> Tensor:
|
h_in = x.type(self.dtype).contiguous()
|
||||||
if backbone_block_profiler is None or backbone_step_index is None:
|
y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous())
|
||||||
return fn()
|
else:
|
||||||
if x.device.type == "cuda":
|
# PyTorch path: original backbone
|
||||||
torch.cuda.synchronize(x.device)
|
h = x.type(self.dtype)
|
||||||
start_time = time.perf_counter()
|
adapter_idx = 0
|
||||||
out = fn()
|
hs = []
|
||||||
if x.device.type == "cuda":
|
hs_a = []
|
||||||
torch.cuda.synchronize(x.device)
|
for id, module in enumerate(self.input_blocks):
|
||||||
backbone_block_profiler.record_block(
|
h = module(h, emb, context=context, batch_size=b)
|
||||||
step=int(backbone_step_index),
|
|
||||||
block_name=block_name,
|
|
||||||
block_stage=block_stage,
|
|
||||||
block_index=block_index,
|
|
||||||
output=out,
|
|
||||||
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
reuse_cache_branch: Dict[str, Tensor] | None = None
|
|
||||||
if backbone_reuse_cache is not None:
|
|
||||||
reuse_cache_branch = backbone_reuse_cache.setdefault(
|
|
||||||
backbone_reuse_branch, {})
|
|
||||||
|
|
||||||
def should_reuse_output_block(block_name: str) -> bool:
|
|
||||||
if backbone_reuse_mode != "reuse_output":
|
|
||||||
return False
|
|
||||||
if backbone_step_index is None or backbone_reuse_start_step is None:
|
|
||||||
return False
|
|
||||||
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
|
|
||||||
return False
|
|
||||||
if int(backbone_step_index) < int(backbone_reuse_start_step):
|
|
||||||
return False
|
|
||||||
if (backbone_reuse_force_compute_steps is not None
|
|
||||||
and int(backbone_step_index)
|
|
||||||
in backbone_reuse_force_compute_steps):
|
|
||||||
return False
|
|
||||||
if (backbone_reuse_schedule_steps is not None
|
|
||||||
and int(backbone_step_index)
|
|
||||||
in backbone_reuse_schedule_steps):
|
|
||||||
return False
|
|
||||||
if reuse_cache_branch is None:
|
|
||||||
return False
|
|
||||||
return block_name in reuse_cache_branch
|
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
|
||||||
adapter_idx = 0
|
|
||||||
hs = []
|
|
||||||
hs_a = []
|
|
||||||
for id, module in enumerate(self.input_blocks):
|
|
||||||
def run_input_block() -> Tensor:
|
|
||||||
block_out = module(h, emb, context=context, batch_size=b)
|
|
||||||
if id == 0 and self.addition_attention:
|
if id == 0 and self.addition_attention:
|
||||||
block_out = self.init_attn(block_out,
|
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||||
emb,
|
# plug-in adapter features
|
||||||
context=context,
|
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||||
batch_size=b)
|
h = h + features_adapter[adapter_idx]
|
||||||
return block_out
|
adapter_idx += 1
|
||||||
|
if id != 0:
|
||||||
|
if isinstance(module[0], Downsample):
|
||||||
|
hs_a.append(
|
||||||
|
rearrange(hs[-1], '(b t) c h w -> b t c h w', b=b))
|
||||||
|
hs.append(h)
|
||||||
|
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
|
||||||
|
|
||||||
h = run_block_with_profile(
|
if features_adapter is not None:
|
||||||
block_name=f"input_{id}",
|
assert len(
|
||||||
block_stage="input_blocks",
|
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||||
block_index=id,
|
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||||
fn=run_input_block,
|
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
|
||||||
)
|
|
||||||
# plug-in adapter features
|
hs_out = []
|
||||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
for module in self.output_blocks:
|
||||||
h = h + features_adapter[adapter_idx]
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
adapter_idx += 1
|
h = module(h, emb, context=context, batch_size=b)
|
||||||
if id != 0:
|
if isinstance(module[-1], Upsample):
|
||||||
if isinstance(module[0], Downsample):
|
|
||||||
hs_a.append(
|
hs_a.append(
|
||||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
|
||||||
hs.append(h)
|
hs_out.append(h)
|
||||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
h = h.type(x.dtype)
|
||||||
|
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
|
||||||
|
|
||||||
if features_adapter is not None:
|
y = self.out(h)
|
||||||
assert len(
|
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
|
||||||
h = run_block_with_profile(
|
|
||||||
block_name="middle",
|
|
||||||
block_stage="middle_block",
|
|
||||||
block_index=0,
|
|
||||||
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
|
|
||||||
)
|
|
||||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
hs_out = []
|
if not self.base_model_gen_only:
|
||||||
for id, module in enumerate(self.output_blocks):
|
|
||||||
skip_h = hs.pop()
|
|
||||||
block_name = f"output_{id}"
|
|
||||||
|
|
||||||
def run_output_block() -> Tensor:
|
|
||||||
return module(torch.cat([h, skip_h], dim=1),
|
|
||||||
emb,
|
|
||||||
context=context,
|
|
||||||
batch_size=b)
|
|
||||||
|
|
||||||
if should_reuse_output_block(block_name):
|
|
||||||
h = reuse_cache_branch[block_name].to(device=h.device,
|
|
||||||
dtype=h.dtype)
|
|
||||||
if backbone_reuse_step_stats is not None:
|
|
||||||
backbone_reuse_step_stats.setdefault(
|
|
||||||
backbone_reuse_branch, set()).add(block_name)
|
|
||||||
else:
|
|
||||||
h = run_block_with_profile(
|
|
||||||
block_name=block_name,
|
|
||||||
block_stage="output_blocks",
|
|
||||||
block_index=id,
|
|
||||||
fn=run_output_block,
|
|
||||||
)
|
|
||||||
if (reuse_cache_branch is not None and backbone_reuse_mode ==
|
|
||||||
"reuse_output"
|
|
||||||
and backbone_reuse_blocks is not None
|
|
||||||
and block_name in backbone_reuse_blocks):
|
|
||||||
reuse_cache_branch[block_name] = h.detach().clone()
|
|
||||||
if isinstance(module[-1], Upsample):
|
|
||||||
hs_a.append(
|
|
||||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
hs_out.append(h)
|
|
||||||
h = h.type(x.dtype)
|
|
||||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
y = self.out(h)
|
|
||||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
|
||||||
|
|
||||||
if not self.base_model_gen_only and run_head:
|
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
context_action[:2], **kwargs)
|
is_sim_mode = context_action[2] if len(context_action) > 2 else False
|
||||||
# Predict state
|
|
||||||
if b > 1:
|
if is_sim_mode:
|
||||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
# WM mode: only need state_unet, skip action_unet
|
||||||
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
context_action[:2], **kwargs)
|
context_action[:2], **kwargs)
|
||||||
|
a_y = torch.zeros_like(x_action)
|
||||||
else:
|
else:
|
||||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
# DM mode: only need action_unet, skip state_unet
|
||||||
context_action[:2], **kwargs)
|
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||||
elif not self.base_model_gen_only:
|
context_action[:2], **kwargs)
|
||||||
a_y = None
|
s_y = torch.zeros_like(x_state)
|
||||||
s_y = None
|
|
||||||
else:
|
else:
|
||||||
a_y = torch.zeros_like(x_action)
|
a_y = torch.zeros_like(x_action)
|
||||||
s_y = torch.zeros_like(x_state)
|
s_y = torch.zeros_like(x_state)
|
||||||
|
|
||||||
return y, a_y, s_y
|
return y, a_y, s_y
|
||||||
|
|
||||||
|
|
||||||
|
def enable_ctx_cache(model):
|
||||||
|
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, WMAModel):
|
||||||
|
m._ctx_cache_enabled = True
|
||||||
|
m._ctx_cache = {}
|
||||||
|
# conditional_unet1d cache
|
||||||
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, ConditionalUnet1D):
|
||||||
|
m._global_cond_cache_enabled = True
|
||||||
|
m._global_cond_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def disable_ctx_cache(model):
|
||||||
|
"""Disable and clear context precomputation cache."""
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, WMAModel):
|
||||||
|
m._ctx_cache_enabled = False
|
||||||
|
m._ctx_cache = {}
|
||||||
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, ConditionalUnet1D):
|
||||||
|
m._global_cond_cache_enabled = False
|
||||||
|
m._global_cond_cache = {}
|
||||||
|
|||||||
196
src/unifolm_wma/trt_utils.py
Normal file
196
src/unifolm_wma/trt_utils.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""TensorRT acceleration utilities for the video UNet backbone."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from unifolm_wma.modules.networks.wma_model import Downsample, Upsample
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_cuda_device(device) -> torch.device:
|
||||||
|
if device is None:
|
||||||
|
return torch.device('cuda', torch.cuda.current_device())
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
if device.type != 'cuda':
|
||||||
|
raise ValueError(f"TensorRT requires a CUDA device, got {device}.")
|
||||||
|
return device
|
||||||
|
if isinstance(device, int):
|
||||||
|
return torch.device('cuda', device)
|
||||||
|
normalized = torch.device(device)
|
||||||
|
if normalized.type != 'cuda':
|
||||||
|
raise ValueError(f"TensorRT requires a CUDA device, got {normalized}.")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
class VideoBackboneForExport(nn.Module):
|
||||||
|
"""Wrapper that isolates the video UNet backbone for ONNX export.
|
||||||
|
|
||||||
|
Takes already-preprocessed inputs (after context/time embedding prep)
|
||||||
|
and returns y + hs_a as a flat tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wma_model):
|
||||||
|
super().__init__()
|
||||||
|
self.input_blocks = wma_model.input_blocks
|
||||||
|
self.middle_block = wma_model.middle_block
|
||||||
|
self.output_blocks = wma_model.output_blocks
|
||||||
|
self.out = wma_model.out
|
||||||
|
self.addition_attention = wma_model.addition_attention
|
||||||
|
if self.addition_attention:
|
||||||
|
self.init_attn = wma_model.init_attn
|
||||||
|
self.dtype = wma_model.dtype
|
||||||
|
|
||||||
|
def forward(self, h, emb, context):
|
||||||
|
t = 16
|
||||||
|
b = 1
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
hs_a = []
|
||||||
|
h = h.type(self.dtype)
|
||||||
|
for id, module in enumerate(self.input_blocks):
|
||||||
|
h = module(h, emb, context=context, batch_size=b)
|
||||||
|
if id == 0 and self.addition_attention:
|
||||||
|
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||||
|
if id != 0:
|
||||||
|
if isinstance(module[0], Downsample):
|
||||||
|
hs_a.append(rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||||
|
hs.append(h)
|
||||||
|
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||||
|
|
||||||
|
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||||
|
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||||
|
|
||||||
|
hs_out = []
|
||||||
|
for module in self.output_blocks:
|
||||||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
h = module(h, emb, context=context, batch_size=b)
|
||||||
|
if isinstance(module[-1], Upsample):
|
||||||
|
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||||
|
hs_out.append(h)
|
||||||
|
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||||
|
|
||||||
|
y = self.out(h.type(h.dtype))
|
||||||
|
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||||
|
return (y, *hs_a)
|
||||||
|
|
||||||
|
|
||||||
|
def export_backbone_onnx(model, save_path, context_len=95):
|
||||||
|
wma = model.model.diffusion_model
|
||||||
|
wrapper = VideoBackboneForExport(wma)
|
||||||
|
device = next(wma.parameters()).device
|
||||||
|
wrapper.eval().to(device)
|
||||||
|
|
||||||
|
for m in wrapper.modules():
|
||||||
|
if hasattr(m, 'checkpoint'):
|
||||||
|
m.checkpoint = False
|
||||||
|
if hasattr(m, 'use_checkpoint'):
|
||||||
|
m.use_checkpoint = False
|
||||||
|
|
||||||
|
import xformers.ops
|
||||||
|
_orig_mea = xformers.ops.memory_efficient_attention
|
||||||
|
def _sdpa_replacement(q, k, v, attn_bias=None, op=None, **kw):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
||||||
|
xformers.ops.memory_efficient_attention = _sdpa_replacement
|
||||||
|
|
||||||
|
BT = 16
|
||||||
|
emb_dim = wma.model_channels * 4
|
||||||
|
ctx_dim = 1024
|
||||||
|
in_ch = wma.in_channels
|
||||||
|
|
||||||
|
dummy_h = torch.randn(BT,
|
||||||
|
in_ch,
|
||||||
|
40,
|
||||||
|
64,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
dummy_emb = torch.randn(BT,
|
||||||
|
emb_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
dummy_ctx = torch.randn(BT,
|
||||||
|
context_len,
|
||||||
|
ctx_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = wrapper(dummy_h, dummy_emb, dummy_ctx)
|
||||||
|
n_outputs = len(outputs)
|
||||||
|
print(f">>> Backbone has {n_outputs} outputs (1 y + {n_outputs-1} hs_a)")
|
||||||
|
for i, o in enumerate(outputs):
|
||||||
|
print(f" output[{i}]: {o.shape} {o.dtype}")
|
||||||
|
|
||||||
|
output_names = ['y'] + [f'hs_a_{i}' for i in range(n_outputs - 1)]
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
wrapper,
|
||||||
|
(dummy_h, dummy_emb, dummy_ctx),
|
||||||
|
save_path,
|
||||||
|
input_names=['h', 'emb', 'context'],
|
||||||
|
output_names=output_names,
|
||||||
|
opset_version=17,
|
||||||
|
do_constant_folding=True,
|
||||||
|
)
|
||||||
|
print(f">>> ONNX exported to {save_path}")
|
||||||
|
xformers.ops.memory_efficient_attention = _orig_mea
|
||||||
|
return n_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TRTBackbone:
|
||||||
|
"""TensorRT runtime wrapper for the video UNet backbone."""
|
||||||
|
|
||||||
|
def __init__(self, engine_path, n_hs_a=9, device=None):
|
||||||
|
import tensorrt as trt
|
||||||
|
|
||||||
|
self.device = _normalize_cuda_device(device)
|
||||||
|
self.logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with torch.cuda.device(self.device):
|
||||||
|
with open(engine_path, 'rb') as f:
|
||||||
|
runtime = trt.Runtime(self.logger)
|
||||||
|
self.engine = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
self.context = self.engine.create_execution_context()
|
||||||
|
self.n_hs_a = n_hs_a
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
self.output_buffers = {}
|
||||||
|
with torch.cuda.device(self.device):
|
||||||
|
for i in range(self.engine.num_io_tensors):
|
||||||
|
name = self.engine.get_tensor_name(i)
|
||||||
|
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
||||||
|
shape = self.engine.get_tensor_shape(name)
|
||||||
|
np_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
||||||
|
buf = torch.empty(
|
||||||
|
list(shape),
|
||||||
|
dtype=torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype,
|
||||||
|
device=self.device)
|
||||||
|
self.output_buffers[name] = buf
|
||||||
|
print(
|
||||||
|
f" TRT output '{name}': {list(shape)} {buf.dtype} on {self.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, h, emb, context):
|
||||||
|
import numpy as np
|
||||||
|
import tensorrt as trt
|
||||||
|
|
||||||
|
bound_inputs = {}
|
||||||
|
with torch.cuda.device(self.device):
|
||||||
|
for name, tensor in [('h', h), ('emb', emb), ('context', context)]:
|
||||||
|
expected_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
||||||
|
torch_expected = torch.from_numpy(
|
||||||
|
np.empty(0, dtype=expected_dtype)).dtype
|
||||||
|
if tensor.device != self.device or tensor.dtype != torch_expected:
|
||||||
|
tensor = tensor.to(device=self.device,
|
||||||
|
dtype=torch_expected,
|
||||||
|
non_blocking=True)
|
||||||
|
tensor = tensor.contiguous()
|
||||||
|
bound_inputs[name] = tensor
|
||||||
|
self.context.set_tensor_address(name, tensor.data_ptr())
|
||||||
|
|
||||||
|
for name, buf in self.output_buffers.items():
|
||||||
|
self.context.set_tensor_address(name, buf.data_ptr())
|
||||||
|
|
||||||
|
stream = torch.cuda.current_stream(device=self.device)
|
||||||
|
self.context.execute_async_v3(stream.cuda_stream)
|
||||||
|
|
||||||
|
y = self.output_buffers['y']
|
||||||
|
hs_a = [self.output_buffers[f'hs_a_{i}'] for i in range(self.n_hs_a)]
|
||||||
|
return y, hs_a
|
||||||
@@ -2,11 +2,11 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case2"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 "${PYTHON_BIN:-python}" scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=0,1 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
--savedir "${res_dir}/output/sparse_8" \
|
--savedir "${res_dir}/output" \
|
||||||
--bs 1 --height 320 --width 512 \
|
--bs 1 --height 320 --width 512 \
|
||||||
--unconditional_guidance_scale 1.0 \
|
--unconditional_guidance_scale 1.0 \
|
||||||
--ddim_steps 50 \
|
--ddim_steps 50 \
|
||||||
@@ -21,9 +21,7 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
|||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--analysis_log_metrics \
|
--pipeline_split_step 30 \
|
||||||
--analysis_reference_steps 50 \
|
--pipeline_multi_gpu \
|
||||||
--head_schedule_steps 0 7 14 21 28 35 42 49 \
|
--pipeline_gpu_ids 0 1
|
||||||
--head_skip_mode reuse_prediction \
|
|
||||||
--head_log_steps 40 43 46 47 48 49
|
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
Reference in New Issue
Block a user