Compare commits
2 Commits
57ba85d147
...
qhy-merged
| Author | SHA1 | Date | |
|---|---|---|---|
| 3069666a15 | |||
| 68369cc15f |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -129,5 +129,4 @@ Experiment/checkpoint
|
||||
Experiment/log
|
||||
|
||||
*.ckpt
|
||||
*.0
|
||||
unitree_z1_dual_arm_cleanup_pencils/case1/profile_output/traces/wx-ms-w7900d-0032_742306.1770698186047591119.pt.trace.json
|
||||
*.0
|
||||
@@ -222,7 +222,7 @@ data:
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import argparse, os, glob
|
||||
from contextlib import nullcontext
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pandas as pd
|
||||
import random
|
||||
import torch
|
||||
@@ -11,13 +13,15 @@ import einops
|
||||
import warnings
|
||||
import imageio
|
||||
|
||||
from typing import Optional, List, Any
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from eval_utils import populate_queues
|
||||
from collections import deque
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@@ -28,6 +32,80 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ========== Async I/O utilities ==========
|
||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||
_io_futures: List[Any] = []
|
||||
|
||||
|
||||
def _get_io_executor() -> ThreadPoolExecutor:
|
||||
global _io_executor
|
||||
if _io_executor is None:
|
||||
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||
return _io_executor
|
||||
|
||||
|
||||
def _flush_io():
|
||||
"""Wait for all pending async I/O to finish."""
|
||||
global _io_futures
|
||||
for fut in _io_futures:
|
||||
try:
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
print(f">>> [async I/O] error: {e}")
|
||||
_io_futures.clear()
|
||||
|
||||
|
||||
atexit.register(_flush_io)
|
||||
|
||||
|
||||
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Submit video saving to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
|
||||
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
|
||||
video = video_cpu.float()
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
|
||||
|
||||
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
|
||||
"""Submit tensorboard logging to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||
@@ -185,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
device (str): Target device for loaded tensors.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
state_dict = torch.load(ckpt, map_location=device)
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
@@ -610,42 +689,63 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Load config
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Apply precision settings before moving to GPU
|
||||
model = apply_precision_settings(model, args)
|
||||
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||
if os.path.exists(prepared_path):
|
||||
# ---- Fast path: load the fully-prepared model ----
|
||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||
model = torch.load(prepared_path,
|
||||
map_location=f"cuda:{gpu_no}",
|
||||
weights_only=False)
|
||||
model.eval()
|
||||
|
||||
# Compile hot ResBlocks for operator fusion
|
||||
apply_torch_compile(model)
|
||||
# Restore autocast attributes (weights already cast, just need contexts)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
|
||||
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
|
||||
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
|
||||
|
||||
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||
from unifolm_wma.modules.attention import CrossAttention
|
||||
kv_count = sum(1 for m in model.modules()
|
||||
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||
# Compile hot ResBlocks for operator fusion
|
||||
apply_torch_compile(model)
|
||||
|
||||
# Export precision-converted checkpoint if requested
|
||||
if args.export_precision_ckpt:
|
||||
export_path = args.export_precision_ckpt
|
||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||
return
|
||||
print(f">>> Prepared model loaded.")
|
||||
else:
|
||||
# ---- Normal path: construct + checkpoint + casting ----
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path,
|
||||
device=f"cuda:{gpu_no}")
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Build unnomalizer
|
||||
# Apply precision settings before moving to GPU
|
||||
model = apply_precision_settings(model, args)
|
||||
|
||||
# Export precision-converted checkpoint if requested
|
||||
if args.export_precision_ckpt:
|
||||
export_path = args.export_precision_ckpt
|
||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||
return
|
||||
|
||||
model = model.cuda(gpu_no)
|
||||
|
||||
# Save prepared model for fast loading next time (before torch.compile)
|
||||
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
|
||||
apply_torch_compile(model)
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
# Run over data
|
||||
@@ -823,28 +923,28 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Save the imagen videos for decision-making
|
||||
if pred_videos_0 is not None:
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
save_results_async(pred_videos_1,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Collect the result of world-model interactions
|
||||
@@ -852,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
||||
@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if getattr(self, '_channels_last', False):
|
||||
x = x.to(memory_format=torch.channels_last)
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
|
||||
return posterior
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
if getattr(self, '_channels_last', False):
|
||||
z = z.to(memory_format=torch.channels_last)
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else: ## Consume less GPU memory but slower
|
||||
bs = getattr(self, 'vae_encode_bs', 1)
|
||||
results = []
|
||||
for index in range(x.shape[0]):
|
||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
||||
1, :, :, :])
|
||||
for i in range(0, x.shape[0], bs):
|
||||
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
@@ -1109,14 +1109,14 @@ class LatentDiffusion(DDPM):
|
||||
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
||||
z = z.to(dtype=vae_dtype)
|
||||
|
||||
z = 1. / self.scale_factor * z
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
bs = getattr(self, 'vae_decode_bs', 1)
|
||||
results = []
|
||||
for index in range(z.shape[0]):
|
||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
||||
for i in range(0, z.shape[0], bs):
|
||||
frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
|
||||
@@ -567,11 +567,6 @@ class ConditionalUnet1D(nn.Module):
|
||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
# Pre-expand global_feature once (reused in every down/mid/up block)
|
||||
if self.use_linear_act_proj:
|
||||
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, T, -1)
|
||||
else:
|
||||
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, 2, -1)
|
||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||
|
||||
@@ -608,11 +603,15 @@ class ConditionalUnet1D(nn.Module):
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
h.append(x)
|
||||
@@ -639,11 +638,15 @@ class ConditionalUnet1D(nn.Module):
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
|
||||
@@ -680,12 +683,16 @@ class ConditionalUnet1D(nn.Module):
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
|
||||
cur_global_feature = torch.cat(
|
||||
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(python3:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -251,13 +251,6 @@ class DDIMSampler(object):
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||
noise_buf = torch.empty_like(img)
|
||||
# Pre-convert schedule arrays to inference dtype (avoid per-step .to())
|
||||
_dtype = img.dtype
|
||||
_alphas = (self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas).to(_dtype)
|
||||
_alphas_prev = (self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev).to(_dtype)
|
||||
_sqrt_one_minus = (self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas).to(_dtype)
|
||||
_sigmas = (self.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas).to(_dtype)
|
||||
enable_cross_attn_kv_cache(self.model)
|
||||
enable_ctx_cache(self.model)
|
||||
try:
|
||||
@@ -293,8 +286,6 @@ class DDIMSampler(object):
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
noise_buf=noise_buf,
|
||||
schedule_arrays=(_alphas, _alphas_prev, _sqrt_one_minus, _sigmas),
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
@@ -348,8 +339,6 @@ class DDIMSampler(object):
|
||||
mask=None,
|
||||
x0=None,
|
||||
guidance_rescale=0.0,
|
||||
noise_buf=None,
|
||||
schedule_arrays=None,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
@@ -395,18 +384,16 @@ class DDIMSampler(object):
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
if schedule_arrays is not None:
|
||||
alphas, alphas_prev, sqrt_one_minus_alphas, sigmas = schedule_arrays
|
||||
else:
|
||||
alphas = (self.model.alphas_cumprod if use_original_steps else self.ddim_alphas).to(x.dtype)
|
||||
alphas_prev = (self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev).to(x.dtype)
|
||||
sqrt_one_minus_alphas = (self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas).to(x.dtype)
|
||||
sigmas = (self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas).to(x.dtype)
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
a_t = alphas[index]
|
||||
a_prev = alphas_prev[index]
|
||||
sigma_t = sigmas[index]
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index].to(x.dtype)
|
||||
a_prev = alphas_prev[index].to(x.dtype)
|
||||
sigma_t = sigmas[index].to(x.dtype)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
@@ -424,12 +411,8 @@ class DDIMSampler(object):
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
if noise_buf is not None:
|
||||
noise_buf.normal_()
|
||||
noise = sigma_t * noise_buf * temperature
|
||||
else:
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
|
||||
@@ -99,7 +99,6 @@ class CrossAttention(nn.Module):
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self._kv_cache = {}
|
||||
self._kv_cache_enabled = False
|
||||
self._kv_fused = False
|
||||
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
@@ -117,27 +116,6 @@ class CrossAttention(nn.Module):
|
||||
self.register_parameter('alpha_caa',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
def fuse_kv(self):
|
||||
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
|
||||
k_w = self.to_k.weight # (inner_dim, context_dim)
|
||||
v_w = self.to_v.weight
|
||||
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
|
||||
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
|
||||
del self.to_k, self.to_v
|
||||
if self.image_cross_attention:
|
||||
for suffix in ('_ip', '_as', '_aa'):
|
||||
k_attr = f'to_k{suffix}'
|
||||
v_attr = f'to_v{suffix}'
|
||||
kw = getattr(self, k_attr).weight
|
||||
vw = getattr(self, v_attr).weight
|
||||
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
|
||||
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
|
||||
setattr(self, f'to_kv{suffix}', fused)
|
||||
delattr(self, k_attr)
|
||||
delattr(self, v_attr)
|
||||
self._kv_fused = True
|
||||
return True
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
@@ -298,20 +276,14 @@ class CrossAttention(nn.Module):
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
@@ -332,11 +304,8 @@ class CrossAttention(nn.Module):
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -690,8 +690,6 @@ class WMAModel(nn.Module):
|
||||
self._ctx_cache = {}
|
||||
# fs_embed cache
|
||||
self._fs_embed_cache = None
|
||||
# Pre-created CUDA stream for parallel action/state UNet
|
||||
self._side_stream = torch.cuda.Stream() if not self.base_model_gen_only else None
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
@@ -850,16 +848,15 @@ class WMAModel(nn.Module):
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||
# Run action_unet and state_unet in parallel via pre-created CUDA stream
|
||||
s_stream = self._side_stream
|
||||
s_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s_stream):
|
||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
torch.cuda.current_stream().wait_stream(s_stream)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-10 17:57:48.047156: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 17:57:48.050303: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-10 17:57:48.081710: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 17:57:48.081741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 17:57:48.083577: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 17:57:48.091772: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-10 17:57:48.092045: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-09 18:39:50.119842: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-09 18:39:50.123128: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-09 18:39:50.156652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-09 18:39:50.156708: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-09 18:39:50.158926: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-09 18:39:50.167779: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-09 18:39:50.168073: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-10 17:57:48.787960: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-09 18:39:50.915144: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
@@ -41,7 +41,6 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
⚠ Found 601 fp32 params, converting to bf16
|
||||
✓ All parameters converted to bfloat16
|
||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||
✓ KV fused: 66 attention layers
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
@@ -117,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:08<07:58, 68.38s/it]
|
||||
25%|██▌ | 2/8 [02:13<06:38, 66.48s/it]
|
||||
@@ -141,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
itr,stack_to_device_1,policy/ddim_sampler_init,policy/image_embedding,policy/vae_encode,policy/text_conditioning,policy/projectors,policy/cond_assembly,policy/ddim_sampling,policy/vae_decode,synth_policy,update_action_queue,stack_to_device_2,wm/ddim_sampler_init,wm/image_embedding,wm/vae_encode,wm/text_conditioning,wm/projectors,wm/cond_assembly,wm/ddim_sampling,wm/vae_decode,synth_world_model,update_obs_queue,tensorboard_log,save_results,cpu_transfer,itr_total
|
||||
0,0.16,0.08,20.98,49.56,14.51,0.29,0.07,31005.48,0.00,31094.51,0.39,0.13,0.09,20.62,48.76,14.17,0.28,0.07,31011.17,775.40,31875.87,0.61,0.31,97.28,7.19,63077.50
|
||||
1,0.16,0.09,20.97,49.63,14.52,0.30,0.07,31035.49,0.00,31125.16,0.54,0.17,0.14,21.46,49.26,14.88,0.49,0.12,31047.54,777.56,31918.60,0.75,0.60,109.89,6.21,63163.18
|
||||
2,0.18,0.10,21.44,49.71,15.05,0.34,0.07,31047.64,0.00,31138.56,0.58,0.16,0.13,21.03,48.74,14.69,0.32,0.08,31036.47,776.96,31905.96,0.67,0.39,116.96,7.43,63171.90
|
||||
3,0.18,0.10,21.38,49.47,15.02,0.35,0.08,31041.05,0.00,31132.03,0.48,0.16,0.12,20.81,49.34,14.41,0.47,0.11,31051.98,777.11,31920.42,0.64,0.38,121.67,7.29,63184.26
|
||||
|
@@ -1,5 +0,0 @@
|
||||
stat,stack_to_device_1,policy/ddim_sampler_init,policy/image_embedding,policy/vae_encode,policy/text_conditioning,policy/projectors,policy/cond_assembly,policy/ddim_sampling,policy/vae_decode,synth_policy,update_action_queue,stack_to_device_2,wm/ddim_sampler_init,wm/image_embedding,wm/vae_encode,wm/text_conditioning,wm/projectors,wm/cond_assembly,wm/ddim_sampling,wm/vae_decode,synth_world_model,update_obs_queue,tensorboard_log,save_results,cpu_transfer,itr_total
|
||||
mean,0.17,0.09,21.19,49.59,14.78,0.32,0.07,31032.42,0.00,31122.56,0.49,0.15,0.12,20.98,49.03,14.53,0.39,0.10,31036.79,776.76,31905.21,0.67,0.42,111.45,7.03,63149.21
|
||||
std,0.01,0.01,0.22,0.09,0.26,0.03,0.00,16.13,0.00,16.88,0.07,0.01,0.02,0.31,0.28,0.27,0.09,0.02,15.83,0.82,17.84,0.05,0.11,9.19,0.48,42.08
|
||||
min,0.16,0.08,20.97,49.47,14.51,0.29,0.07,31005.48,0.00,31094.51,0.39,0.13,0.09,20.62,48.74,14.17,0.28,0.07,31011.17,775.40,31875.87,0.61,0.31,97.28,6.21,63077.50
|
||||
max,0.18,0.10,21.44,49.71,15.05,0.35,0.08,31047.64,0.00,31138.56,0.58,0.17,0.14,21.46,49.34,14.88,0.49,0.12,31051.98,777.56,31920.42,0.75,0.60,121.67,7.43,63184.26
|
||||
|
@@ -1,45 +0,0 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py:168: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
============================================================
|
||||
PROFILE ITERATION — Loading model...
|
||||
============================================================
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||
>>> Model loaded and ready.
|
||||
>>> Noise shape: [1, 4, 16, 40, 64]
|
||||
>>> DDIM steps: 50
|
||||
>>> fast_policy_no_decode: True
|
||||
============================================================
|
||||
LAYER 1: ITERATION-LEVEL PROFILING
|
||||
============================================================
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Running 5 profiled iterations ...
|
||||
Traceback (most recent call last):
|
||||
File "/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py", line 981, in <module>
|
||||
main()
|
||||
File "/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py", line 967, in main
|
||||
all_records = run_profiled_iterations(
|
||||
File "/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py", line 502, in run_profiled_iterations
|
||||
sampler_type=args.sampler_type)
|
||||
AttributeError: 'Namespace' object has no attribute 'sampler_type'
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||
"psnr": 32.442113263955434
|
||||
"psnr": 31.802224855380352
|
||||
}
|
||||
@@ -4,7 +4,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
{
|
||||
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--savedir "${res_dir}/output" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
@@ -21,6 +21,9 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae \
|
||||
--vae_dtype bf16 \
|
||||
--diffusion_dtype fp32 \
|
||||
--projector_mode fp32 \
|
||||
--encoder_mode fp32 \
|
||||
--vae_dtype fp32 \
|
||||
--fast_policy_no_decode
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user