1050 lines
37 KiB
Python
1050 lines
37 KiB
Python
"""World environment manager for vectorized Gymnasium environments."""
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from collections import defaultdict
|
|
from collections.abc import Callable, Sequence
|
|
from copy import deepcopy
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import gymnasium as gym
|
|
import h5py
|
|
import hdf5plugin
|
|
import numpy as np
|
|
import torch
|
|
from gymnasium.vector import VectorEnv
|
|
from loguru import logger as logging
|
|
from rich import print
|
|
from tqdm import tqdm
|
|
|
|
from stable_worldmodel.data.utils import get_cache_dir
|
|
from stable_worldmodel.policy import Policy
|
|
|
|
from .wrapper import MegaWrapper, SyncWorld, VariationWrapper
|
|
|
|
|
|
def _make_env(env_name, max_episode_steps, wrappers, **kwargs):
|
|
"""Create a gymnasium environment with specified wrappers.
|
|
|
|
Factory function for creating environments within a vectorized setup.
|
|
Creates the base environment and applies wrappers in order.
|
|
|
|
Args:
|
|
env_name: Name of the gymnasium environment to create.
|
|
max_episode_steps: Maximum steps per episode before truncation.
|
|
wrappers: List of wrapper functions/classes to apply. Each wrapper
|
|
should accept an environment and return a wrapped environment.
|
|
**kwargs: Additional keyword arguments passed to gym.make.
|
|
|
|
Returns:
|
|
The wrapped gymnasium environment.
|
|
|
|
Example:
|
|
>>> from functools import partial
|
|
>>> wrappers = [partial(MegaWrapper, image_shape=(64, 64))]
|
|
>>> env = _make_env("CartPole-v1", max_episode_steps=500, wrappers=wrappers)
|
|
"""
|
|
disable_env_checker = kwargs.pop('disable_env_checker', True)
|
|
env = gym.make(
|
|
env_name,
|
|
max_episode_steps=max_episode_steps,
|
|
disable_env_checker=disable_env_checker,
|
|
**kwargs,
|
|
)
|
|
for wrapper in wrappers:
|
|
env = wrapper(env)
|
|
return env
|
|
|
|
|
|
def _write_eval_video(
|
|
video_path: Path,
|
|
env_idx: int,
|
|
video_frames: np.ndarray,
|
|
target_frames: np.ndarray,
|
|
eval_budget: int,
|
|
target_len: int,
|
|
):
|
|
import imageio
|
|
|
|
out = imageio.get_writer(
|
|
video_path / f'rollout_{env_idx}.mp4',
|
|
fps=15,
|
|
codec='libx264',
|
|
)
|
|
goals = np.vstack([target_frames[-1], target_frames[-1]])
|
|
for t in range(eval_budget):
|
|
stacked_frame = np.vstack([video_frames[t], target_frames[t % target_len]])
|
|
frame = np.hstack([stacked_frame, goals])
|
|
out.append_data(frame)
|
|
out.close()
|
|
|
|
|
|
class World:
|
|
"""High-level manager for vectorized Gymnasium environments.
|
|
|
|
Manages a set of synchronized vectorized environments with automatic
|
|
preprocessing (resizing, frame stacking, goal conditioning).
|
|
|
|
Args:
|
|
env_name: Name of the Gymnasium environment to create.
|
|
num_envs: Number of parallel environments.
|
|
image_shape: Target shape for image observations (H, W).
|
|
goal_transform: Optional callable to transform goal observations.
|
|
image_transform: Optional callable to transform image observations.
|
|
seed: Random seed for reproducibility.
|
|
history_size: Number of frames to stack.
|
|
frame_skip: Number of frames to skip per step.
|
|
max_episode_steps: Maximum steps per episode before truncation.
|
|
verbose: Verbosity level (0: silent, >0: info).
|
|
extra_wrappers: List of additional wrappers to apply to each env.
|
|
goal_conditioned: Whether to separate goal from observation.
|
|
**kwargs: Additional keyword arguments passed to `gym.make_vec`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env_name: str,
|
|
num_envs: int,
|
|
image_shape: tuple[int, int],
|
|
goal_transform: Callable[[Any], Any] | None = None,
|
|
image_transform: Callable[[Any], Any] | None = None,
|
|
seed: int = 2349867,
|
|
history_size: int = 1,
|
|
frame_skip: int = 1,
|
|
max_episode_steps: int = 100,
|
|
verbose: int = 1,
|
|
extra_wrappers: list[Callable] | None = None,
|
|
goal_conditioned: bool = True,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
wrappers = [
|
|
partial(
|
|
MegaWrapper,
|
|
image_shape=image_shape,
|
|
pixels_transform=image_transform,
|
|
goal_transform=goal_transform,
|
|
history_size=history_size,
|
|
frame_skip=frame_skip,
|
|
separate_goal=goal_conditioned,
|
|
),
|
|
*(extra_wrappers or []),
|
|
]
|
|
|
|
env_fn = partial(
|
|
_make_env, env_name, max_episode_steps, wrappers, **kwargs
|
|
)
|
|
env_fns = [env_fn for _ in range(num_envs)]
|
|
self.envs: VectorEnv = VariationWrapper(SyncWorld(env_fns))
|
|
self.envs.unwrapped.autoreset_mode = gym.vector.AutoresetMode.DISABLED
|
|
|
|
self._history_size = history_size
|
|
self.policy: Policy | None = None
|
|
self.states: dict | None = None
|
|
self.infos: dict = {}
|
|
self.rewards: np.ndarray | None = None
|
|
self.terminateds: np.ndarray | None = None
|
|
self.truncateds: np.ndarray | None = None
|
|
|
|
if verbose > 0:
|
|
logging.info(f'🌍🌍🌍 World {env_name} initialized 🌍🌍🌍')
|
|
|
|
logging.info('🕹️ 🕹️ 🕹️ Action space 🕹️ 🕹️ 🕹️')
|
|
logging.info(f'{self.envs.action_space}')
|
|
|
|
logging.info('👁️ 👁️ 👁️ Observation space 👁️ 👁️ 👁️')
|
|
logging.info(f'{str(self.envs.observation_space)}')
|
|
|
|
if self.envs.variation_space is not None:
|
|
logging.info('⚗️ ⚗️ ⚗️ Variation space ⚗️ ⚗️ ⚗️')
|
|
print(self.single_variation_space.to_str())
|
|
else:
|
|
logging.warning('No variation space provided!')
|
|
|
|
self.seed = seed
|
|
|
|
@property
|
|
def num_envs(self) -> int:
|
|
"""Number of parallel environment instances."""
|
|
return self.envs.num_envs
|
|
|
|
@property
|
|
def observation_space(self) -> gym.Space:
|
|
"""Batched observation space for all environments."""
|
|
return self.envs.observation_space
|
|
|
|
@property
|
|
def action_space(self) -> gym.Space:
|
|
"""Batched action space for all environments."""
|
|
return self.envs.action_space
|
|
|
|
@property
|
|
def variation_space(self) -> gym.Space | None:
|
|
"""Batched variation space for domain randomization."""
|
|
return self.envs.variation_space
|
|
|
|
@property
|
|
def single_variation_space(self) -> gym.Space | None:
|
|
"""Variation space for a single environment instance."""
|
|
return self.envs.single_variation_space
|
|
|
|
@property
|
|
def single_action_space(self) -> gym.Space:
|
|
"""Action space for a single environment instance."""
|
|
return self.envs.single_action_space
|
|
|
|
@property
|
|
def single_observation_space(self) -> gym.Space:
|
|
"""Observation space for a single environment instance."""
|
|
return self.envs.single_observation_space
|
|
|
|
def close(self, **kwargs: Any) -> None:
|
|
"""Close all environments and clean up resources."""
|
|
return self.envs.close(**kwargs)
|
|
|
|
def step(self) -> None:
|
|
"""Advance all environments by one step using the current policy."""
|
|
# note: reset happens before because of auto-reset, should fix that
|
|
if self.policy is None:
|
|
raise RuntimeError('No policy set. Call set_policy() first.')
|
|
|
|
actions = self.policy.get_action(self.infos)
|
|
(
|
|
self.states,
|
|
self.rewards,
|
|
self.terminateds,
|
|
self.truncateds,
|
|
self.infos,
|
|
) = self.envs.step(actions)
|
|
|
|
def reset(
|
|
self,
|
|
seed: int | list[int] | None = None,
|
|
options: dict | None = None,
|
|
) -> None:
|
|
"""Reset all environments to initial states.
|
|
|
|
Args:
|
|
seed: Random seed(s) for the environments.
|
|
options: Additional options passed to the environment reset.
|
|
"""
|
|
self.states, self.infos = self.envs.reset(seed=seed, options=options)
|
|
|
|
def set_policy(self, policy: Policy) -> None:
|
|
"""Attach a policy to the world.
|
|
|
|
Args:
|
|
policy: The policy instance to use for determining actions.
|
|
"""
|
|
self.policy = policy
|
|
self.policy.set_env(self.envs)
|
|
|
|
if hasattr(self.policy, 'seed') and self.policy.seed is not None:
|
|
self.policy.set_seed(self.policy.seed)
|
|
|
|
def record_video(
|
|
self,
|
|
video_path: str | Path,
|
|
max_steps: int = 500,
|
|
fps: int = 30,
|
|
viewname: str | list[str] = 'pixels',
|
|
seed: int | None = None,
|
|
extension: str = 'mp4',
|
|
options: dict | None = None,
|
|
) -> None:
|
|
"""Record rollout videos for each environment under the current policy.
|
|
|
|
Args:
|
|
video_path: Directory path to save the videos.
|
|
max_steps: Maximum steps to record per environment.
|
|
fps: Frames per second for the output video.
|
|
viewname: Key(s) in `infos` containing image data to render.
|
|
seed: Random seed for reset.
|
|
extension: Video file format ('mp4' or 'gif').
|
|
options: Options for reset.
|
|
"""
|
|
|
|
assert extension in ['mp4', 'gif'], (
|
|
'Unsupported video format. Use "mp4" or "gif".'
|
|
)
|
|
|
|
import imageio
|
|
|
|
viewname = [viewname] if isinstance(viewname, str) else viewname
|
|
out = [
|
|
imageio.get_writer(
|
|
Path(video_path) / f'env_{i}.{extension}',
|
|
fps=fps,
|
|
codec='libx264',
|
|
)
|
|
for i in range(self.num_envs)
|
|
]
|
|
|
|
self.reset(seed, options)
|
|
|
|
for i, o in enumerate(out):
|
|
frames_to_stack = []
|
|
for v_name in viewname:
|
|
frame_data = self.infos[v_name][i]
|
|
# if frame_data has a history dimension, take the last frame
|
|
if frame_data.ndim > 3:
|
|
frame_data = frame_data[-1]
|
|
frames_to_stack.append(frame_data)
|
|
frame = np.vstack(frames_to_stack)
|
|
|
|
if 'goal' in self.infos:
|
|
goal_data = self.infos['goal'][i]
|
|
if goal_data.ndim > 3:
|
|
goal_data = goal_data[-1]
|
|
frame = np.vstack([frame, goal_data])
|
|
o.append_data(frame)
|
|
|
|
for _ in range(max_steps):
|
|
self.step()
|
|
|
|
if np.any(self.terminateds) or np.any(self.truncateds):
|
|
break
|
|
|
|
for i, o in enumerate(out):
|
|
frames_to_stack = []
|
|
for v_name in viewname:
|
|
frame_data = self.infos[v_name][i]
|
|
# if frame_data has a history dimension, take the last frame
|
|
if frame_data.ndim > 3:
|
|
frame_data = frame_data[-1]
|
|
frames_to_stack.append(frame_data)
|
|
frame = np.vstack(frames_to_stack)
|
|
|
|
if 'goal' in self.infos:
|
|
goal_data = self.infos['goal'][i]
|
|
if goal_data.ndim > 3:
|
|
goal_data = goal_data[-1]
|
|
frame = np.vstack([frame, goal_data])
|
|
o.append_data(frame)
|
|
for o in out:
|
|
o.close()
|
|
print(f'Video saved to {video_path}')
|
|
|
|
def record_dataset(
|
|
self,
|
|
dataset_name: str,
|
|
episodes: int = 10,
|
|
seed: int | None = None,
|
|
cache_dir: os.PathLike | str | None = None,
|
|
options: dict | None = None,
|
|
) -> None:
|
|
"""Records episodes from the environment into an HDF5 dataset.
|
|
|
|
Args:
|
|
dataset_name: Name of the dataset file (without extension).
|
|
episodes: Total number of episodes to record.
|
|
seed: Initial random seed.
|
|
cache_dir: Directory to save the dataset. Defaults to standard cache.
|
|
options: Reset options passed to environments.
|
|
|
|
Raises:
|
|
NotImplementedError: If history_size > 1.
|
|
"""
|
|
if self._history_size > 1:
|
|
raise NotImplementedError(
|
|
'Frame history > 1 not supported for dataset recording.'
|
|
)
|
|
|
|
path = Path(cache_dir or get_cache_dir()) / f'{dataset_name}.h5'
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.terminateds = np.zeros(self.num_envs, dtype=bool)
|
|
self.truncateds = np.zeros(self.num_envs, dtype=bool)
|
|
|
|
episode_buffers = [defaultdict(list) for _ in range(self.num_envs)]
|
|
|
|
h5_kwargs = {
|
|
'name': str(path),
|
|
'mode': 'a' if path.exists() else 'w',
|
|
'libver': 'latest',
|
|
}
|
|
|
|
if not path.exists(): # creation only args
|
|
h5_kwargs.update(
|
|
{'fs_strategy': 'page', 'fs_page_size': 4 * 1024 * 1024}
|
|
)
|
|
|
|
with h5py.File(**h5_kwargs) as f:
|
|
f.swmr_mode = True # avoid issue when killed
|
|
|
|
if 'ep_len' in f:
|
|
n_ep_recorded = f['ep_len'].shape[0]
|
|
global_step_ptr = (
|
|
f['ep_offset'][-1] + f['ep_len'][-1]
|
|
if n_ep_recorded > 0
|
|
else 0
|
|
)
|
|
initialized = True
|
|
seed = None if seed is None else (seed + n_ep_recorded)
|
|
logging.info(
|
|
f'Resuming: {n_ep_recorded} episodes already on disk.'
|
|
)
|
|
else:
|
|
n_ep_recorded = 0
|
|
global_step_ptr = 0
|
|
initialized = False
|
|
|
|
self.reset(seed, options=options)
|
|
seed = None if seed is None else (seed + self.num_envs)
|
|
self._dump_step_data(episode_buffers) # record initial state
|
|
|
|
with tqdm(
|
|
total=episodes, initial=n_ep_recorded, desc='Recording'
|
|
) as pbar:
|
|
while n_ep_recorded < episodes:
|
|
self.step()
|
|
self._dump_step_data(episode_buffers)
|
|
|
|
for i in range(self.num_envs):
|
|
if self.terminateds[i] or self.truncateds[i]:
|
|
finished_ep = self._handle_done_ep(
|
|
episode_buffers, i, n_ep_recorded
|
|
)
|
|
|
|
# lazy dataset initialization
|
|
if not initialized:
|
|
self._init_h5_datasets(f, finished_ep)
|
|
initialized = True
|
|
|
|
# contiguous writing
|
|
steps_written = self._write_episode(
|
|
f, finished_ep, global_step_ptr
|
|
)
|
|
global_step_ptr += steps_written
|
|
n_ep_recorded += 1
|
|
pbar.update(1)
|
|
|
|
f.flush() # flush metadata to avoid corruption
|
|
|
|
if n_ep_recorded >= episodes:
|
|
break
|
|
|
|
# reset terminated env and record initial state
|
|
n_seed = (
|
|
None
|
|
if seed is None
|
|
else (seed + n_ep_recorded)
|
|
)
|
|
self._reset_single_env(i, n_seed, options)
|
|
self._dump_step_data(episode_buffers, env_idx=i)
|
|
|
|
logging.info(f'Recording complete. Total frames: {global_step_ptr}')
|
|
|
|
def _init_h5_datasets(
|
|
self, f: h5py.File, sample_episode: dict[str, list[Any]]
|
|
) -> None:
|
|
"""Initialize resizable HDF5 datasets based on the first episode.
|
|
|
|
Args:
|
|
f: The open HDF5 file handle.
|
|
sample_episode: A dictionary containing data from a single episode,
|
|
used to determine shapes and dtypes.
|
|
"""
|
|
for key, data_list in sample_episode.items():
|
|
if key in ['ep_len', 'ep_idx', 'policy']:
|
|
continue
|
|
|
|
key = key.replace('/', '_') # sanitize keys for h5
|
|
|
|
# determine array shape and dtype from sample data
|
|
sample_data = np.array(data_list[0])
|
|
shape = (0,) + sample_data.shape
|
|
maxshape = (None,) + sample_data.shape
|
|
|
|
# determine chunk size and compression
|
|
if sample_data.ndim >= 2:
|
|
chunks = (100,) + sample_data.shape
|
|
compression = hdf5plugin.Blosc(
|
|
cname='lz4', clevel=5, shuffle=hdf5plugin.Blosc.SHUFFLE
|
|
)
|
|
|
|
else:
|
|
chunks = (1000,) + sample_data.shape
|
|
compression = None
|
|
|
|
dtype = sample_data.dtype
|
|
if np.issubdtype(dtype, np.str_) or np.issubdtype(
|
|
dtype, np.bytes_
|
|
):
|
|
dtype = h5py.string_dtype()
|
|
|
|
f.create_dataset(
|
|
key,
|
|
shape=shape,
|
|
maxshape=maxshape,
|
|
dtype=dtype,
|
|
chunks=chunks,
|
|
compression=compression,
|
|
)
|
|
|
|
# index metadata
|
|
f.create_dataset(
|
|
'ep_offset', shape=(0,), maxshape=(None,), dtype=np.int64
|
|
)
|
|
f.create_dataset(
|
|
'ep_len', shape=(0,), maxshape=(None,), dtype=np.int32
|
|
)
|
|
|
|
# per-step episode index
|
|
f.create_dataset(
|
|
'ep_idx',
|
|
shape=(0,),
|
|
maxshape=(None,),
|
|
dtype=np.int32,
|
|
chunks=(1000,),
|
|
)
|
|
|
|
def _reset_single_env(
|
|
self,
|
|
env_idx: int,
|
|
seed: int | None = None,
|
|
options: dict | None = None,
|
|
) -> None:
|
|
"""Reset a single environment and update infos dict.
|
|
|
|
Args:
|
|
env_idx: Index of the environment to reset.
|
|
seed: Random seed for this specific environment.
|
|
options: Reset options.
|
|
"""
|
|
self.envs.unwrapped._autoreset_envs = np.zeros(self.num_envs)
|
|
_, infos = self.envs.envs[env_idx].reset(seed=seed, options=options)
|
|
|
|
for k, v in infos.items():
|
|
if k in self.infos:
|
|
self.infos[k][env_idx] = v
|
|
|
|
def _handle_done_ep(
|
|
self,
|
|
tmp_buffer: list[dict[str, list[Any]]],
|
|
env_idx: int,
|
|
n_ep_recorded: int,
|
|
) -> dict[str, list[Any]]:
|
|
"""Prepare the episode buffer for writing.
|
|
|
|
Args:
|
|
tmp_buffer: List of dictionaries accumulating step data per env.
|
|
env_idx: Index of the environment that finished an episode.
|
|
n_ep_recorded: Number of episodes recorded so far.
|
|
|
|
Returns:
|
|
A dictionary containing the complete episode data.
|
|
"""
|
|
ep_buffer = tmp_buffer[env_idx]
|
|
|
|
# left-shift actions to align with observations i.e. (o_t, a_t)
|
|
if 'action' in ep_buffer:
|
|
actions = ep_buffer['action']
|
|
nan = actions.pop(0)
|
|
actions.append(nan)
|
|
|
|
# Extract a copy and clear the temporary buffer
|
|
out = {k: list(v) for k, v in ep_buffer.items()}
|
|
ep_buffer.clear()
|
|
self.terminateds[env_idx] = False
|
|
self.truncateds[env_idx] = False
|
|
|
|
# Add episode index to all steps
|
|
ep_len = len(out['step_idx'])
|
|
out['ep_idx'] = [n_ep_recorded] * ep_len
|
|
|
|
return out
|
|
|
|
def _write_episode(
|
|
self, f: h5py.File, ep_data: dict[str, list[Any]], global_ptr: int
|
|
) -> int:
|
|
"""Write a single contiguous episode to the HDF5 file.
|
|
|
|
Args:
|
|
f: The open HDF5 file handle.
|
|
ep_data: The episode data dictionary.
|
|
global_ptr: The global step index where this episode starts.
|
|
|
|
Returns:
|
|
The length of the episode written.
|
|
"""
|
|
ep_len = len(ep_data['step_idx'])
|
|
|
|
# append data to each dataset
|
|
for key in ep_data:
|
|
h5_key = key.replace('/', '_') # sanitize keys for h5
|
|
if h5_key in ['ep_len', 'policy']:
|
|
continue
|
|
|
|
ds = f[h5_key]
|
|
curr_size = ds.shape[0]
|
|
ds.resize(curr_size + ep_len, axis=0)
|
|
ds[curr_size:] = np.array(ep_data[key])
|
|
|
|
# update metadata
|
|
meta_idx = f['ep_offset'].shape[0]
|
|
f['ep_offset'].resize(meta_idx + 1, axis=0)
|
|
f['ep_len'].resize(meta_idx + 1, axis=0)
|
|
|
|
f['ep_offset'][meta_idx] = global_ptr
|
|
f['ep_len'][meta_idx] = ep_len
|
|
|
|
return ep_len
|
|
|
|
def _dump_step_data(
|
|
self,
|
|
tmp_buffer: list[dict[str, list[Any]]],
|
|
env_idx: int | None = None,
|
|
) -> None:
|
|
"""Append current step data to temporary episode buffers.
|
|
|
|
Args:
|
|
tmp_buffer: List of dictionaries accumulating step data.
|
|
env_idx: Optional index to dump data for a single environment.
|
|
If None, dumps for all environments.
|
|
"""
|
|
env_indices = range(self.num_envs) if env_idx is None else [env_idx]
|
|
|
|
for col, data in self.infos.items():
|
|
if col.startswith('_'):
|
|
continue
|
|
|
|
# normalize data shape and type
|
|
if isinstance(data, np.ndarray):
|
|
data = (
|
|
np.squeeze(data, axis=1)
|
|
if data.ndim > 1 and data.shape[1] == 1
|
|
else data
|
|
)
|
|
if data.dtype == object:
|
|
data = np.concatenate(data).tolist()
|
|
|
|
# append to buffers
|
|
for i in env_indices:
|
|
env_data = (
|
|
data[i].copy()
|
|
if isinstance(data[i], np.ndarray)
|
|
else data[i]
|
|
)
|
|
tmp_buffer[i][col].append(env_data)
|
|
|
|
def evaluate(
|
|
self,
|
|
episodes: int = 10,
|
|
eval_keys: list[str] | None = None,
|
|
seed: int | None = None,
|
|
options: dict | None = None,
|
|
dump_every: int = -1,
|
|
) -> dict:
|
|
"""Evaluate the current policy over multiple episodes.
|
|
|
|
Args:
|
|
episodes: Number of episodes to evaluate.
|
|
eval_keys: List of keys in `infos` to collect and return.
|
|
seed: Random seed for evaluation.
|
|
options: Reset options.
|
|
dump_every: Interval to save intermediate results (for long evals).
|
|
|
|
Returns:
|
|
Dictionary containing success rates, seeds, and collected keys.
|
|
"""
|
|
options = options or {}
|
|
|
|
results: dict = {
|
|
'episode_count': 0,
|
|
'success_rate': 0.0,
|
|
'episode_successes': np.zeros(episodes),
|
|
'seeds': np.zeros(episodes, dtype=np.int32),
|
|
}
|
|
|
|
if eval_keys:
|
|
for key in eval_keys:
|
|
results[key] = np.zeros(episodes)
|
|
|
|
self.terminateds = np.zeros(self.num_envs)
|
|
self.truncateds = np.zeros(self.num_envs)
|
|
|
|
episode_idx = np.arange(self.num_envs)
|
|
self.reset(seed=seed, options=options)
|
|
root_seed = seed + self.num_envs if seed is not None else None
|
|
|
|
eval_done = False
|
|
|
|
# determine "unique" hash for this eval run
|
|
config = {
|
|
'episodes': episodes,
|
|
'eval_keys': tuple(sorted(eval_keys)) if eval_keys else None,
|
|
'seed': seed,
|
|
'options': tuple(sorted(options.items())) if options else None,
|
|
'dump_every': dump_every,
|
|
}
|
|
|
|
config_str = json.dumps(config, sort_keys=True)
|
|
run_hash = hashlib.sha256(config_str.encode()).hexdigest()[:8]
|
|
run_tmp_path = Path(f'eval_tmp_{run_hash}.npy')
|
|
|
|
# load back intermediate results if file exists
|
|
if run_tmp_path.exists():
|
|
tmp_results = np.load(run_tmp_path, allow_pickle=True).item()
|
|
results.update(tmp_results)
|
|
|
|
ep_count = results['episode_count']
|
|
episode_idx = np.arange(ep_count, ep_count + self.num_envs)
|
|
|
|
# reset seed where we left off
|
|
last_seed = seed + ep_count if seed is not None else None
|
|
self.reset(seed=last_seed, options=options)
|
|
|
|
logging.success(
|
|
f'Found existing eval tmp file {run_tmp_path}, resuming from episode {ep_count}/{episodes}'
|
|
)
|
|
|
|
while True:
|
|
self.step()
|
|
|
|
# start new episode for done envs
|
|
for i in range(self.num_envs):
|
|
if self.terminateds[i] or self.truncateds[i]:
|
|
# record eval info
|
|
ep_idx = episode_idx[i]
|
|
results['episode_successes'][ep_idx] = self.terminateds[i]
|
|
results['seeds'][ep_idx] = self.envs.envs[
|
|
i
|
|
].unwrapped.np_random_seed
|
|
|
|
if eval_keys:
|
|
for key in eval_keys:
|
|
assert key in self.infos, (
|
|
f'key {key} not found in infos'
|
|
)
|
|
results[key][ep_idx] = self.infos[key][i]
|
|
|
|
# determine new episode idx
|
|
# re-reset env with seed and options (no supported by auto-reset)
|
|
new_seed = (
|
|
root_seed + results['episode_count']
|
|
if seed is not None
|
|
else None
|
|
)
|
|
next_ep_idx = episode_idx.max() + 1
|
|
episode_idx[i] = next_ep_idx
|
|
results['episode_count'] += 1
|
|
|
|
# break if enough episodes evaluated
|
|
if results['episode_count'] >= episodes:
|
|
eval_done = True
|
|
if run_tmp_path.exists():
|
|
logging.info(
|
|
f'Eval done, deleting tmp file {run_tmp_path}'
|
|
)
|
|
os.remove(run_tmp_path)
|
|
break
|
|
|
|
# dump temporary results in a file
|
|
if dump_every > 0 and (
|
|
results['episode_count'] % dump_every == 0
|
|
):
|
|
np.save(run_tmp_path, results)
|
|
logging.success(
|
|
f'Dumped intermediate eval results to {run_tmp_path} ({results["episode_count"]}/{episodes})'
|
|
)
|
|
self.envs.unwrapped._autoreset_envs = np.zeros(
|
|
(self.num_envs,)
|
|
)
|
|
_, infos = self.envs.envs[i].reset(
|
|
seed=new_seed, options=options
|
|
)
|
|
|
|
for k, v in infos.items():
|
|
if k not in self.infos:
|
|
continue
|
|
# Convert to array and extract scalar to preserve dtype
|
|
self.infos[k][i] = np.asarray(v)
|
|
|
|
if eval_done:
|
|
break
|
|
|
|
# compute success rate
|
|
results['success_rate'] = (
|
|
float(np.sum(results['episode_successes'])) / episodes * 100.0
|
|
)
|
|
|
|
assert results['episode_count'] == episodes, (
|
|
f'episode_count {results["episode_count"]} != episodes {episodes}'
|
|
)
|
|
|
|
assert np.unique(results['seeds']).shape[0] == episodes, (
|
|
'Some episode seeds are identical!'
|
|
)
|
|
|
|
return results
|
|
|
|
def evaluate_from_dataset(
|
|
self,
|
|
dataset: Any,
|
|
episodes_idx: Sequence[int],
|
|
start_steps: Sequence[int],
|
|
goal_offset_steps: int,
|
|
eval_budget: int,
|
|
callables: list[dict] | None = None,
|
|
save_video: bool = True,
|
|
video_path: str | Path = './',
|
|
) -> dict:
|
|
"""Evaluate the policy starting from states sampled from a dataset.
|
|
|
|
Args:
|
|
dataset: The source dataset to sample initial states/goals from.
|
|
episodes_idx: Indices of episodes to sample from.
|
|
start_steps: Step indices within those episodes to start from.
|
|
goal_offset_steps: Number of steps ahead to look for the goal.
|
|
eval_budget: Maximum steps allowed for the agent to reach the goal.
|
|
callables: Optional list of method calls to setup the env.
|
|
save_video: Whether to save rollout videos.
|
|
video_path: Path to save videos.
|
|
|
|
Returns:
|
|
Dictionary containing success rates and other metrics.
|
|
|
|
Raises:
|
|
ValueError: If input sequence lengths mismatch or don't match num_envs.
|
|
"""
|
|
assert (
|
|
self.envs.envs[0].spec.max_episode_steps is None
|
|
or self.envs.envs[0].spec.max_episode_steps >= goal_offset_steps
|
|
), 'env max_episode_steps must be greater than eval_budget'
|
|
|
|
ep_idx_arr = np.array(episodes_idx)
|
|
start_steps_arr = np.array(start_steps)
|
|
end_steps = start_steps_arr + goal_offset_steps
|
|
|
|
if not (len(ep_idx_arr) == len(start_steps_arr)):
|
|
raise ValueError(
|
|
'episodes_idx and start_steps must have the same length'
|
|
)
|
|
|
|
if len(ep_idx_arr) != self.num_envs:
|
|
raise ValueError(
|
|
'Number of episodes to evaluate must match number of envs'
|
|
)
|
|
|
|
data = dataset.load_chunk(ep_idx_arr, start_steps_arr, end_steps)
|
|
columns = dataset.column_names
|
|
|
|
# keep relevant part of the chunk
|
|
init_step_per_env: dict[str, list[Any]] = defaultdict(list)
|
|
goal_step_per_env: dict[str, list[Any]] = defaultdict(list)
|
|
|
|
for i, ep in enumerate(data):
|
|
for col in columns:
|
|
if col.startswith('goal'):
|
|
continue
|
|
if col.startswith('pixels'):
|
|
# permute channel to be last
|
|
ep[col] = ep[col].permute(0, 2, 3, 1)
|
|
|
|
if not isinstance(ep[col], (torch.Tensor | np.ndarray)):
|
|
continue
|
|
|
|
init_data = ep[col][0]
|
|
goal_data = ep[col][-1]
|
|
|
|
# TODO handle that better
|
|
if not isinstance(init_data, (np.ndarray | torch.Tensor)):
|
|
logging.warning(
|
|
f'Data type {type(init_data)} for column {col} not supported, yet skipping conversion'
|
|
)
|
|
continue
|
|
|
|
init_data = (
|
|
init_data.numpy()
|
|
if isinstance(init_data, torch.Tensor)
|
|
else init_data
|
|
)
|
|
goal_data = (
|
|
goal_data.numpy()
|
|
if isinstance(goal_data, torch.Tensor)
|
|
else goal_data
|
|
)
|
|
|
|
init_step_per_env[col].append(init_data)
|
|
goal_step_per_env[col].append(goal_data)
|
|
|
|
init_step = {
|
|
k: np.stack(v) for k, v in deepcopy(init_step_per_env).items()
|
|
}
|
|
|
|
goal_step = {}
|
|
for key, value in goal_step_per_env.items():
|
|
key = 'goal' if key == 'pixels' else f'goal_{key}'
|
|
goal_step[key] = np.stack(value)
|
|
|
|
# get dataset info
|
|
seeds = init_step.get('seed')
|
|
# get dataset variation
|
|
vkey = 'variation.'
|
|
variations_dict = {
|
|
k.removeprefix(vkey): v
|
|
for k, v in init_step.items()
|
|
if k.startswith(vkey)
|
|
}
|
|
|
|
options = [{} for _ in range(self.num_envs)]
|
|
|
|
if len(variations_dict) > 0:
|
|
for i in range(self.num_envs):
|
|
options[i]['variation'] = list(variations_dict.keys())
|
|
options[i]['variation_values'] = {
|
|
k: v[i] for k, v in variations_dict.items()
|
|
}
|
|
|
|
init_step.update(deepcopy(goal_step))
|
|
self.reset(seed=seeds, options=options) # set seeds for all envs
|
|
|
|
# apply callable list (e.g used for set initial position if not access to seed)
|
|
callables = callables or []
|
|
for i, env in enumerate(self.envs.unwrapped.envs):
|
|
env_unwrapped = env.unwrapped
|
|
|
|
for spec in callables:
|
|
method_name = spec['method']
|
|
if not hasattr(env_unwrapped, method_name):
|
|
logging.warning(
|
|
f'Env {env_unwrapped} has no method {method_name}, skipping callable'
|
|
)
|
|
continue
|
|
|
|
method = getattr(env_unwrapped, method_name)
|
|
args = spec.get('args', spec)
|
|
|
|
# prepare args
|
|
prepared_args = {}
|
|
for args_name, args_data in args.items():
|
|
value = args_data.get('value', None)
|
|
is_in_datset = args_data.get('in_dataset', True)
|
|
|
|
if is_in_datset:
|
|
if value not in init_step:
|
|
logging.warning(
|
|
f'Col {value} not found in dataset, skipping callable for env {env_unwrapped}'
|
|
)
|
|
continue
|
|
prepared_args[args_name] = deepcopy(
|
|
init_step[value][i]
|
|
)
|
|
else:
|
|
prepared_args[args_name] = args_data.get('value')
|
|
|
|
# call method with prepared args
|
|
method(**prepared_args)
|
|
|
|
for i, env in enumerate(self.envs.unwrapped.envs):
|
|
env_unwrapped = env.unwrapped
|
|
|
|
# TODO remove this
|
|
if 'goal_state' in init_step and 'goal_state' in goal_step:
|
|
assert np.array_equal(
|
|
init_step['goal_state'][i], goal_step['goal_state'][i]
|
|
), 'Goal state info does not match at reset'
|
|
|
|
results: dict = {
|
|
'success_rate': 0.0,
|
|
'episode_successes': np.zeros(len(episodes_idx)),
|
|
'seeds': seeds,
|
|
}
|
|
|
|
# expend all data to the right shape (x, y, (original_shape))
|
|
shape_prefix = self.infos['pixels'].shape[:2]
|
|
|
|
# TODO get the data from the previous step in the dataset for history
|
|
init_step = {
|
|
k: np.broadcast_to(v[:, None, ...], shape_prefix + v.shape[1:])
|
|
for k, v in init_step.items()
|
|
}
|
|
goal_step = {
|
|
k: np.broadcast_to(v[:, None, ...], shape_prefix + v.shape[1:])
|
|
for k, v in goal_step.items()
|
|
}
|
|
|
|
# update the reset with our new init and goal infos
|
|
self.infos.update(deepcopy(init_step))
|
|
self.infos.update(deepcopy(goal_step))
|
|
|
|
if 'goal' in goal_step and 'goal' in self.infos:
|
|
assert np.allclose(self.infos['goal'], goal_step['goal']), (
|
|
'Goal info does not match'
|
|
)
|
|
|
|
target_frames = torch.stack([ep['pixels'] for ep in data]).numpy()
|
|
video_frames = np.empty(
|
|
(self.num_envs, eval_budget, *self.infos['pixels'].shape[-3:]),
|
|
dtype=np.uint8,
|
|
)
|
|
|
|
# run normal evaluation for eval_budget and record video
|
|
active_mask = np.ones(self.num_envs, dtype=bool)
|
|
last_eval_step = 0
|
|
for i in range(eval_budget):
|
|
video_frames[:, i] = self.infos['pixels'][:, -1]
|
|
last_eval_step = i
|
|
self.infos.update(goal_step)
|
|
actions = self.policy.get_action(self.infos, active_mask=active_mask)
|
|
(
|
|
self.states,
|
|
self.rewards,
|
|
self.terminateds,
|
|
self.truncateds,
|
|
self.infos,
|
|
) = self.envs.step(actions)
|
|
results['episode_successes'] = np.logical_or(
|
|
results['episode_successes'], self.terminateds
|
|
)
|
|
active_mask = np.logical_not(results['episode_successes'])
|
|
if not np.any(active_mask):
|
|
break
|
|
# for auto-reset
|
|
self.envs.unwrapped._autoreset_envs = np.zeros((self.num_envs,))
|
|
|
|
video_frames[:, last_eval_step] = self.infos['pixels'][:, -1]
|
|
if last_eval_step + 1 < eval_budget:
|
|
video_frames[:, last_eval_step + 1 :] = video_frames[:, last_eval_step : last_eval_step + 1]
|
|
|
|
n_episodes = len(episodes_idx)
|
|
|
|
# compute success rate
|
|
results['success_rate'] = (
|
|
float(np.sum(results['episode_successes'])) / n_episodes * 100.0
|
|
)
|
|
|
|
# save video if required
|
|
if save_video:
|
|
target_len = target_frames.shape[1]
|
|
video_path_obj = Path(video_path)
|
|
video_path_obj.mkdir(parents=True, exist_ok=True)
|
|
max_workers = min(self.num_envs, os.cpu_count() or 1, 8)
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
futures = [
|
|
executor.submit(
|
|
_write_eval_video,
|
|
video_path_obj,
|
|
i,
|
|
video_frames[i],
|
|
target_frames[i],
|
|
eval_budget,
|
|
target_len,
|
|
)
|
|
for i in range(self.num_envs)
|
|
]
|
|
for future in futures:
|
|
future.result()
|
|
print(f'Video saved to {video_path_obj}')
|
|
|
|
if results['seeds'] is not None:
|
|
assert np.unique(results['seeds']).shape[0] == n_episodes, (
|
|
'Some episode seeds are identical!'
|
|
)
|
|
|
|
return results
|