629 lines
21 KiB
Python
629 lines
21 KiB
Python
from collections import deque
|
|
from dataclasses import dataclass
|
|
import inspect
|
|
from pathlib import Path
|
|
from typing import Any, Protocol
|
|
from collections.abc import Callable
|
|
|
|
import numpy as np
|
|
import torch
|
|
from loguru import logger as logging
|
|
|
|
import stable_worldmodel as swm
|
|
from stable_worldmodel.solver import Solver
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PlanConfig:
|
|
"""Configuration for the MPC planning loop.
|
|
|
|
Attributes:
|
|
horizon: Planning horizon in number of steps.
|
|
receding_horizon: Number of steps to execute before re-planning.
|
|
history_len: Number of past observations to consider.
|
|
action_block: Number of times each action is repeated (frameskip).
|
|
warm_start: Whether to use the previous plan to initialize the next one.
|
|
"""
|
|
|
|
horizon: int
|
|
receding_horizon: int
|
|
history_len: int = 1
|
|
action_block: int = 1
|
|
warm_start: bool = True
|
|
|
|
@property
|
|
def plan_len(self) -> int:
|
|
"""Total plan length in environment steps."""
|
|
return self.horizon * self.action_block
|
|
|
|
|
|
class Transformable(Protocol):
|
|
"""Protocol for reversible data transformations (e.g., normalizers, scalers)."""
|
|
|
|
def transform(self, x: np.ndarray) -> np.ndarray: # pragma: no cover
|
|
"""Apply preprocessing to input data.
|
|
|
|
Args:
|
|
x: Input data as a numpy array.
|
|
|
|
Returns:
|
|
Preprocessed data as a numpy array.
|
|
"""
|
|
...
|
|
|
|
def inverse_transform(
|
|
self, x: np.ndarray
|
|
) -> np.ndarray: # pragma: no cover
|
|
"""Reverse the preprocessing transformation.
|
|
|
|
Args:
|
|
x: Preprocessed data as a numpy array.
|
|
|
|
Returns:
|
|
Original data as a numpy array.
|
|
"""
|
|
...
|
|
|
|
|
|
class Actionable(Protocol):
|
|
"""Protocol for model action computation."""
|
|
|
|
def get_action(info) -> torch.Tensor: # pragma: no cover
|
|
"""Compute action from observation and goal"""
|
|
...
|
|
|
|
|
|
class BasePolicy:
|
|
"""Base class for agent policies.
|
|
|
|
Attributes:
|
|
env: The environment the policy is associated with.
|
|
type: A string identifier for the policy type.
|
|
"""
|
|
|
|
env: Any
|
|
type: str
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize the base policy.
|
|
|
|
Args:
|
|
**kwargs: Additional configuration parameters.
|
|
"""
|
|
self.env = None
|
|
self.type = 'base'
|
|
for arg, value in kwargs.items():
|
|
setattr(self, arg, value)
|
|
|
|
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
|
|
"""Get action from the policy given the observation.
|
|
|
|
Args:
|
|
obs: The current observation from the environment.
|
|
**kwargs: Additional parameters for action selection.
|
|
|
|
Returns:
|
|
Selected action as a numpy array.
|
|
|
|
Raises:
|
|
NotImplementedError: If not implemented by a subclass.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def set_env(self, env: Any) -> None:
|
|
"""Associate this policy with an environment.
|
|
|
|
Args:
|
|
env: The environment to associate.
|
|
"""
|
|
self.env = env
|
|
|
|
def _move_info_to_device(
|
|
self, info_dict: dict[str, Any], device: torch.device | str
|
|
) -> dict[str, Any]:
|
|
target = torch.device(device)
|
|
for k, v in info_dict.items():
|
|
if torch.is_tensor(v):
|
|
if v.device != target:
|
|
v = v.to(target, non_blocking=True)
|
|
if not v.is_contiguous():
|
|
v = v.contiguous()
|
|
info_dict[k] = v
|
|
return info_dict
|
|
|
|
def _prepare_info_for_device(
|
|
self,
|
|
info_dict: dict,
|
|
device: torch.device | str | None = None,
|
|
) -> dict[str, torch.Tensor]:
|
|
if device is None:
|
|
return self._prepare_info(info_dict)
|
|
|
|
prepare_info = self._prepare_info
|
|
try:
|
|
signature = inspect.signature(prepare_info)
|
|
except (TypeError, ValueError):
|
|
return prepare_info(info_dict)
|
|
|
|
accepts_device = any(
|
|
param.kind == inspect.Parameter.VAR_KEYWORD
|
|
or (param.name == "device" and param.kind in {
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
inspect.Parameter.KEYWORD_ONLY,
|
|
})
|
|
for param in signature.parameters.values()
|
|
)
|
|
if accepts_device:
|
|
return prepare_info(info_dict, device=device)
|
|
return prepare_info(info_dict)
|
|
|
|
def _prepare_info(
|
|
self,
|
|
info_dict: dict,
|
|
device: torch.device | str | None = None,
|
|
) -> dict[str, torch.Tensor]:
|
|
"""Pre-process and transform observations.
|
|
|
|
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
|
|
to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy.
|
|
|
|
Args:
|
|
info_dict: Raw observation dictionary from the environment.
|
|
|
|
Returns:
|
|
A dictionary of processed tensors.
|
|
|
|
Raises:
|
|
ValueError: If an expected numpy array is missing for processing.
|
|
"""
|
|
target_device = torch.device(device) if device is not None else None
|
|
for k, v in info_dict.items():
|
|
is_numpy = isinstance(v, (np.ndarray | np.generic))
|
|
|
|
if hasattr(self, 'process') and k in self.process:
|
|
if not is_numpy:
|
|
raise ValueError(
|
|
f"Expected numpy array for key '{k}' in process, got {type(v)}"
|
|
)
|
|
|
|
# flatten extra dimensions if needed
|
|
shape = v.shape
|
|
if len(shape) > 2:
|
|
v = v.reshape(-1, *shape[2:])
|
|
|
|
# process and reshape back
|
|
v = self.process[k].transform(v)
|
|
v = v.reshape(shape)
|
|
|
|
# collapse env and time dimensions for transform (e, t, ...) -> (e * t, ...)
|
|
# then restore after transform
|
|
if hasattr(self, 'transform') and k in self.transform:
|
|
shape = None
|
|
if is_numpy or torch.is_tensor(v):
|
|
if v.ndim > 2:
|
|
shape = v.shape
|
|
v = v.reshape(-1, *shape[2:])
|
|
if is_numpy:
|
|
if v.dtype.kind in 'USO':
|
|
raise ValueError(
|
|
f"Expected numeric numpy array for key '{k}', got dtype {v.dtype}"
|
|
)
|
|
v = torch.from_numpy(v)
|
|
is_numpy = False
|
|
moved_for_transform = False
|
|
if target_device is not None and target_device.type != "cpu":
|
|
if v.device != target_device:
|
|
v = v.to(target_device, non_blocking=True)
|
|
moved_for_transform = True
|
|
if k.startswith('pixels') or k.startswith('goal'):
|
|
# Vectorized image transform on the full batch.
|
|
v = v.permute(0, 3, 1, 2)
|
|
try:
|
|
v = self.transform[k](v)
|
|
except (NotImplementedError, RuntimeError):
|
|
if not moved_for_transform:
|
|
raise
|
|
v = self.transform[k](v.cpu())
|
|
|
|
if shape is not None:
|
|
v = v.reshape(*shape[:2], *v.shape[1:])
|
|
|
|
if is_numpy and v.dtype.kind not in 'USO':
|
|
v = torch.from_numpy(v)
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.is_tensor(v)
|
|
and v.device.type == "cpu"
|
|
and not v.is_pinned()
|
|
):
|
|
v = v.pin_memory()
|
|
|
|
info_dict[k] = v
|
|
|
|
return info_dict
|
|
|
|
|
|
class RandomPolicy(BasePolicy):
|
|
"""Policy that samples random actions from the action space."""
|
|
|
|
def __init__(self, seed: int | None = None, **kwargs: Any) -> None:
|
|
"""Initialize the random policy.
|
|
|
|
Args:
|
|
seed: Optional random seed for the action space.
|
|
**kwargs: Additional configuration parameters.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.type = 'random'
|
|
self.seed = seed
|
|
|
|
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
|
|
"""Get a random action from the environment's action space.
|
|
|
|
Args:
|
|
obs: The current observation (ignored).
|
|
**kwargs: Additional parameters (ignored).
|
|
|
|
Returns:
|
|
A randomly sampled action.
|
|
"""
|
|
return self.env.action_space.sample()
|
|
|
|
def set_seed(self, seed: int) -> None:
|
|
"""Set the random seed for action sampling.
|
|
|
|
Args:
|
|
seed: The seed value.
|
|
"""
|
|
if self.env is not None:
|
|
self.env.action_space.seed(seed)
|
|
|
|
|
|
class ExpertPolicy(BasePolicy):
|
|
"""Policy using expert demonstrations or heuristics."""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize the expert policy.
|
|
|
|
Args:
|
|
**kwargs: Additional configuration parameters.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.type = 'expert'
|
|
|
|
def get_action(
|
|
self, obs: Any, goal_obs: Any, **kwargs: Any
|
|
) -> np.ndarray | None:
|
|
"""Get action from the expert policy.
|
|
|
|
Args:
|
|
obs: The current observation.
|
|
goal_obs: The goal observation.
|
|
**kwargs: Additional parameters.
|
|
|
|
Returns:
|
|
The expert action, or None if not available.
|
|
"""
|
|
# Implement expert policy logic here
|
|
pass
|
|
|
|
|
|
class FeedForwardPolicy(BasePolicy):
|
|
"""Feed-Forward Policy using a neural network model.
|
|
|
|
Actions are computed via a single forward pass through the model.
|
|
Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).
|
|
|
|
Attributes:
|
|
model: Neural network model implementing the Actionable protocol.
|
|
process: Dictionary of data preprocessors for specific keys.
|
|
transform: Dictionary of tensor transformations (e.g., image transforms).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Actionable,
|
|
process: dict[str, Transformable] | None = None,
|
|
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
|
| None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the feed-forward policy.
|
|
|
|
Args:
|
|
model: Neural network model with a `get_action` method.
|
|
process: Dictionary of data preprocessors for specific keys.
|
|
transform: Dictionary of tensor transformations (e.g., image transforms).
|
|
**kwargs: Additional configuration parameters.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.type = 'feed_forward'
|
|
self.model = model.eval()
|
|
self.process = process or {}
|
|
self.transform = transform or {}
|
|
|
|
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
|
"""Get action via a forward pass through the neural network model.
|
|
|
|
Args:
|
|
info_dict: Current state information containing at minimum a 'goal' key.
|
|
**kwargs: Additional parameters (unused).
|
|
|
|
Returns:
|
|
The selected action as a numpy array.
|
|
|
|
Raises:
|
|
AssertionError: If environment not set or 'goal' not in info_dict.
|
|
"""
|
|
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
|
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
|
|
|
# Prepare the info dict (transforms and normalizes inputs)
|
|
info_dict = self._prepare_info_for_device(
|
|
info_dict, device=next(self.model.parameters()).device
|
|
)
|
|
|
|
# Add goal_pixels key for GCBC model
|
|
if 'goal' in info_dict:
|
|
info_dict['goal_pixels'] = info_dict['goal']
|
|
|
|
# Move all tensors to the model's device
|
|
device = next(self.model.parameters()).device
|
|
info_dict = self._move_info_to_device(info_dict, device)
|
|
|
|
# Get action from model
|
|
with torch.no_grad():
|
|
action = self.model.get_action(info_dict)
|
|
|
|
# Convert to numpy
|
|
if torch.is_tensor(action):
|
|
action = action.cpu().detach().numpy()
|
|
|
|
# post-process action
|
|
if 'action' in self.process:
|
|
action = self.process['action'].inverse_transform(action)
|
|
|
|
return action
|
|
|
|
|
|
class WorldModelPolicy(BasePolicy):
|
|
"""Policy using a world model and planning solver for action selection."""
|
|
|
|
def __init__(
|
|
self,
|
|
solver: Solver,
|
|
config: PlanConfig,
|
|
process: dict[str, Transformable] | None = None,
|
|
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
|
| None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the world model policy.
|
|
|
|
Args:
|
|
solver: The planning solver to use.
|
|
config: MPC planning configuration.
|
|
process: Dictionary of data preprocessors for specific keys.
|
|
transform: Dictionary of tensor transformations (e.g., image transforms).
|
|
**kwargs: Additional configuration parameters.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
|
|
self.type = 'world_model'
|
|
self.cfg = config
|
|
self.solver = solver
|
|
self.action_buffer: deque[torch.Tensor] = deque(
|
|
maxlen=self.flatten_receding_horizon
|
|
)
|
|
self.process = process or {}
|
|
self.transform = transform or {}
|
|
self._action_buffer: deque[torch.Tensor] | None = None
|
|
self._next_init: torch.Tensor | None = None
|
|
|
|
@property
|
|
def flatten_receding_horizon(self) -> int:
|
|
"""Receding horizon in environment steps (with frameskip)."""
|
|
return self.cfg.receding_horizon * self.cfg.action_block
|
|
|
|
def set_env(self, env: Any) -> None:
|
|
"""Configure the policy and solver for the given environment.
|
|
|
|
Args:
|
|
env: The environment to associate with the policy.
|
|
"""
|
|
self.env = env
|
|
n_envs = getattr(env, 'num_envs', 1)
|
|
self.solver.configure(
|
|
action_space=env.action_space, n_envs=n_envs, config=self.cfg
|
|
)
|
|
self._action_buffer = deque(maxlen=self.flatten_receding_horizon)
|
|
|
|
assert isinstance(self.solver, Solver), (
|
|
'Solver must implement the Solver protocol'
|
|
)
|
|
|
|
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
|
|
if active_mask is None:
|
|
return None
|
|
active_mask = np.asarray(active_mask, dtype=bool)
|
|
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
|
|
raise ValueError(
|
|
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
|
|
)
|
|
return active_mask
|
|
|
|
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
|
"""Get action via planning with the world model.
|
|
|
|
Args:
|
|
info_dict: Current state information from the environment.
|
|
**kwargs: Additional parameters for planning.
|
|
|
|
Returns:
|
|
The selected action(s) as a numpy array.
|
|
"""
|
|
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
|
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
|
|
|
|
# need to replan if action buffer is empty
|
|
if len(self._action_buffer) == 0:
|
|
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
|
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
|
|
|
info_dict = self._prepare_info_for_device(
|
|
info_dict, device=self.solver.device
|
|
)
|
|
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
|
|
|
outputs = self.solver(
|
|
info_dict,
|
|
init_action=self._next_init,
|
|
active_mask=active_mask,
|
|
)
|
|
|
|
actions = outputs['actions'] # (num_envs, horizon, action_dim)
|
|
if active_mask is not None and actions.shape[0] != self.env.num_envs:
|
|
full_actions = torch.zeros(
|
|
self.env.num_envs,
|
|
actions.shape[1],
|
|
actions.shape[2],
|
|
dtype=actions.dtype,
|
|
device=actions.device,
|
|
)
|
|
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
|
|
actions = full_actions
|
|
|
|
keep_horizon = self.cfg.receding_horizon
|
|
plan = actions[:, :keep_horizon]
|
|
rest = actions[:, keep_horizon:]
|
|
self._next_init = rest.contiguous() if self.cfg.warm_start else None
|
|
|
|
# frameskip back to timestep
|
|
plan = plan.reshape(
|
|
self.env.num_envs, self.flatten_receding_horizon, -1
|
|
).contiguous()
|
|
|
|
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
|
|
|
|
action = self._action_buffer.popleft()
|
|
action = action.reshape(*self.env.action_space.shape)
|
|
if active_mask is not None:
|
|
if torch.is_tensor(action):
|
|
inactive_mask = torch.as_tensor(
|
|
~active_mask, device=action.device, dtype=torch.bool
|
|
)
|
|
action = action.clone()
|
|
action[inactive_mask] = 0
|
|
else:
|
|
action = np.array(action, copy=True)
|
|
action[~active_mask] = 0
|
|
if torch.is_tensor(action):
|
|
action = action.detach().cpu().numpy()
|
|
else:
|
|
action = np.asarray(action)
|
|
|
|
# post-process action
|
|
if 'action' in self.process:
|
|
action = self.process['action'].inverse_transform(action)
|
|
|
|
return action # (num_envs, action_dim)
|
|
|
|
|
|
def _load_model_with_attribute(run_name, attribute_name, cache_dir=None):
|
|
"""Helper function to load a model checkpoint and find a module with the specified attribute.
|
|
|
|
Args:
|
|
run_name: Path or name of the model run
|
|
attribute_name: Name of the attribute to look for in the module (e.g., 'get_action', 'get_cost')
|
|
cache_dir: Optional cache directory path
|
|
|
|
Returns:
|
|
The module with the specified attribute
|
|
|
|
Raises:
|
|
RuntimeError: If no module with the specified attribute is found
|
|
"""
|
|
if Path(run_name).exists():
|
|
run_path = Path(run_name)
|
|
else:
|
|
run_path = Path(cache_dir or swm.data.utils.get_cache_dir(), run_name)
|
|
|
|
if run_path.is_dir():
|
|
ckpt_files = list(run_path.glob('*_object.ckpt'))
|
|
ckpt_files.sort(key=lambda x: x.stat().st_ctime, reverse=True)
|
|
path = ckpt_files[0]
|
|
logging.info(f'Loading model from checkpoint: {path}')
|
|
else:
|
|
path = Path(f'{run_path}_object.ckpt')
|
|
assert path.exists(), (
|
|
f'Checkpoint path does not exist: {path}. Launch pretraining first.'
|
|
)
|
|
|
|
spt_module = torch.load(path, weights_only=False, map_location='cpu')
|
|
|
|
def scan_module(module):
|
|
if hasattr(module, attribute_name):
|
|
if isinstance(module, torch.nn.Module):
|
|
module = module.eval()
|
|
return module
|
|
for child in module.children():
|
|
result = scan_module(child)
|
|
if result is not None:
|
|
return result
|
|
return None
|
|
|
|
result = scan_module(spt_module)
|
|
if result is not None:
|
|
return result
|
|
|
|
raise RuntimeError(
|
|
f"No module with '{attribute_name}' found in the loaded world model."
|
|
)
|
|
|
|
|
|
def AutoActionableModel(
|
|
run_name: str, cache_dir: str | Path | None = None
|
|
) -> torch.nn.Module:
|
|
"""Load a model checkpoint and return the module with a `get_action` method.
|
|
|
|
Automatically scans the checkpoint for a module implementing the Actionable
|
|
protocol (i.e., has a `get_action` method).
|
|
|
|
Args:
|
|
run_name: Path or name of the model run/checkpoint.
|
|
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
|
|
|
|
Returns:
|
|
The module with a `get_action` method, set to eval mode.
|
|
|
|
Raises:
|
|
RuntimeError: If no module with `get_action` is found in the checkpoint.
|
|
"""
|
|
return _load_model_with_attribute(run_name, 'get_action', cache_dir)
|
|
|
|
|
|
def AutoCostModel(
|
|
run_name: str, cache_dir: str | Path | None = None
|
|
) -> torch.nn.Module:
|
|
"""Load a model checkpoint and return the module with a `get_cost` method.
|
|
|
|
Automatically scans the checkpoint for a module implementing a cost function
|
|
(i.e., has a `get_cost` method) for use with planning solvers.
|
|
|
|
Args:
|
|
run_name: Path or name of the model run/checkpoint.
|
|
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
|
|
|
|
Returns:
|
|
The module with a `get_cost` method, set to eval mode.
|
|
|
|
Raises:
|
|
RuntimeError: If no module with `get_cost` is found in the checkpoint.
|
|
"""
|
|
return _load_model_with_attribute(run_name, 'get_cost', cache_dir)
|
|
|
|
|
|
# Alias for backward compatibility and type hinting
|
|
Policy = BasePolicy
|