Compare commits
2 Commits
57ba85d147
...
qhy-merged
| Author | SHA1 | Date | |
|---|---|---|---|
| 3069666a15 | |||
| 68369cc15f |
@@ -222,7 +222,7 @@ data:
|
|||||||
test:
|
test:
|
||||||
target: unifolm_wma.data.wma_data.WMAData
|
target: unifolm_wma.data.wma_data.WMAData
|
||||||
params:
|
params:
|
||||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||||
video_length: ${model.params.wma_config.params.temporal_length}
|
video_length: ${model.params.wma_config.params.temporal_length}
|
||||||
frame_stride: 2
|
frame_stride: 2
|
||||||
load_raw_resolution: True
|
load_raw_resolution: True
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import argparse, os, glob
|
import argparse, os, glob
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
import atexit
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
@@ -11,13 +13,15 @@ import einops
|
|||||||
import warnings
|
import warnings
|
||||||
import imageio
|
import imageio
|
||||||
|
|
||||||
|
from typing import Optional, List, Any
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from eval_utils import populate_queues, log_to_tensorboard
|
from eval_utils import populate_queues
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@@ -28,6 +32,80 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Async I/O utilities ==========
|
||||||
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||||
|
_io_futures: List[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_io_executor() -> ThreadPoolExecutor:
|
||||||
|
global _io_executor
|
||||||
|
if _io_executor is None:
|
||||||
|
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
return _io_executor
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_io():
|
||||||
|
"""Wait for all pending async I/O to finish."""
|
||||||
|
global _io_futures
|
||||||
|
for fut in _io_futures:
|
||||||
|
try:
|
||||||
|
fut.result()
|
||||||
|
except Exception as e:
|
||||||
|
print(f">>> [async I/O] error: {e}")
|
||||||
|
_io_futures.clear()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_flush_io)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||||
|
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||||
|
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
torchvision.io.write_video(filename,
|
||||||
|
grid,
|
||||||
|
fps=fps,
|
||||||
|
video_codec='h264',
|
||||||
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||||
|
"""Submit video saving to background thread pool."""
|
||||||
|
video_cpu = video.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
|
||||||
|
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
|
||||||
|
video = video_cpu.float()
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = grid.unsqueeze(dim=0)
|
||||||
|
writer.add_video(tag, grid, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
|
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
|
||||||
|
"""Submit tensorboard logging to background thread pool."""
|
||||||
|
video_cpu = video.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
def patch_norm_bypass_autocast():
|
def patch_norm_bypass_autocast():
|
||||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||||
@@ -185,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
|||||||
return file_list
|
return file_list
|
||||||
|
|
||||||
|
|
||||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
|
||||||
"""Load model weights from checkpoint file.
|
"""Load model weights from checkpoint file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): Model instance.
|
model (nn.Module): Model instance.
|
||||||
ckpt (str): Path to the checkpoint file.
|
ckpt (str): Path to the checkpoint file.
|
||||||
|
device (str): Target device for loaded tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: Model with loaded weights.
|
nn.Module: Model with loaded weights.
|
||||||
"""
|
"""
|
||||||
state_dict = torch.load(ckpt, map_location="cpu")
|
state_dict = torch.load(ckpt, map_location=device)
|
||||||
if "state_dict" in list(state_dict.keys()):
|
if "state_dict" in list(state_dict.keys()):
|
||||||
state_dict = state_dict["state_dict"]
|
state_dict = state_dict["state_dict"]
|
||||||
try:
|
try:
|
||||||
@@ -610,36 +689,63 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
config['model']['params']['wma_config']['params'][
|
|
||||||
'use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.perframe_ae = args.perframe_ae
|
|
||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
|
||||||
model = load_model_checkpoint(model, args.ckpt_path)
|
|
||||||
model.eval()
|
|
||||||
print(f'>>> Load pre-trained model ...')
|
|
||||||
|
|
||||||
# Apply precision settings before moving to GPU
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||||
model = apply_precision_settings(model, args)
|
if os.path.exists(prepared_path):
|
||||||
|
# ---- Fast path: load the fully-prepared model ----
|
||||||
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||||
|
model = torch.load(prepared_path,
|
||||||
|
map_location=f"cuda:{gpu_no}",
|
||||||
|
weights_only=False)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
# Compile hot ResBlocks for operator fusion
|
# Restore autocast attributes (weights already cast, just need contexts)
|
||||||
apply_torch_compile(model)
|
model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
|
||||||
|
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
|
||||||
|
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
|
||||||
|
|
||||||
# Export precision-converted checkpoint if requested
|
# Compile hot ResBlocks for operator fusion
|
||||||
if args.export_precision_ckpt:
|
apply_torch_compile(model)
|
||||||
export_path = args.export_precision_ckpt
|
|
||||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
|
||||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
|
||||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build unnomalizer
|
print(f">>> Prepared model loaded.")
|
||||||
|
else:
|
||||||
|
# ---- Normal path: construct + checkpoint + casting ----
|
||||||
|
config['model']['params']['wma_config']['params'][
|
||||||
|
'use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
|
model = load_model_checkpoint(model, args.ckpt_path,
|
||||||
|
device=f"cuda:{gpu_no}")
|
||||||
|
model.eval()
|
||||||
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
|
# Apply precision settings before moving to GPU
|
||||||
|
model = apply_precision_settings(model, args)
|
||||||
|
|
||||||
|
# Export precision-converted checkpoint if requested
|
||||||
|
if args.export_precision_ckpt:
|
||||||
|
export_path = args.export_precision_ckpt
|
||||||
|
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||||
|
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||||
|
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
model = model.cuda(gpu_no)
|
||||||
|
|
||||||
|
# Save prepared model for fast loading next time (before torch.compile)
|
||||||
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||||
|
torch.save(model, prepared_path)
|
||||||
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||||
|
|
||||||
|
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
|
||||||
|
apply_torch_compile(model)
|
||||||
|
|
||||||
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
logging.info("***** Configing Data *****")
|
logging.info("***** Configing Data *****")
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
data.setup()
|
data.setup()
|
||||||
print(">>> Dataset is successfully loaded ...")
|
print(">>> Dataset is successfully loaded ...")
|
||||||
|
|
||||||
model = model.cuda(gpu_no)
|
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
@@ -817,28 +923,28 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_1,
|
pred_videos_1,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_0.cpu(),
|
save_results_async(pred_videos_0,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_1.cpu(),
|
save_results_async(pred_videos_1,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Collect the result of world-model interactions
|
||||||
@@ -846,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
full_video = torch.cat(wm_video, dim=2)
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
full_video,
|
full_video,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||||
|
|
||||||
|
# Wait for all async I/O to complete
|
||||||
|
_flush_io()
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
|||||||
@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
|
|||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
def encode(self, x, **kwargs):
|
||||||
|
if getattr(self, '_channels_last', False):
|
||||||
|
x = x.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
h = self.encoder(x)
|
h = self.encoder(x)
|
||||||
moments = self.quant_conv(h)
|
moments = self.quant_conv(h)
|
||||||
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
|
|||||||
return posterior
|
return posterior
|
||||||
|
|
||||||
def decode(self, z, **kwargs):
|
def decode(self, z, **kwargs):
|
||||||
|
if getattr(self, '_channels_last', False):
|
||||||
|
z = z.to(memory_format=torch.channels_last)
|
||||||
z = self.post_quant_conv(z)
|
z = self.post_quant_conv(z)
|
||||||
dec = self.decoder(z)
|
dec = self.decoder(z)
|
||||||
return dec
|
return dec
|
||||||
|
|||||||
@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
|
|||||||
encoder_posterior = self.first_stage_model.encode(x)
|
encoder_posterior = self.first_stage_model.encode(x)
|
||||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||||
else: ## Consume less GPU memory but slower
|
else: ## Consume less GPU memory but slower
|
||||||
|
bs = getattr(self, 'vae_encode_bs', 1)
|
||||||
results = []
|
results = []
|
||||||
for index in range(x.shape[0]):
|
for i in range(0, x.shape[0], bs):
|
||||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||||
1, :, :, :])
|
|
||||||
frame_result = self.get_first_stage_encoding(
|
frame_result = self.get_first_stage_encoding(
|
||||||
frame_batch).detach()
|
frame_batch).detach()
|
||||||
results.append(frame_result)
|
results.append(frame_result)
|
||||||
@@ -1109,14 +1109,14 @@ class LatentDiffusion(DDPM):
|
|||||||
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
||||||
z = z.to(dtype=vae_dtype)
|
z = z.to(dtype=vae_dtype)
|
||||||
|
|
||||||
|
z = 1. / self.scale_factor * z
|
||||||
if not self.perframe_ae:
|
if not self.perframe_ae:
|
||||||
z = 1. / self.scale_factor * z
|
|
||||||
results = self.first_stage_model.decode(z, **kwargs)
|
results = self.first_stage_model.decode(z, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
bs = getattr(self, 'vae_decode_bs', 1)
|
||||||
results = []
|
results = []
|
||||||
for index in range(z.shape[0]):
|
for i in range(0, z.shape[0], bs):
|
||||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
|
||||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
|
||||||
results.append(frame_result)
|
results.append(frame_result)
|
||||||
results = torch.cat(results, dim=0)
|
results = torch.cat(results, dim=0)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(python3:*)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
|||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
# swish
|
# swish
|
||||||
return x * torch.sigmoid(x)
|
return torch.nn.functional.silu(x)
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
{
|
{
|
||||||
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
--savedir "${res_dir}/output" \
|
--savedir "${res_dir}/output" \
|
||||||
--bs 1 --height 320 --width 512 \
|
--bs 1 --height 320 --width 512 \
|
||||||
@@ -21,6 +21,9 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--vae_dtype bf16 \
|
--diffusion_dtype fp32 \
|
||||||
|
--projector_mode fp32 \
|
||||||
|
--encoder_mode fp32 \
|
||||||
|
--vae_dtype fp32 \
|
||||||
--fast_policy_no_decode
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
Reference in New Issue
Block a user