每步迭代保存异步
This commit is contained in:
@@ -4,7 +4,8 @@
|
||||
"Bash(conda env list:*)",
|
||||
"Bash(mamba env:*)",
|
||||
"Bash(micromamba env list:*)",
|
||||
"Bash(echo:*)"
|
||||
"Bash(echo:*)",
|
||||
"Bash(git show:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -130,4 +130,5 @@ Experiment/log
|
||||
|
||||
*.ckpt
|
||||
|
||||
*.0
|
||||
*.0
|
||||
ckpts/unifolm_wma_dual.ckpt.prepared.pt
|
||||
|
||||
@@ -9,6 +9,8 @@ import logging
|
||||
import einops
|
||||
import warnings
|
||||
import imageio
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
@@ -16,8 +18,9 @@ from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from eval_utils import populate_queues
|
||||
from collections import deque
|
||||
from typing import Optional, List, Any
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
@@ -153,6 +156,81 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
# ========== Async I/O ==========
|
||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||
_io_futures: List[Any] = []
|
||||
|
||||
|
||||
def _get_io_executor() -> ThreadPoolExecutor:
|
||||
global _io_executor
|
||||
if _io_executor is None:
|
||||
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||
return _io_executor
|
||||
|
||||
|
||||
def _flush_io():
|
||||
"""Wait for all pending async I/O to finish."""
|
||||
global _io_futures
|
||||
for fut in _io_futures:
|
||||
try:
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
print(f">>> [async I/O] error: {e}")
|
||||
_io_futures.clear()
|
||||
|
||||
|
||||
atexit.register(_flush_io)
|
||||
|
||||
|
||||
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Submit video saving to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
||||
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
||||
if video_cpu.dim() == 5:
|
||||
n = video_cpu.shape[0]
|
||||
video = video_cpu.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
|
||||
|
||||
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
||||
"""Submit TensorBoard logging to background thread pool."""
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
data_cpu = data.detach().cpu()
|
||||
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
||||
"""Construct the init_frame path from directory and sample metadata.
|
||||
|
||||
@@ -673,31 +751,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||
observation)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
# Save the imagen videos for decision-making (async)
|
||||
if pred_videos_0 is not None:
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
save_results_async(pred_videos_1,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Collect the result of world-model interactions
|
||||
@@ -705,12 +783,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
||||
@@ -1,24 +1,13 @@
|
||||
2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-10 19:43:34.679819: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 19:43:34.729245: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 19:43:34.729298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 19:43:34.730600: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 19:43:34.738078: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-10 19:43:35.659490: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||
>>> Prepared model loaded.
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
@@ -41,8 +30,10 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]
|
||||
9%|▉ | 1/11 [00:33<05:38, 33.86s/it]>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
@@ -92,9 +83,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
18%|█▊ | 2/11 [01:08<05:06, 34.03s/it]
|
||||
@@ -125,6 +114,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
Reference in New Issue
Block a user