异步保存结果
This commit is contained in:
@@ -11,6 +11,8 @@ import warnings
|
|||||||
import imageio
|
import imageio
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
import atexit
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from typing import Optional, Dict, List, Any, Mapping
|
from typing import Optional, Dict, List, Any, Mapping
|
||||||
@@ -21,7 +23,7 @@ from tqdm import tqdm
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from eval_utils import populate_queues, log_to_tensorboard
|
from eval_utils import populate_queues
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@@ -393,6 +395,81 @@ def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> Profil
|
|||||||
return _profiler
|
return _profiler
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Async I/O ==========
|
||||||
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||||
|
_io_futures: List[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_io_executor() -> ThreadPoolExecutor:
|
||||||
|
global _io_executor
|
||||||
|
if _io_executor is None:
|
||||||
|
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
return _io_executor
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_io():
|
||||||
|
"""Wait for all pending async I/O to finish."""
|
||||||
|
global _io_futures
|
||||||
|
for fut in _io_futures:
|
||||||
|
try:
|
||||||
|
fut.result()
|
||||||
|
except Exception as e:
|
||||||
|
print(f">>> [async I/O] error: {e}")
|
||||||
|
_io_futures.clear()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_flush_io)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||||
|
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||||
|
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
torchvision.io.write_video(filename,
|
||||||
|
grid,
|
||||||
|
fps=fps,
|
||||||
|
video_codec='h264',
|
||||||
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||||
|
"""Submit video saving to background thread pool."""
|
||||||
|
video_cpu = video.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
||||||
|
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
||||||
|
if video_cpu.dim() == 5:
|
||||||
|
n = video_cpu.shape[0]
|
||||||
|
video = video_cpu.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = grid.unsqueeze(dim=0)
|
||||||
|
writer.add_video(tag, grid, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
|
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
||||||
|
"""Submit TensorBoard logging to background thread pool."""
|
||||||
|
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||||
|
data_cpu = data.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
# ========== Original Functions ==========
|
# ========== Original Functions ==========
|
||||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||||
"""Get a module's device by checking one of its parameters.
|
"""Get a module's device by checking one of its parameters.
|
||||||
@@ -1392,30 +1469,30 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||||
obs_update)
|
obs_update)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making (async)
|
||||||
with profiler.profile_section("save_results"):
|
with profiler.profile_section("save_results"):
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_1,
|
pred_videos_1,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_0.cpu(),
|
save_results_async(pred_videos_0,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_1.cpu(),
|
save_results_async(pred_videos_1,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Collect the result of world-model interactions
|
||||||
@@ -1426,12 +1503,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
full_video = torch.cat(wm_video, dim=2)
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
full_video,
|
full_video,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||||
|
|
||||||
|
# Wait for all async I/O to complete before profiling report
|
||||||
|
_flush_io()
|
||||||
|
|
||||||
# Save profiling results
|
# Save profiling results
|
||||||
profiler.save_results()
|
profiler.save_results()
|
||||||
|
|||||||
Reference in New Issue
Block a user