2 Commits

Author SHA1 Message Date
qhy
3069666a15 脚本修改 2026-02-10 14:49:26 +08:00
qhy
68369cc15f 合并后测试 2026-02-10 14:45:14 +08:00
7 changed files with 178 additions and 55 deletions

View File

@@ -222,7 +222,7 @@ data:
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True

View File

@@ -1,5 +1,7 @@
import argparse, os, glob
from contextlib import nullcontext
import atexit
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
import random
import torch
@@ -11,13 +13,15 @@ import einops
import warnings
import imageio
from typing import Optional, List, Any
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from eval_utils import populate_queues
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
@@ -28,6 +32,80 @@ from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
# ========== Async I/O utilities ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
video = video_cpu.float()
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
"""Submit tensorboard logging to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
_io_futures.append(fut)
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
@@ -185,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
device (str): Target device for loaded tensors.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
state_dict = torch.load(ckpt, map_location=device)
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
@@ -610,21 +689,40 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Load config
config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False)
model.eval()
# Restore autocast attributes (weights already cast, just need contexts)
model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval()
print(f'>>> Load pre-trained model ...')
# Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt
@@ -633,13 +731,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(f">>> Precision-converted checkpoint saved to: {export_path}")
return
# Build unnomalizer
model = model.cuda(gpu_no)
# Save prepared model for fast loading next time (before torch.compile)
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
apply_torch_compile(model)
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
# Run over data
@@ -817,13 +923,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Save the imagen videos for decision-making
if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
log_to_tensorboard_async(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
log_to_tensorboard_async(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
@@ -831,12 +937,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Save the imagen videos for decision-making
if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
save_results_async(pred_videos_0,
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
save_results_async(pred_videos_1,
sample_video_file,
fps=args.save_fps)
@@ -846,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
log_to_tensorboard_async(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Wait for all async I/O to complete
_flush_io()
def get_parser():

View File

@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}")
def encode(self, x, **kwargs):
if getattr(self, '_channels_last', False):
x = x.to(memory_format=torch.channels_last)
h = self.encoder(x)
moments = self.quant_conv(h)
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
return posterior
def decode(self, z, **kwargs):
if getattr(self, '_channels_last', False):
z = z.to(memory_format=torch.channels_last)
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec

View File

@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower
bs = getattr(self, 'vae_encode_bs', 1)
results = []
for index in range(x.shape[0]):
frame_batch = self.first_stage_model.encode(x[index:index +
1, :, :, :])
for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[i:i + bs])
frame_result = self.get_first_stage_encoding(
frame_batch).detach()
results.append(frame_result)
@@ -1109,14 +1109,14 @@ class LatentDiffusion(DDPM):
vae_dtype = next(self.first_stage_model.parameters()).dtype
z = z.to(dtype=vae_dtype)
if not self.perframe_ae:
z = 1. / self.scale_factor * z
if not self.perframe_ae:
results = self.first_stage_model.decode(z, **kwargs)
else:
bs = getattr(self, 'vae_decode_bs', 1)
results = []
for index in range(z.shape[0]):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
for i in range(0, z.shape[0], bs):
frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
results.append(frame_result)
results = torch.cat(results, dim=0)

View File

@@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(python3:*)"
]
}
}

View File

@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -4,7 +4,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
{
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
--savedir "${res_dir}/output" \
--bs 1 --height 320 --width 512 \
@@ -21,6 +21,9 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae \
--vae_dtype bf16 \
--diffusion_dtype fp32 \
--projector_mode fp32 \
--encoder_mode fp32 \
--vae_dtype fp32 \
--fast_policy_no_decode
} 2>&1 | tee "${res_dir}/output.log"