From 6630952d2b986c76874c4a36fef9e636cc9c08eb Mon Sep 17 00:00:00 2001 From: qhy <2728290997@qq.com> Date: Mon, 9 Feb 2026 21:23:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BC=82=E6=AD=A5=E4=BF=9D=E5=AD=98=E7=BB=93?= =?UTF-8?q?=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/evaluation/world_model_interaction.py | 122 +++++++++++++++--- 1 file changed, 101 insertions(+), 21 deletions(-) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 9784f62..00bd36b 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -11,6 +11,8 @@ import warnings import imageio import time import json +import atexit +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, nullcontext from dataclasses import dataclass, field, asdict from typing import Optional, Dict, List, Any, Mapping @@ -21,7 +23,7 @@ from tqdm import tqdm from einops import rearrange, repeat from collections import OrderedDict from torch import nn -from eval_utils import populate_queues, log_to_tensorboard +from eval_utils import populate_queues from collections import deque from torch import Tensor from torch.utils.tensorboard import SummaryWriter @@ -393,6 +395,81 @@ def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> Profil return _profiler +# ========== Async I/O ========== +_io_executor: Optional[ThreadPoolExecutor] = None +_io_futures: List[Any] = [] + + +def _get_io_executor() -> ThreadPoolExecutor: + global _io_executor + if _io_executor is None: + _io_executor = ThreadPoolExecutor(max_workers=2) + return _io_executor + + +def _flush_io(): + """Wait for all pending async I/O to finish.""" + global _io_futures + for fut in _io_futures: + try: + fut.result() + except Exception as e: + print(f">>> [async I/O] error: {e}") + _io_futures.clear() + + +atexit.register(_flush_io) + + +def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None: + """Synchronous save on CPU tensor (runs in background thread).""" + video = torch.clamp(video_cpu.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(filename, + grid, + fps=fps, + video_codec='h264', + options={'crf': '10'}) + + +def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None: + """Submit video saving to background thread pool.""" + video_cpu = video.detach().cpu() + fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps) + _io_futures.append(fut) + + +def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None: + """Synchronous TensorBoard log on CPU tensor (runs in background thread).""" + if video_cpu.dim() == 5: + n = video_cpu.shape[0] + video = video_cpu.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = grid.unsqueeze(dim=0) + writer.add_video(tag, grid, fps=fps) + + +def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None: + """Submit TensorBoard logging to background thread pool.""" + if isinstance(data, torch.Tensor) and data.dim() == 5: + data_cpu = data.detach().cpu() + fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps) + _io_futures.append(fut) + + # ========== Original Functions ========== def get_device_from_parameters(module: nn.Module) -> torch.device: """Get a module's device by checking one of its parameters. @@ -1392,30 +1469,30 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: cond_obs_queues = populate_queues(cond_obs_queues, obs_update) - # Save the imagen videos for decision-making + # Save the imagen videos for decision-making (async) with profiler.profile_section("save_results"): sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_0, - sample_tag, - fps=args.save_fps) + log_to_tensorboard_async(writer, + pred_videos_0, + sample_tag, + fps=args.save_fps) # Save videos environment changes via world-model interaction sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_1, - sample_tag, - fps=args.save_fps) + log_to_tensorboard_async(writer, + pred_videos_1, + sample_tag, + fps=args.save_fps) # Save the imagen videos for decision-making sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_0.cpu(), - sample_video_file, - fps=args.save_fps) + save_results_async(pred_videos_0, + sample_video_file, + fps=args.save_fps) # Save videos environment changes via world-model interaction sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_1.cpu(), - sample_video_file, - fps=args.save_fps) + save_results_async(pred_videos_1, + sample_video_file, + fps=args.save_fps) print('>' * 24) # Collect the result of world-model interactions @@ -1426,12 +1503,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: full_video = torch.cat(wm_video, dim=2) sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" - log_to_tensorboard(writer, - full_video, - sample_tag, - fps=args.save_fps) + log_to_tensorboard_async(writer, + full_video, + sample_tag, + fps=args.save_fps) sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" - save_results(full_video, sample_full_video_file, fps=args.save_fps) + save_results_async(full_video, sample_full_video_file, fps=args.save_fps) + + # Wait for all async I/O to complete before profiling report + _flush_io() # Save profiling results profiler.save_results()