diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py new file mode 100644 index 0000000..af672c4 --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py @@ -0,0 +1,533 @@ +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Protocol +from collections.abc import Callable + +import numpy as np +import torch +from loguru import logger as logging + +import stable_worldmodel as swm +from stable_worldmodel.solver import Solver + + +@dataclass(frozen=True) +class PlanConfig: + """Configuration for the MPC planning loop. + + Attributes: + horizon: Planning horizon in number of steps. + receding_horizon: Number of steps to execute before re-planning. + history_len: Number of past observations to consider. + action_block: Number of times each action is repeated (frameskip). + warm_start: Whether to use the previous plan to initialize the next one. + """ + + horizon: int + receding_horizon: int + history_len: int = 1 + action_block: int = 1 + warm_start: bool = True + + @property + def plan_len(self) -> int: + """Total plan length in environment steps.""" + return self.horizon * self.action_block + + +class Transformable(Protocol): + """Protocol for reversible data transformations (e.g., normalizers, scalers).""" + + def transform(self, x: np.ndarray) -> np.ndarray: # pragma: no cover + """Apply preprocessing to input data. + + Args: + x: Input data as a numpy array. + + Returns: + Preprocessed data as a numpy array. + """ + ... + + def inverse_transform( + self, x: np.ndarray + ) -> np.ndarray: # pragma: no cover + """Reverse the preprocessing transformation. + + Args: + x: Preprocessed data as a numpy array. + + Returns: + Original data as a numpy array. + """ + ... + + +class Actionable(Protocol): + """Protocol for model action computation.""" + + def get_action(info) -> torch.Tensor: # pragma: no cover + """Compute action from observation and goal""" + ... + + +class BasePolicy: + """Base class for agent policies. + + Attributes: + env: The environment the policy is associated with. + type: A string identifier for the policy type. + """ + + env: Any + type: str + + def __init__(self, **kwargs: Any) -> None: + """Initialize the base policy. + + Args: + **kwargs: Additional configuration parameters. + """ + self.env = None + self.type = 'base' + for arg, value in kwargs.items(): + setattr(self, arg, value) + + def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray: + """Get action from the policy given the observation. + + Args: + obs: The current observation from the environment. + **kwargs: Additional parameters for action selection. + + Returns: + Selected action as a numpy array. + + Raises: + NotImplementedError: If not implemented by a subclass. + """ + raise NotImplementedError + + def set_env(self, env: Any) -> None: + """Associate this policy with an environment. + + Args: + env: The environment to associate. + """ + self.env = env + + def _move_info_to_device( + self, info_dict: dict[str, Any], device: torch.device | str + ) -> dict[str, Any]: + target = torch.device(device) + for k, v in info_dict.items(): + if torch.is_tensor(v) and v.device != target: + info_dict[k] = v.to(target, non_blocking=True) + return info_dict + + def _prepare_info(self, info_dict: dict) -> dict[str, torch.Tensor]: + """Pre-process and transform observations. + + Applies preprocessing (via `self.process`) and transformations (via `self.transform`) + to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy. + + Args: + info_dict: Raw observation dictionary from the environment. + + Returns: + A dictionary of processed tensors. + + Raises: + ValueError: If an expected numpy array is missing for processing. + """ + for k, v in info_dict.items(): + is_numpy = isinstance(v, (np.ndarray | np.generic)) + + if hasattr(self, 'process') and k in self.process: + if not is_numpy: + raise ValueError( + f"Expected numpy array for key '{k}' in process, got {type(v)}" + ) + + # flatten extra dimensions if needed + shape = v.shape + if len(shape) > 2: + v = v.reshape(-1, *shape[2:]) + + # process and reshape back + v = self.process[k].transform(v) + v = v.reshape(shape) + + # collapse env and time dimensions for transform (e, t, ...) -> (e * t, ...) + # then restore after transform + if hasattr(self, 'transform') and k in self.transform: + shape = None + if is_numpy or torch.is_tensor(v): + if v.ndim > 2: + shape = v.shape + v = v.reshape(-1, *shape[2:]) + if is_numpy: + if v.dtype.kind in 'USO': + raise ValueError( + f"Expected numeric numpy array for key '{k}', got dtype {v.dtype}" + ) + v = torch.from_numpy(v) + is_numpy = False + if k.startswith('pixels') or k.startswith('goal'): + # Vectorized image transform on the full batch. + v = v.permute(0, 3, 1, 2) + v = self.transform[k](v) + + if shape is not None: + v = v.reshape(*shape[:2], *v.shape[1:]) + + if is_numpy and v.dtype.kind not in 'USO': + v = torch.from_numpy(v) + + if torch.is_tensor(v) and v.device.type == "cpu" and not v.is_pinned(): + v = v.pin_memory() + + info_dict[k] = v + + return info_dict + + +class RandomPolicy(BasePolicy): + """Policy that samples random actions from the action space.""" + + def __init__(self, seed: int | None = None, **kwargs: Any) -> None: + """Initialize the random policy. + + Args: + seed: Optional random seed for the action space. + **kwargs: Additional configuration parameters. + """ + super().__init__(**kwargs) + self.type = 'random' + self.seed = seed + + def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray: + """Get a random action from the environment's action space. + + Args: + obs: The current observation (ignored). + **kwargs: Additional parameters (ignored). + + Returns: + A randomly sampled action. + """ + return self.env.action_space.sample() + + def set_seed(self, seed: int) -> None: + """Set the random seed for action sampling. + + Args: + seed: The seed value. + """ + if self.env is not None: + self.env.action_space.seed(seed) + + +class ExpertPolicy(BasePolicy): + """Policy using expert demonstrations or heuristics.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the expert policy. + + Args: + **kwargs: Additional configuration parameters. + """ + super().__init__(**kwargs) + self.type = 'expert' + + def get_action( + self, obs: Any, goal_obs: Any, **kwargs: Any + ) -> np.ndarray | None: + """Get action from the expert policy. + + Args: + obs: The current observation. + goal_obs: The goal observation. + **kwargs: Additional parameters. + + Returns: + The expert action, or None if not available. + """ + # Implement expert policy logic here + pass + + +class FeedForwardPolicy(BasePolicy): + """Feed-Forward Policy using a neural network model. + + Actions are computed via a single forward pass through the model. + Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC). + + Attributes: + model: Neural network model implementing the Actionable protocol. + process: Dictionary of data preprocessors for specific keys. + transform: Dictionary of tensor transformations (e.g., image transforms). + """ + + def __init__( + self, + model: Actionable, + process: dict[str, Transformable] | None = None, + transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] + | None = None, + **kwargs: Any, + ) -> None: + """Initialize the feed-forward policy. + + Args: + model: Neural network model with a `get_action` method. + process: Dictionary of data preprocessors for specific keys. + transform: Dictionary of tensor transformations (e.g., image transforms). + **kwargs: Additional configuration parameters. + """ + super().__init__(**kwargs) + self.type = 'feed_forward' + self.model = model.eval() + self.process = process or {} + self.transform = transform or {} + + def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray: + """Get action via a forward pass through the neural network model. + + Args: + info_dict: Current state information containing at minimum a 'goal' key. + **kwargs: Additional parameters (unused). + + Returns: + The selected action as a numpy array. + + Raises: + AssertionError: If environment not set or 'goal' not in info_dict. + """ + assert hasattr(self, 'env'), 'Environment not set for the policy' + assert 'goal' in info_dict, "'goal' must be provided in info_dict" + + # Prepare the info dict (transforms and normalizes inputs) + info_dict = self._prepare_info(info_dict) + + # Add goal_pixels key for GCBC model + if 'goal' in info_dict: + info_dict['goal_pixels'] = info_dict['goal'] + + # Move all tensors to the model's device + device = next(self.model.parameters()).device + info_dict = self._move_info_to_device(info_dict, device) + + # Get action from model + with torch.no_grad(): + action = self.model.get_action(info_dict) + + # Convert to numpy + if torch.is_tensor(action): + action = action.cpu().detach().numpy() + + # post-process action + if 'action' in self.process: + action = self.process['action'].inverse_transform(action) + + return action + + +class WorldModelPolicy(BasePolicy): + """Policy using a world model and planning solver for action selection.""" + + def __init__( + self, + solver: Solver, + config: PlanConfig, + process: dict[str, Transformable] | None = None, + transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] + | None = None, + **kwargs: Any, + ) -> None: + """Initialize the world model policy. + + Args: + solver: The planning solver to use. + config: MPC planning configuration. + process: Dictionary of data preprocessors for specific keys. + transform: Dictionary of tensor transformations (e.g., image transforms). + **kwargs: Additional configuration parameters. + """ + super().__init__(**kwargs) + + self.type = 'world_model' + self.cfg = config + self.solver = solver + self.action_buffer: deque[torch.Tensor] = deque( + maxlen=self.flatten_receding_horizon + ) + self.process = process or {} + self.transform = transform or {} + self._action_buffer: deque[torch.Tensor] | None = None + self._next_init: torch.Tensor | None = None + + @property + def flatten_receding_horizon(self) -> int: + """Receding horizon in environment steps (with frameskip).""" + return self.cfg.receding_horizon * self.cfg.action_block + + def set_env(self, env: Any) -> None: + """Configure the policy and solver for the given environment. + + Args: + env: The environment to associate with the policy. + """ + self.env = env + n_envs = getattr(env, 'num_envs', 1) + self.solver.configure( + action_space=env.action_space, n_envs=n_envs, config=self.cfg + ) + self._action_buffer = deque(maxlen=self.flatten_receding_horizon) + + assert isinstance(self.solver, Solver), ( + 'Solver must implement the Solver protocol' + ) + + def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray: + """Get action via planning with the world model. + + Args: + info_dict: Current state information from the environment. + **kwargs: Additional parameters for planning. + + Returns: + The selected action(s) as a numpy array. + """ + assert hasattr(self, 'env'), 'Environment not set for the policy' + assert 'pixels' in info_dict, "'pixels' must be provided in info_dict" + assert 'goal' in info_dict, "'goal' must be provided in info_dict" + + info_dict = self._prepare_info(info_dict) + info_dict = self._move_info_to_device(info_dict, self.solver.device) + + # need to replan if action buffer is empty + if len(self._action_buffer) == 0: + outputs = self.solver(info_dict, init_action=self._next_init) + + actions = outputs['actions'] # (num_envs, horizon, action_dim) + keep_horizon = self.cfg.receding_horizon + plan = actions[:, :keep_horizon] + rest = actions[:, keep_horizon:] + self._next_init = rest if self.cfg.warm_start else None + + # frameskip back to timestep + plan = plan.reshape( + self.env.num_envs, self.flatten_receding_horizon, -1 + ) + + self._action_buffer.extend(plan.transpose(0, 1)) + + action = self._action_buffer.popleft() + action = action.reshape(*self.env.action_space.shape) + action = action.numpy() + + # post-process action + if 'action' in self.process: + action = self.process['action'].inverse_transform(action) + + return action # (num_envs, action_dim) + + +def _load_model_with_attribute(run_name, attribute_name, cache_dir=None): + """Helper function to load a model checkpoint and find a module with the specified attribute. + + Args: + run_name: Path or name of the model run + attribute_name: Name of the attribute to look for in the module (e.g., 'get_action', 'get_cost') + cache_dir: Optional cache directory path + + Returns: + The module with the specified attribute + + Raises: + RuntimeError: If no module with the specified attribute is found + """ + if Path(run_name).exists(): + run_path = Path(run_name) + else: + run_path = Path(cache_dir or swm.data.utils.get_cache_dir(), run_name) + + if run_path.is_dir(): + ckpt_files = list(run_path.glob('*_object.ckpt')) + ckpt_files.sort(key=lambda x: x.stat().st_ctime, reverse=True) + path = ckpt_files[0] + logging.info(f'Loading model from checkpoint: {path}') + else: + path = Path(f'{run_path}_object.ckpt') + assert path.exists(), ( + f'Checkpoint path does not exist: {path}. Launch pretraining first.' + ) + + spt_module = torch.load(path, weights_only=False, map_location='cpu') + + def scan_module(module): + if hasattr(module, attribute_name): + if isinstance(module, torch.nn.Module): + module = module.eval() + return module + for child in module.children(): + result = scan_module(child) + if result is not None: + return result + return None + + result = scan_module(spt_module) + if result is not None: + return result + + raise RuntimeError( + f"No module with '{attribute_name}' found in the loaded world model." + ) + + +def AutoActionableModel( + run_name: str, cache_dir: str | Path | None = None +) -> torch.nn.Module: + """Load a model checkpoint and return the module with a `get_action` method. + + Automatically scans the checkpoint for a module implementing the Actionable + protocol (i.e., has a `get_action` method). + + Args: + run_name: Path or name of the model run/checkpoint. + cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME. + + Returns: + The module with a `get_action` method, set to eval mode. + + Raises: + RuntimeError: If no module with `get_action` is found in the checkpoint. + """ + return _load_model_with_attribute(run_name, 'get_action', cache_dir) + + +def AutoCostModel( + run_name: str, cache_dir: str | Path | None = None +) -> torch.nn.Module: + """Load a model checkpoint and return the module with a `get_cost` method. + + Automatically scans the checkpoint for a module implementing a cost function + (i.e., has a `get_cost` method) for use with planning solvers. + + Args: + run_name: Path or name of the model run/checkpoint. + cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME. + + Returns: + The module with a `get_cost` method, set to eval mode. + + Raises: + RuntimeError: If no module with `get_cost` is found in the checkpoint. + """ + return _load_model_with_attribute(run_name, 'get_cost', cache_dir) + + +# Alias for backward compatibility and type hinting +Policy = BasePolicy diff --git a/tworoom_results.txt b/tworoom_results.txt index 4c0c982..95b1047 100644 --- a/tworoom_results.txt +++ b/tworoom_results.txt @@ -342,3 +342,61 @@ metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, True, True, True, True, True]), 'seeds': None} evaluation_time: 94.01083040237427 seconds inference_precision: fp16 + +==== CONFIG ==== +cache_dir: null +solver: + _target_: stable_worldmodel.solver.CEMSolver + model: ??? + batch_size: 1 + num_samples: 300 + var_scale: 1.0 + n_steps: 30 + topk: 30 + device: cuda + seed: ${seed} +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: 100 + history_size: 1 + frame_skip: 1 +seed: 42 +policy: two-room/tworoom/lejepa +inference_precision: fp16 +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + - method: _set_state + args: + state: + value: proprio + - method: _set_goal_state + args: + goal_state: + value: goal_proprio +output: + filename: tworoom_results.txt + +==== RESULTS ==== +metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, False, True, True, True, True, + True, True, True, False, True, True, True, True, True, + True, True, True, True, True]), 'seeds': None} +evaluation_time: 89.50024104118347 seconds +inference_precision: fp16