Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a639fdefca | ||
|
|
28f2fba0e8 | ||
|
|
113e591899 | ||
|
|
0164e21f48 | ||
|
|
02080e2564 | ||
|
|
d86aeb2df0 | ||
|
|
5e55727901 | ||
|
|
02c3cea3f9 | ||
|
|
f08f2b82f4 | ||
|
|
e84074d6d6 | ||
|
|
cf43af0729 | ||
|
|
4c3fdbcce6 | ||
|
|
75a5d86966 | ||
|
|
46cb2177bc | ||
|
|
8ba5bc8b0b | ||
|
|
e6f2b2b9d4 | ||
|
|
25e4ddb628 | ||
|
|
995cd8cfec | ||
|
|
cd03a0d5cb | ||
|
|
20ffb3492b | ||
|
|
96e17a13af | ||
|
|
006102d00c | ||
|
|
3a94829eac | ||
|
|
38be7d3bef | ||
|
|
f2750daace | ||
|
|
9e2407cdc4 | ||
|
|
0f85e39690 | ||
|
|
85795bd91d | ||
|
|
7c2e341d93 | ||
|
|
12ba4f4352 | ||
|
|
fa1c15c896 | ||
|
|
8b84251eb9 |
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
.venv/*
|
||||
!.venv/.gitignore
|
||||
!.venv/lib/
|
||||
.venv/lib/*
|
||||
!.venv/lib/python3.10/
|
||||
.venv/lib/python3.10/*
|
||||
!.venv/lib/python3.10/site-packages/
|
||||
.venv/lib/python3.10/site-packages/*
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/*
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
|
||||
outputs/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
|
||||
torch_profile/
|
||||
trace.json
|
||||
key_averages.txt
|
||||
eval_tmp_*.npy
|
||||
*.mp4
|
||||
*.gif
|
||||
|
||||
.DS_Store
|
||||
.idea/
|
||||
.vscode/
|
||||
*.log
|
||||
0
.venv/.gitignore
vendored
Normal file
0
.venv/.gitignore
vendored
Normal file
628
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
Normal file
628
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
Normal file
@@ -0,0 +1,628 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger as logging
|
||||
|
||||
import stable_worldmodel as swm
|
||||
from stable_worldmodel.solver import Solver
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlanConfig:
|
||||
"""Configuration for the MPC planning loop.
|
||||
|
||||
Attributes:
|
||||
horizon: Planning horizon in number of steps.
|
||||
receding_horizon: Number of steps to execute before re-planning.
|
||||
history_len: Number of past observations to consider.
|
||||
action_block: Number of times each action is repeated (frameskip).
|
||||
warm_start: Whether to use the previous plan to initialize the next one.
|
||||
"""
|
||||
|
||||
horizon: int
|
||||
receding_horizon: int
|
||||
history_len: int = 1
|
||||
action_block: int = 1
|
||||
warm_start: bool = True
|
||||
|
||||
@property
|
||||
def plan_len(self) -> int:
|
||||
"""Total plan length in environment steps."""
|
||||
return self.horizon * self.action_block
|
||||
|
||||
|
||||
class Transformable(Protocol):
|
||||
"""Protocol for reversible data transformations (e.g., normalizers, scalers)."""
|
||||
|
||||
def transform(self, x: np.ndarray) -> np.ndarray: # pragma: no cover
|
||||
"""Apply preprocessing to input data.
|
||||
|
||||
Args:
|
||||
x: Input data as a numpy array.
|
||||
|
||||
Returns:
|
||||
Preprocessed data as a numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
def inverse_transform(
|
||||
self, x: np.ndarray
|
||||
) -> np.ndarray: # pragma: no cover
|
||||
"""Reverse the preprocessing transformation.
|
||||
|
||||
Args:
|
||||
x: Preprocessed data as a numpy array.
|
||||
|
||||
Returns:
|
||||
Original data as a numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Actionable(Protocol):
|
||||
"""Protocol for model action computation."""
|
||||
|
||||
def get_action(info) -> torch.Tensor: # pragma: no cover
|
||||
"""Compute action from observation and goal"""
|
||||
...
|
||||
|
||||
|
||||
class BasePolicy:
|
||||
"""Base class for agent policies.
|
||||
|
||||
Attributes:
|
||||
env: The environment the policy is associated with.
|
||||
type: A string identifier for the policy type.
|
||||
"""
|
||||
|
||||
env: Any
|
||||
type: str
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the base policy.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional configuration parameters.
|
||||
"""
|
||||
self.env = None
|
||||
self.type = 'base'
|
||||
for arg, value in kwargs.items():
|
||||
setattr(self, arg, value)
|
||||
|
||||
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
|
||||
"""Get action from the policy given the observation.
|
||||
|
||||
Args:
|
||||
obs: The current observation from the environment.
|
||||
**kwargs: Additional parameters for action selection.
|
||||
|
||||
Returns:
|
||||
Selected action as a numpy array.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_env(self, env: Any) -> None:
|
||||
"""Associate this policy with an environment.
|
||||
|
||||
Args:
|
||||
env: The environment to associate.
|
||||
"""
|
||||
self.env = env
|
||||
|
||||
def _move_info_to_device(
|
||||
self, info_dict: dict[str, Any], device: torch.device | str
|
||||
) -> dict[str, Any]:
|
||||
target = torch.device(device)
|
||||
for k, v in info_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
if v.device != target:
|
||||
v = v.to(target, non_blocking=True)
|
||||
if not v.is_contiguous():
|
||||
v = v.contiguous()
|
||||
info_dict[k] = v
|
||||
return info_dict
|
||||
|
||||
def _prepare_info_for_device(
|
||||
self,
|
||||
info_dict: dict,
|
||||
device: torch.device | str | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if device is None:
|
||||
return self._prepare_info(info_dict)
|
||||
|
||||
prepare_info = self._prepare_info
|
||||
try:
|
||||
signature = inspect.signature(prepare_info)
|
||||
except (TypeError, ValueError):
|
||||
return prepare_info(info_dict)
|
||||
|
||||
accepts_device = any(
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
or (param.name == "device" and param.kind in {
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
})
|
||||
for param in signature.parameters.values()
|
||||
)
|
||||
if accepts_device:
|
||||
return prepare_info(info_dict, device=device)
|
||||
return prepare_info(info_dict)
|
||||
|
||||
def _prepare_info(
|
||||
self,
|
||||
info_dict: dict,
|
||||
device: torch.device | str | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Pre-process and transform observations.
|
||||
|
||||
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
|
||||
to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy.
|
||||
|
||||
Args:
|
||||
info_dict: Raw observation dictionary from the environment.
|
||||
|
||||
Returns:
|
||||
A dictionary of processed tensors.
|
||||
|
||||
Raises:
|
||||
ValueError: If an expected numpy array is missing for processing.
|
||||
"""
|
||||
target_device = torch.device(device) if device is not None else None
|
||||
for k, v in info_dict.items():
|
||||
is_numpy = isinstance(v, (np.ndarray | np.generic))
|
||||
|
||||
if hasattr(self, 'process') and k in self.process:
|
||||
if not is_numpy:
|
||||
raise ValueError(
|
||||
f"Expected numpy array for key '{k}' in process, got {type(v)}"
|
||||
)
|
||||
|
||||
# flatten extra dimensions if needed
|
||||
shape = v.shape
|
||||
if len(shape) > 2:
|
||||
v = v.reshape(-1, *shape[2:])
|
||||
|
||||
# process and reshape back
|
||||
v = self.process[k].transform(v)
|
||||
v = v.reshape(shape)
|
||||
|
||||
# collapse env and time dimensions for transform (e, t, ...) -> (e * t, ...)
|
||||
# then restore after transform
|
||||
if hasattr(self, 'transform') and k in self.transform:
|
||||
shape = None
|
||||
if is_numpy or torch.is_tensor(v):
|
||||
if v.ndim > 2:
|
||||
shape = v.shape
|
||||
v = v.reshape(-1, *shape[2:])
|
||||
if is_numpy:
|
||||
if v.dtype.kind in 'USO':
|
||||
raise ValueError(
|
||||
f"Expected numeric numpy array for key '{k}', got dtype {v.dtype}"
|
||||
)
|
||||
v = torch.from_numpy(v)
|
||||
is_numpy = False
|
||||
moved_for_transform = False
|
||||
if target_device is not None and target_device.type != "cpu":
|
||||
if v.device != target_device:
|
||||
v = v.to(target_device, non_blocking=True)
|
||||
moved_for_transform = True
|
||||
if k.startswith('pixels') or k.startswith('goal'):
|
||||
# Vectorized image transform on the full batch.
|
||||
v = v.permute(0, 3, 1, 2)
|
||||
try:
|
||||
v = self.transform[k](v)
|
||||
except (NotImplementedError, RuntimeError):
|
||||
if not moved_for_transform:
|
||||
raise
|
||||
v = self.transform[k](v.cpu())
|
||||
|
||||
if shape is not None:
|
||||
v = v.reshape(*shape[:2], *v.shape[1:])
|
||||
|
||||
if is_numpy and v.dtype.kind not in 'USO':
|
||||
v = torch.from_numpy(v)
|
||||
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.is_tensor(v)
|
||||
and v.device.type == "cpu"
|
||||
and not v.is_pinned()
|
||||
):
|
||||
v = v.pin_memory()
|
||||
|
||||
info_dict[k] = v
|
||||
|
||||
return info_dict
|
||||
|
||||
|
||||
class RandomPolicy(BasePolicy):
|
||||
"""Policy that samples random actions from the action space."""
|
||||
|
||||
def __init__(self, seed: int | None = None, **kwargs: Any) -> None:
|
||||
"""Initialize the random policy.
|
||||
|
||||
Args:
|
||||
seed: Optional random seed for the action space.
|
||||
**kwargs: Additional configuration parameters.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.type = 'random'
|
||||
self.seed = seed
|
||||
|
||||
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
|
||||
"""Get a random action from the environment's action space.
|
||||
|
||||
Args:
|
||||
obs: The current observation (ignored).
|
||||
**kwargs: Additional parameters (ignored).
|
||||
|
||||
Returns:
|
||||
A randomly sampled action.
|
||||
"""
|
||||
return self.env.action_space.sample()
|
||||
|
||||
def set_seed(self, seed: int) -> None:
|
||||
"""Set the random seed for action sampling.
|
||||
|
||||
Args:
|
||||
seed: The seed value.
|
||||
"""
|
||||
if self.env is not None:
|
||||
self.env.action_space.seed(seed)
|
||||
|
||||
|
||||
class ExpertPolicy(BasePolicy):
|
||||
"""Policy using expert demonstrations or heuristics."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the expert policy.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional configuration parameters.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.type = 'expert'
|
||||
|
||||
def get_action(
|
||||
self, obs: Any, goal_obs: Any, **kwargs: Any
|
||||
) -> np.ndarray | None:
|
||||
"""Get action from the expert policy.
|
||||
|
||||
Args:
|
||||
obs: The current observation.
|
||||
goal_obs: The goal observation.
|
||||
**kwargs: Additional parameters.
|
||||
|
||||
Returns:
|
||||
The expert action, or None if not available.
|
||||
"""
|
||||
# Implement expert policy logic here
|
||||
pass
|
||||
|
||||
|
||||
class FeedForwardPolicy(BasePolicy):
|
||||
"""Feed-Forward Policy using a neural network model.
|
||||
|
||||
Actions are computed via a single forward pass through the model.
|
||||
Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).
|
||||
|
||||
Attributes:
|
||||
model: Neural network model implementing the Actionable protocol.
|
||||
process: Dictionary of data preprocessors for specific keys.
|
||||
transform: Dictionary of tensor transformations (e.g., image transforms).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Actionable,
|
||||
process: dict[str, Transformable] | None = None,
|
||||
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
| None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the feed-forward policy.
|
||||
|
||||
Args:
|
||||
model: Neural network model with a `get_action` method.
|
||||
process: Dictionary of data preprocessors for specific keys.
|
||||
transform: Dictionary of tensor transformations (e.g., image transforms).
|
||||
**kwargs: Additional configuration parameters.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.type = 'feed_forward'
|
||||
self.model = model.eval()
|
||||
self.process = process or {}
|
||||
self.transform = transform or {}
|
||||
|
||||
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
||||
"""Get action via a forward pass through the neural network model.
|
||||
|
||||
Args:
|
||||
info_dict: Current state information containing at minimum a 'goal' key.
|
||||
**kwargs: Additional parameters (unused).
|
||||
|
||||
Returns:
|
||||
The selected action as a numpy array.
|
||||
|
||||
Raises:
|
||||
AssertionError: If environment not set or 'goal' not in info_dict.
|
||||
"""
|
||||
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
# Prepare the info dict (transforms and normalizes inputs)
|
||||
info_dict = self._prepare_info_for_device(
|
||||
info_dict, device=next(self.model.parameters()).device
|
||||
)
|
||||
|
||||
# Add goal_pixels key for GCBC model
|
||||
if 'goal' in info_dict:
|
||||
info_dict['goal_pixels'] = info_dict['goal']
|
||||
|
||||
# Move all tensors to the model's device
|
||||
device = next(self.model.parameters()).device
|
||||
info_dict = self._move_info_to_device(info_dict, device)
|
||||
|
||||
# Get action from model
|
||||
with torch.no_grad():
|
||||
action = self.model.get_action(info_dict)
|
||||
|
||||
# Convert to numpy
|
||||
if torch.is_tensor(action):
|
||||
action = action.cpu().detach().numpy()
|
||||
|
||||
# post-process action
|
||||
if 'action' in self.process:
|
||||
action = self.process['action'].inverse_transform(action)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
class WorldModelPolicy(BasePolicy):
|
||||
"""Policy using a world model and planning solver for action selection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
solver: Solver,
|
||||
config: PlanConfig,
|
||||
process: dict[str, Transformable] | None = None,
|
||||
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
| None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the world model policy.
|
||||
|
||||
Args:
|
||||
solver: The planning solver to use.
|
||||
config: MPC planning configuration.
|
||||
process: Dictionary of data preprocessors for specific keys.
|
||||
transform: Dictionary of tensor transformations (e.g., image transforms).
|
||||
**kwargs: Additional configuration parameters.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.type = 'world_model'
|
||||
self.cfg = config
|
||||
self.solver = solver
|
||||
self.action_buffer: deque[torch.Tensor] = deque(
|
||||
maxlen=self.flatten_receding_horizon
|
||||
)
|
||||
self.process = process or {}
|
||||
self.transform = transform or {}
|
||||
self._action_buffer: deque[torch.Tensor] | None = None
|
||||
self._next_init: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
def flatten_receding_horizon(self) -> int:
|
||||
"""Receding horizon in environment steps (with frameskip)."""
|
||||
return self.cfg.receding_horizon * self.cfg.action_block
|
||||
|
||||
def set_env(self, env: Any) -> None:
|
||||
"""Configure the policy and solver for the given environment.
|
||||
|
||||
Args:
|
||||
env: The environment to associate with the policy.
|
||||
"""
|
||||
self.env = env
|
||||
n_envs = getattr(env, 'num_envs', 1)
|
||||
self.solver.configure(
|
||||
action_space=env.action_space, n_envs=n_envs, config=self.cfg
|
||||
)
|
||||
self._action_buffer = deque(maxlen=self.flatten_receding_horizon)
|
||||
|
||||
assert isinstance(self.solver, Solver), (
|
||||
'Solver must implement the Solver protocol'
|
||||
)
|
||||
|
||||
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
|
||||
if active_mask is None:
|
||||
return None
|
||||
active_mask = np.asarray(active_mask, dtype=bool)
|
||||
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
|
||||
raise ValueError(
|
||||
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
|
||||
)
|
||||
return active_mask
|
||||
|
||||
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
|
||||
"""Get action via planning with the world model.
|
||||
|
||||
Args:
|
||||
info_dict: Current state information from the environment.
|
||||
**kwargs: Additional parameters for planning.
|
||||
|
||||
Returns:
|
||||
The selected action(s) as a numpy array.
|
||||
"""
|
||||
assert hasattr(self, 'env'), 'Environment not set for the policy'
|
||||
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
|
||||
|
||||
# need to replan if action buffer is empty
|
||||
if len(self._action_buffer) == 0:
|
||||
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
|
||||
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
|
||||
|
||||
info_dict = self._prepare_info_for_device(
|
||||
info_dict, device=self.solver.device
|
||||
)
|
||||
info_dict = self._move_info_to_device(info_dict, self.solver.device)
|
||||
|
||||
outputs = self.solver(
|
||||
info_dict,
|
||||
init_action=self._next_init,
|
||||
active_mask=active_mask,
|
||||
)
|
||||
|
||||
actions = outputs['actions'] # (num_envs, horizon, action_dim)
|
||||
if active_mask is not None and actions.shape[0] != self.env.num_envs:
|
||||
full_actions = torch.zeros(
|
||||
self.env.num_envs,
|
||||
actions.shape[1],
|
||||
actions.shape[2],
|
||||
dtype=actions.dtype,
|
||||
device=actions.device,
|
||||
)
|
||||
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
|
||||
actions = full_actions
|
||||
|
||||
keep_horizon = self.cfg.receding_horizon
|
||||
plan = actions[:, :keep_horizon]
|
||||
rest = actions[:, keep_horizon:]
|
||||
self._next_init = rest.contiguous() if self.cfg.warm_start else None
|
||||
|
||||
# frameskip back to timestep
|
||||
plan = plan.reshape(
|
||||
self.env.num_envs, self.flatten_receding_horizon, -1
|
||||
).contiguous()
|
||||
|
||||
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
|
||||
|
||||
action = self._action_buffer.popleft()
|
||||
action = action.reshape(*self.env.action_space.shape)
|
||||
if active_mask is not None:
|
||||
if torch.is_tensor(action):
|
||||
inactive_mask = torch.as_tensor(
|
||||
~active_mask, device=action.device, dtype=torch.bool
|
||||
)
|
||||
action = action.clone()
|
||||
action[inactive_mask] = 0
|
||||
else:
|
||||
action = np.array(action, copy=True)
|
||||
action[~active_mask] = 0
|
||||
if torch.is_tensor(action):
|
||||
action = action.detach().cpu().numpy()
|
||||
else:
|
||||
action = np.asarray(action)
|
||||
|
||||
# post-process action
|
||||
if 'action' in self.process:
|
||||
action = self.process['action'].inverse_transform(action)
|
||||
|
||||
return action # (num_envs, action_dim)
|
||||
|
||||
|
||||
def _load_model_with_attribute(run_name, attribute_name, cache_dir=None):
|
||||
"""Helper function to load a model checkpoint and find a module with the specified attribute.
|
||||
|
||||
Args:
|
||||
run_name: Path or name of the model run
|
||||
attribute_name: Name of the attribute to look for in the module (e.g., 'get_action', 'get_cost')
|
||||
cache_dir: Optional cache directory path
|
||||
|
||||
Returns:
|
||||
The module with the specified attribute
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no module with the specified attribute is found
|
||||
"""
|
||||
if Path(run_name).exists():
|
||||
run_path = Path(run_name)
|
||||
else:
|
||||
run_path = Path(cache_dir or swm.data.utils.get_cache_dir(), run_name)
|
||||
|
||||
if run_path.is_dir():
|
||||
ckpt_files = list(run_path.glob('*_object.ckpt'))
|
||||
ckpt_files.sort(key=lambda x: x.stat().st_ctime, reverse=True)
|
||||
path = ckpt_files[0]
|
||||
logging.info(f'Loading model from checkpoint: {path}')
|
||||
else:
|
||||
path = Path(f'{run_path}_object.ckpt')
|
||||
assert path.exists(), (
|
||||
f'Checkpoint path does not exist: {path}. Launch pretraining first.'
|
||||
)
|
||||
|
||||
spt_module = torch.load(path, weights_only=False, map_location='cpu')
|
||||
|
||||
def scan_module(module):
|
||||
if hasattr(module, attribute_name):
|
||||
if isinstance(module, torch.nn.Module):
|
||||
module = module.eval()
|
||||
return module
|
||||
for child in module.children():
|
||||
result = scan_module(child)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
result = scan_module(spt_module)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
raise RuntimeError(
|
||||
f"No module with '{attribute_name}' found in the loaded world model."
|
||||
)
|
||||
|
||||
|
||||
def AutoActionableModel(
|
||||
run_name: str, cache_dir: str | Path | None = None
|
||||
) -> torch.nn.Module:
|
||||
"""Load a model checkpoint and return the module with a `get_action` method.
|
||||
|
||||
Automatically scans the checkpoint for a module implementing the Actionable
|
||||
protocol (i.e., has a `get_action` method).
|
||||
|
||||
Args:
|
||||
run_name: Path or name of the model run/checkpoint.
|
||||
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
|
||||
|
||||
Returns:
|
||||
The module with a `get_action` method, set to eval mode.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no module with `get_action` is found in the checkpoint.
|
||||
"""
|
||||
return _load_model_with_attribute(run_name, 'get_action', cache_dir)
|
||||
|
||||
|
||||
def AutoCostModel(
|
||||
run_name: str, cache_dir: str | Path | None = None
|
||||
) -> torch.nn.Module:
|
||||
"""Load a model checkpoint and return the module with a `get_cost` method.
|
||||
|
||||
Automatically scans the checkpoint for a module implementing a cost function
|
||||
(i.e., has a `get_cost` method) for use with planning solvers.
|
||||
|
||||
Args:
|
||||
run_name: Path or name of the model run/checkpoint.
|
||||
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
|
||||
|
||||
Returns:
|
||||
The module with a `get_cost` method, set to eval mode.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no module with `get_cost` is found in the checkpoint.
|
||||
"""
|
||||
return _load_model_with_attribute(run_name, 'get_cost', cache_dir)
|
||||
|
||||
|
||||
# Alias for backward compatibility and type hinting
|
||||
Policy = BasePolicy
|
||||
@@ -0,0 +1,17 @@
|
||||
from .cem import CEMSolver
|
||||
from .gd import GradientSolver
|
||||
from .icem import ICEMSolver
|
||||
from .lagrangian import LagrangianSolver
|
||||
from .mppi import MPPISolver
|
||||
from .solver import Solver
|
||||
from .discrete_solvers import PGDSolver
|
||||
|
||||
__all__ = [
|
||||
'Solver',
|
||||
'GradientSolver',
|
||||
'CEMSolver',
|
||||
'ICEMSolver',
|
||||
'PGDSolver',
|
||||
'MPPISolver',
|
||||
'LagrangianSolver',
|
||||
]
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Cross Entropy Method solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class CEMSolver:
|
||||
"""Cross Entropy Method solver for action optimization.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action candidates to sample per iteration.
|
||||
var_scale: Initial variance scale for the action distribution.
|
||||
n_steps: Number of CEM iterations.
|
||||
topk: Number of elite samples to keep for distribution update.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
batch_size: int = 1,
|
||||
num_samples: int = 300,
|
||||
var_scale: float = 1,
|
||||
n_steps: int = 30,
|
||||
topk: int = 30,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.var_scale = var_scale
|
||||
self.num_samples = num_samples
|
||||
self.n_steps = n_steps
|
||||
self.topk = topk
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.")
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_active_mask(
|
||||
active_mask: torch.Tensor | np.ndarray | None,
|
||||
n_envs: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor | None:
|
||||
if active_mask is None:
|
||||
return None
|
||||
if not torch.is_tensor(active_mask):
|
||||
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
|
||||
else:
|
||||
active_mask = active_mask.to(device=device, dtype=torch.bool)
|
||||
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
|
||||
raise ValueError(
|
||||
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
|
||||
)
|
||||
return active_mask
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
device = torch.device(self.device)
|
||||
var = self.var_scale * torch.ones(
|
||||
[self.n_envs, self.horizon, self.action_dim],
|
||||
device=device,
|
||||
)
|
||||
mean = (
|
||||
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
|
||||
if actions is None
|
||||
else actions
|
||||
)
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
new_mean = torch.zeros(
|
||||
[self.n_envs, remaining, self.action_dim],
|
||||
device=mean.device,
|
||||
)
|
||||
mean = torch.cat([mean, new_mean], dim=1)
|
||||
|
||||
return mean, var
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Solve the planning problem using Cross Entropy Method."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"costs": [],
|
||||
"mean": [], # History of means
|
||||
"var": [], # History of vars
|
||||
}
|
||||
|
||||
# -- initialize the action distribution globally
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
if mean.device != torch.device(self.device):
|
||||
mean = mean.to(self.device, non_blocking=True)
|
||||
if var.device != torch.device(self.device):
|
||||
var = var.to(self.device, non_blocking=True)
|
||||
active_mask = self._normalize_active_mask(
|
||||
active_mask, self.n_envs, torch.device(self.device)
|
||||
)
|
||||
|
||||
if active_mask is not None and not torch.any(active_mask):
|
||||
return {
|
||||
"costs": [],
|
||||
"actions": mean.detach(),
|
||||
"mean": [mean.detach()],
|
||||
"var": [var.detach()],
|
||||
}
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
# --- Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
# Slice Distribution Parameters for current batch
|
||||
batch_mean = mean[start_idx:end_idx]
|
||||
batch_var = var[start_idx:end_idx]
|
||||
|
||||
# Expand Info Dict for current batch
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
# v is shape (n_envs, ...)
|
||||
# Slice batch
|
||||
v_batch = v[start_idx:end_idx]
|
||||
if torch.is_tensor(v):
|
||||
if v_batch.device != self.device:
|
||||
v_batch = v_batch.to(self.device, non_blocking=True)
|
||||
# Add sample dim: (batch, 1, ...)
|
||||
v_batch = v_batch.unsqueeze(1)
|
||||
# Expand: (batch, num_samples, ...)
|
||||
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
if active_mask is not None:
|
||||
batch_mask = active_mask[start_idx:end_idx]
|
||||
if not torch.any(batch_mask):
|
||||
outputs["costs"].append(
|
||||
torch.full((current_bs,), float("nan"), device=self.device)
|
||||
)
|
||||
continue
|
||||
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
|
||||
active_local_np = active_local.detach().cpu().numpy()
|
||||
batch_mean = batch_mean[active_local]
|
||||
batch_var = batch_var[active_local]
|
||||
expanded_infos = {
|
||||
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
|
||||
for k, v in expanded_infos.items()
|
||||
}
|
||||
current_bs = int(active_local.numel())
|
||||
else:
|
||||
active_local = None
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
|
||||
candidates = torch.randn(
|
||||
current_bs,
|
||||
self.num_samples,
|
||||
self.horizon,
|
||||
self.action_dim,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
|
||||
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
|
||||
|
||||
# Force the first sample to be the current mean
|
||||
candidates[:, 0] = batch_mean
|
||||
|
||||
current_info = expanded_infos.copy()
|
||||
|
||||
# Evaluate candidates
|
||||
costs = self.model.get_cost(current_info, candidates)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
|
||||
# Select Top-K
|
||||
# topk_vals: (Batch, K), topk_inds: (Batch, K)
|
||||
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||
|
||||
# Gather Top-K Candidates
|
||||
# We need to select the specific candidates corresponding to topk_inds
|
||||
# Indexing: candidates[batch_idx, sample_idx]
|
||||
# Result shape: (Batch, K, Horizon, Dim)
|
||||
topk_candidates = candidates[batch_indices, topk_inds]
|
||||
|
||||
# Update Mean and Variance based on Top-K
|
||||
batch_mean = topk_candidates.mean(dim=1)
|
||||
batch_var = topk_candidates.std(dim=1)
|
||||
|
||||
# Update final cost for logging
|
||||
# We average the cost of the top elites
|
||||
final_batch_cost = topk_vals.mean(dim=1).detach()
|
||||
|
||||
# Write results back to global storage
|
||||
if active_mask is not None:
|
||||
global_indices = start_idx + active_local
|
||||
mean[global_indices] = batch_mean
|
||||
var[global_indices] = batch_var
|
||||
batch_costs = torch.full(
|
||||
(end_idx - start_idx,), float("nan"), device=self.device
|
||||
)
|
||||
batch_costs[active_local] = final_batch_cost
|
||||
else:
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
var[start_idx:end_idx] = batch_var
|
||||
batch_costs = final_batch_cost
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].append(batch_costs)
|
||||
|
||||
if outputs["costs"]:
|
||||
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
|
||||
else:
|
||||
outputs["costs"] = []
|
||||
outputs["actions"] = mean.detach()
|
||||
outputs["mean"] = [mean.detach()]
|
||||
outputs["var"] = [var.detach()]
|
||||
|
||||
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Projected Gradient Descent solver for discrete action spaces."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Discrete
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class PGDSolver(torch.nn.Module):
|
||||
"""Projected Gradient Descent solver for discrete action optimization.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
n_steps: Number of gradient descent iterations.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
var_scale: Initial variance scale for action perturbations.
|
||||
num_samples: Number of action samples to optimize in parallel.
|
||||
action_noise: Noise added to actions during optimization.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
n_steps: int,
|
||||
batch_size: int | None = None,
|
||||
var_scale: float = 1,
|
||||
num_samples: int = 1,
|
||||
action_noise: float = 0.0,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.n_steps = n_steps
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.var_scale = var_scale
|
||||
self.action_noise = action_noise
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
self._configured = False
|
||||
self._n_envs = None
|
||||
self._action_dim = None
|
||||
self._action_simplex_dim = None
|
||||
self._config = None
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
assert isinstance(action_space, Discrete), f"Action space must be discrete, got {type(action_space)}"
|
||||
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._action_simplex_dim = int(action_space.n)
|
||||
self._configured = True
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def action_simplex_dim(self) -> int:
|
||||
"""Simplex dimension for discrete action probabilities."""
|
||||
return self._action_simplex_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action(
|
||||
self, actions: torch.Tensor | None = None, from_scalar: bool = False
|
||||
) -> None:
|
||||
"""Initialize the action tensor for optimization."""
|
||||
if actions is None:
|
||||
actions = torch.zeros((self._n_envs, 0, self.action_simplex_dim))
|
||||
elif from_scalar:
|
||||
# convert scalar to one-hot
|
||||
actions = torch.nn.functional.one_hot(actions, num_classes=self._action_simplex_dim).to(torch.float32)
|
||||
# merge action_block dim
|
||||
actions = actions.reshape(*actions.shape[:-2], self.action_simplex_dim)
|
||||
assert (
|
||||
actions.shape[0] == self._n_envs
|
||||
and actions.shape[1] <= self.horizon
|
||||
and actions.shape[2] == self.action_simplex_dim
|
||||
)
|
||||
|
||||
# fill remaining action
|
||||
remaining = self.horizon - actions.shape[1]
|
||||
|
||||
if remaining > 0:
|
||||
new_actions = torch.zeros(self._n_envs, remaining, self.action_simplex_dim)
|
||||
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
||||
|
||||
actions = actions.unsqueeze(1).repeat_interleave(self.num_samples, dim=1) # add sample dim
|
||||
actions[:, 1:] += (
|
||||
torch.randn(actions[:, 1:].shape, generator=self.torch_gen, device=self.device) * self.var_scale
|
||||
) # add small noise to all samples except the first one
|
||||
|
||||
# reset actions
|
||||
if hasattr(self, "init"):
|
||||
self.init.copy_(actions)
|
||||
else:
|
||||
self.register_parameter("init", torch.nn.Parameter(actions))
|
||||
|
||||
def solve(
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
from_scalar: bool = False,
|
||||
) -> dict:
|
||||
"""Solve the planning problem using projected gradient descent."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"cost": [], # Will store list of cost histories per batch
|
||||
"actions": None,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.init_action(init_action, from_scalar=from_scalar)
|
||||
|
||||
# Determine batch size (default to all envs if not specified which can cause memory issues)
|
||||
batch_size = self.batch_size if self.batch_size is not None else self.n_envs
|
||||
total_envs = self.n_envs
|
||||
|
||||
# Lists to hold results from each batch to be concatenated later
|
||||
batch_top_actions_list = []
|
||||
|
||||
# --- Outer Loop: Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, batch_size):
|
||||
end_idx = min(start_idx + batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||
batch_init.requires_grad = True
|
||||
|
||||
optim = torch.optim.SGD([batch_init], lr=1.0)
|
||||
|
||||
# Prepare Batch Infos
|
||||
# Slice the input info_dict and then expand dimensions
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
# Slice the data for the current batch indices
|
||||
# Assumes input data dim 0 corresponds to n_envs
|
||||
if torch.is_tensor(v):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = batch_v.unsqueeze(1)
|
||||
batch_v = batch_v.expand(current_bs, self.num_samples, *batch_v.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = np.repeat(batch_v[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = batch_v
|
||||
|
||||
# Perform Gradient Descent for this batch
|
||||
batch_cost_history = []
|
||||
|
||||
for step in range(self.n_steps):
|
||||
current_info = expanded_infos.copy()
|
||||
|
||||
# Calculate cost using the batch parameter
|
||||
costs = self.model.get_cost(current_info, batch_init)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Got {type(costs)} cost, expect torch.Tensor"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
assert costs.requires_grad, "Cost must requires_grad for PGD solver."
|
||||
|
||||
cost = costs.sum() # Sum cost for this batch
|
||||
cost.backward()
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
# Add noise
|
||||
if self.action_noise > 0:
|
||||
batch_init.data += torch.randn(batch_init.shape, generator=self.torch_gen) * self.action_noise
|
||||
|
||||
# projection onto simplex
|
||||
with torch.no_grad():
|
||||
batch_init.copy_(self._project_action_simplex(batch_init))
|
||||
|
||||
batch_cost_history.append(cost.item())
|
||||
|
||||
# Store cost history for this batch
|
||||
outputs["cost"].append(batch_cost_history)
|
||||
|
||||
# Update the global self.init with the optimized batch values
|
||||
with torch.no_grad():
|
||||
self.init[start_idx:end_idx] = batch_init
|
||||
|
||||
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
||||
batch_indices = torch.arange(current_bs)
|
||||
|
||||
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||
|
||||
# convert one-hot back to discrete actions
|
||||
top_actions_batch = self._factor_action_block(top_actions_batch).argmax(dim=-1)
|
||||
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
||||
|
||||
# Concatenate all batch results
|
||||
outputs["actions"] = torch.cat(batch_top_actions_list, dim=0)
|
||||
end_time = time.time()
|
||||
print(f"PGDSolver.solve completed in {end_time - start_time:.4f} seconds.")
|
||||
|
||||
return outputs
|
||||
|
||||
def _factor_action_block(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Factor the action block dimension from action_simplex_dim."""
|
||||
original_shape = actions.shape
|
||||
action_block = self._config.action_block
|
||||
simplex_dim = self._action_simplex_dim
|
||||
return actions.reshape(*original_shape[:-1], action_block, simplex_dim)
|
||||
|
||||
def _project_action_simplex(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Project the action onto the probability simplex."""
|
||||
original_shape = actions.shape
|
||||
|
||||
s = self._factor_action_block(actions).reshape(-1, self._action_simplex_dim)
|
||||
|
||||
mu, _ = torch.sort(s, descending=True, dim=-1)
|
||||
cumulative = mu.cumsum(dim=-1)
|
||||
|
||||
d = s.size(-1)
|
||||
indices = torch.arange(1, d + 1, device=s.device, dtype=s.dtype)
|
||||
|
||||
threshold = (cumulative - 1) / indices
|
||||
|
||||
cond = (mu > threshold).to(torch.int32)
|
||||
rho = cond.cumsum(dim=-1)
|
||||
valid_rho = rho * cond
|
||||
rho_max = valid_rho.max(dim=-1, keepdim=True)[0]
|
||||
|
||||
rho_min = torch.clamp(rho_max, min=1)
|
||||
psi = (cumulative.gather(-1, rho_min - 1) - 1) / rho_min
|
||||
|
||||
projected = torch.clamp(s - psi, min=0.0).reshape(original_shape)
|
||||
return projected
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Gradient-based solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class GradientSolver(torch.nn.Module):
|
||||
"""Gradient-based solver using backpropagation through the world model.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
n_steps: Number of gradient descent iterations.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
var_scale: Initial variance scale for action perturbations.
|
||||
num_samples: Number of action samples to optimize in parallel.
|
||||
action_noise: Noise added to actions during optimization.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
optimizer_cls: PyTorch optimizer class to use.
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
n_steps: int,
|
||||
batch_size: int | None = None,
|
||||
var_scale: float = 1,
|
||||
num_samples: int = 1,
|
||||
action_noise: float = 0.0,
|
||||
device: str | torch.device = 'cpu',
|
||||
seed: int = 1234,
|
||||
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD,
|
||||
optimizer_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.n_steps = n_steps
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.var_scale = var_scale
|
||||
self.action_noise = action_noise
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
self.optimizer_cls = optimizer_cls
|
||||
self.optimizer_kwargs = (
|
||||
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
||||
)
|
||||
|
||||
self._configured = False
|
||||
self._n_envs = None
|
||||
self._action_dim = None
|
||||
self._config = None
|
||||
|
||||
def configure(
|
||||
self, *, action_space: gym.Space, n_envs: int, config: Any
|
||||
) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(
|
||||
f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.'
|
||||
)
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
||||
"""Initialize the action tensor for optimization."""
|
||||
device = torch.device(self.device)
|
||||
if actions is None:
|
||||
actions = torch.zeros(
|
||||
(self._n_envs, 0, self.action_dim), device=device
|
||||
)
|
||||
elif actions.device != device:
|
||||
actions = actions.to(device, non_blocking=True)
|
||||
|
||||
# fill remaining action
|
||||
remaining = self.horizon - actions.shape[1]
|
||||
|
||||
if remaining > 0:
|
||||
new_actions = torch.zeros(
|
||||
self._n_envs,
|
||||
remaining,
|
||||
self.action_dim,
|
||||
device=actions.device,
|
||||
dtype=actions.dtype,
|
||||
)
|
||||
actions = torch.cat([actions, new_actions], dim=1)
|
||||
|
||||
actions = actions.unsqueeze(1).repeat_interleave(
|
||||
self.num_samples, dim=1
|
||||
) # add sample dim
|
||||
actions[:, 1:] += (
|
||||
torch.randn(
|
||||
actions[:, 1:].shape,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
* self.var_scale
|
||||
) # add small noise to all samples except the first one
|
||||
|
||||
# reset actions
|
||||
if hasattr(self, 'init'):
|
||||
self.init.copy_(actions)
|
||||
else:
|
||||
self.register_parameter('init', torch.nn.Parameter(actions))
|
||||
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using gradient descent."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
'cost': [], # Will store list of cost histories per batch
|
||||
'actions': None,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.init_action(init_action)
|
||||
|
||||
# Determine batch size (default to all envs if not specified which can cause memory issues)
|
||||
batch_size = (
|
||||
self.batch_size if self.batch_size is not None else self.n_envs
|
||||
)
|
||||
total_envs = self.n_envs
|
||||
|
||||
# Lists to hold results from each batch to be concatenated later
|
||||
batch_top_actions_list = []
|
||||
|
||||
# --- Outer Loop: Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, batch_size):
|
||||
end_idx = min(start_idx + batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||
batch_init.requires_grad = True
|
||||
|
||||
# We initialize the optimizer class passed in __init__ with the kwargs
|
||||
optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs)
|
||||
|
||||
# Prepare Batch Infos
|
||||
# Slice the input info_dict and then expand dimensions
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
# Slice the data for the current batch indices
|
||||
# Assumes input data dim 0 corresponds to n_envs
|
||||
if torch.is_tensor(v):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = batch_v.unsqueeze(1)
|
||||
batch_v = batch_v.expand(
|
||||
current_bs, self.num_samples, *batch_v.shape[2:]
|
||||
)
|
||||
elif isinstance(v, np.ndarray):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = np.repeat(
|
||||
batch_v[:, None, ...], self.num_samples, axis=1
|
||||
)
|
||||
expanded_infos[k] = batch_v
|
||||
|
||||
final_batch_cost = None
|
||||
|
||||
for step in range(self.n_steps):
|
||||
current_info = expanded_infos.copy()
|
||||
|
||||
# Calculate cost using the batch parameter
|
||||
costs = self.model.get_cost(current_info, batch_init)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), (
|
||||
f'Got {type(costs)} cost, expect torch.Tensor'
|
||||
)
|
||||
assert (
|
||||
costs.ndim == 2
|
||||
and costs.shape[0] == current_bs
|
||||
and costs.shape[1] == self.num_samples
|
||||
), (
|
||||
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
||||
)
|
||||
assert costs.requires_grad, (
|
||||
'Cost must requires_grad for GD solver.'
|
||||
)
|
||||
|
||||
cost = costs.sum() # Sum cost for this batch
|
||||
cost.backward()
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
# Add noise
|
||||
if self.action_noise > 0:
|
||||
batch_init.data += (
|
||||
torch.randn(
|
||||
batch_init.shape,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
* self.action_noise
|
||||
)
|
||||
|
||||
final_batch_cost = costs.detach().min(dim=1).values
|
||||
|
||||
# Store cost history for this batch
|
||||
outputs['cost'].append(final_batch_cost)
|
||||
|
||||
# Update the global self.init with the optimized batch values
|
||||
with torch.no_grad():
|
||||
self.init[start_idx:end_idx] = batch_init
|
||||
|
||||
top_idx = costs.argmin(dim=1)
|
||||
batch_indices = torch.arange(current_bs, device=self.device)
|
||||
|
||||
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||
batch_top_actions_list.append(top_actions_batch.detach())
|
||||
|
||||
# Concatenate all batch results
|
||||
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
||||
outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist()
|
||||
end_time = time.time()
|
||||
print(
|
||||
f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.'
|
||||
)
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Improved Cross Entropy Method (iCEM) solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class ICEMSolver:
|
||||
"""Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention.
|
||||
iCEM improves the sample efficiency over standard CEM and was introduced by
|
||||
[1] for real-time planning.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action candidates to sample per iteration.
|
||||
var_scale: Initial variance scale for the action distribution.
|
||||
n_steps: Number of CEM iterations.
|
||||
topk: Number of elite samples to keep for distribution update.
|
||||
noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
|
||||
alpha: Momentum for mean/std EMA update.
|
||||
n_elite_keep: Number of elites carried from previous iteration.
|
||||
return_mean: If False, return best single trajectory instead of mean.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
|
||||
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and
|
||||
G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning".
|
||||
Conference on Robot Learning, 2020.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
batch_size: int = 1,
|
||||
num_samples: int = 300,
|
||||
var_scale: float = 1,
|
||||
n_steps: int = 30,
|
||||
topk: int = 30,
|
||||
noise_beta: float = 2.0,
|
||||
alpha: float = 0.1,
|
||||
n_elite_keep: int = 5,
|
||||
return_mean: bool = True,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.var_scale = var_scale
|
||||
self.num_samples = num_samples
|
||||
self.n_steps = n_steps
|
||||
self.topk = topk
|
||||
self.noise_beta = noise_beta
|
||||
self.alpha = alpha
|
||||
self.n_elite_keep = n_elite_keep
|
||||
self.return_mean = return_mean
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if isinstance(action_space, Box):
|
||||
self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32)
|
||||
self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.")
|
||||
self._action_low = None
|
||||
self._action_high = None
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
|
||||
return mean, var
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using improved Cross Entropy Method."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"costs": [],
|
||||
"mean": [],
|
||||
"var": [],
|
||||
}
|
||||
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
|
||||
for start_idx in range(0, self.n_envs, self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, self.n_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_mean = mean[start_idx:end_idx]
|
||||
batch_var = var[start_idx:end_idx]
|
||||
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
v_batch = v[start_idx:end_idx]
|
||||
if torch.is_tensor(v):
|
||||
v_batch = v_batch.unsqueeze(1)
|
||||
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
prev_topk_candidates = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
|
||||
# Precompute FFT scale for colored noise
|
||||
noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon)
|
||||
freqs = torch.fft.rfftfreq(self.horizon, device=self.device)
|
||||
freqs[0] = 1.0
|
||||
noise_scale = freqs.pow(-self.noise_beta / 2)
|
||||
noise_scale[0] = noise_scale[1]
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Colored noise: generate with temporal axis last, then transpose
|
||||
if self.horizon <= 1:
|
||||
noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
|
||||
else:
|
||||
white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
|
||||
fft = torch.fft.rfft(white, dim=-1)
|
||||
colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1)
|
||||
std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
noise = colored / std
|
||||
noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim)
|
||||
|
||||
candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
|
||||
candidates[:, 0] = batch_mean
|
||||
|
||||
# Inject previous elites
|
||||
if prev_topk_candidates is not None:
|
||||
n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1])
|
||||
candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject]
|
||||
|
||||
# Clip to action bounds
|
||||
if self._action_low is not None:
|
||||
candidates = candidates.clamp(self._action_low, self._action_high)
|
||||
|
||||
current_info = expanded_infos.copy()
|
||||
costs = self.model.get_cost(current_info, candidates)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
|
||||
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||
topk_candidates = candidates[batch_indices, topk_inds]
|
||||
|
||||
prev_topk_candidates = topk_candidates
|
||||
|
||||
# Momentum update
|
||||
elite_mean = topk_candidates.mean(dim=1)
|
||||
elite_var = topk_candidates.std(dim=1)
|
||||
batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean
|
||||
batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var
|
||||
|
||||
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
|
||||
|
||||
if self.return_mean:
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
else:
|
||||
mean[start_idx:end_idx] = topk_candidates[:, 0]
|
||||
|
||||
var[start_idx:end_idx] = batch_var
|
||||
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
|
||||
print(f"iCEM solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Lagrangian solver for stable world model."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class LagrangianSolver(torch.nn.Module):
|
||||
"""Lagrangian solver for stable world model.
|
||||
|
||||
get_cost returns the cost tensor (B, S). If the model also implements get_constraints,
|
||||
it should return the constraint violations (B, S, C), where C is the number of constraints.
|
||||
The constraint_cost should represent the cost of violating the constraints, where the constraint
|
||||
is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
|
||||
|
||||
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
|
||||
|
||||
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol. Its get_cost() returns
|
||||
a plain cost tensor (B, S). If it also has get_constraints(), that method
|
||||
returns constraints of shape (B, S, C).
|
||||
n_steps: Number of gradient descent steps per outer iteration.
|
||||
n_outer_steps: Number of dual ascent (outer) iterations.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action samples to optimize in parallel.
|
||||
var_scale: Initial variance scale for action perturbations.
|
||||
action_noise: Noise added to actions during optimization.
|
||||
rho_init: Initial penalty coefficient for the quadratic constraint term.
|
||||
rho_max: Maximum value of the penalty coefficient.
|
||||
rho_scale: Multiplicative growth factor for rho after each outer step.
|
||||
persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
optimizer_cls: PyTorch optimizer class to use.
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
n_steps: int,
|
||||
n_outer_steps: int = 5,
|
||||
batch_size: int | None = None,
|
||||
num_samples: int = 1,
|
||||
var_scale: float = 1.0,
|
||||
action_noise: float = 0.0,
|
||||
rho_init: float = 1.0,
|
||||
rho_max: float = 1e4,
|
||||
rho_scale: float = 2.0,
|
||||
persist_multipliers: bool = True,
|
||||
device: str | torch.device = 'cpu',
|
||||
seed: int = 1234,
|
||||
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||
optimizer_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.n_steps = n_steps
|
||||
self.n_outer_steps = n_outer_steps
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.var_scale = var_scale
|
||||
self.action_noise = action_noise
|
||||
self.rho_init = rho_init
|
||||
self.rho_max = rho_max
|
||||
self.rho_scale = rho_scale
|
||||
self.persist_multipliers = persist_multipliers
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
self.optimizer_cls = optimizer_cls
|
||||
self.optimizer_kwargs = (
|
||||
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
||||
)
|
||||
|
||||
self._configured = False
|
||||
self._n_envs = None
|
||||
self._action_dim = None
|
||||
self._config = None
|
||||
self._lambdas: torch.Tensor | None = None # (n_envs, C)
|
||||
|
||||
def configure(
|
||||
self, *, action_space: gym.Space, n_envs: int, config: Any
|
||||
) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(
|
||||
f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.'
|
||||
)
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
||||
"""Initialize the action tensor for optimization."""
|
||||
if actions is None:
|
||||
actions = torch.zeros((self._n_envs, 0, self.action_dim))
|
||||
|
||||
remaining = self.horizon - actions.shape[1]
|
||||
if remaining > 0:
|
||||
new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
|
||||
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
||||
|
||||
actions = actions.unsqueeze(1).repeat_interleave(
|
||||
self.num_samples, dim=1
|
||||
)
|
||||
actions[:, 1:] += (
|
||||
torch.randn(
|
||||
actions[:, 1:].shape,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
* self.var_scale
|
||||
)
|
||||
|
||||
if hasattr(self, 'init'):
|
||||
self.init.copy_(actions)
|
||||
else:
|
||||
self.register_parameter('init', torch.nn.Parameter(actions))
|
||||
|
||||
def _init_multipliers(self, num_constraints: int) -> None:
|
||||
"""Lazily initialize Lagrange multipliers to zeros."""
|
||||
self._lambdas = torch.zeros(
|
||||
self._n_envs, num_constraints, device=self.device
|
||||
)
|
||||
|
||||
def _augmented_lagrangian_loss(
|
||||
self,
|
||||
costs: torch.Tensor, # (B, S)
|
||||
constraints: torch.Tensor, # (B, S, C)
|
||||
lambdas_batch: torch.Tensor, # (B, C)
|
||||
rho: float,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the augmented Lagrangian loss.
|
||||
|
||||
L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2
|
||||
"""
|
||||
# lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C)
|
||||
linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum(
|
||||
dim=-1
|
||||
) # (B, S)
|
||||
quadratic_penalty = rho * F.relu(constraints).pow(2).sum(
|
||||
dim=-1
|
||||
) # (B, S)
|
||||
return (costs + linear_penalty + quadratic_penalty).sum()
|
||||
|
||||
def _update_multipliers(
|
||||
self,
|
||||
constraints: torch.Tensor, # (B, S, C) — detached, no grad
|
||||
lambdas_batch: torch.Tensor, # (B, C)
|
||||
rho: float,
|
||||
) -> torch.Tensor:
|
||||
"""Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i))."""
|
||||
mean_g = constraints.mean(dim=1) # (B, C)
|
||||
return torch.clamp(lambdas_batch + rho * mean_g, min=0.0)
|
||||
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using augmented Lagrangian gradient descent."""
|
||||
start_time = time.time()
|
||||
outputs: dict = {
|
||||
'cost': [],
|
||||
'constraint_violation': [],
|
||||
'actions': None,
|
||||
'lambdas': None,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.init_action(init_action)
|
||||
|
||||
if not self.persist_multipliers:
|
||||
self._lambdas = None
|
||||
|
||||
batch_size = (
|
||||
self.batch_size if self.batch_size is not None else self.n_envs
|
||||
)
|
||||
total_envs = self.n_envs
|
||||
batch_top_actions_list = []
|
||||
|
||||
for start_idx in range(0, total_envs, batch_size):
|
||||
end_idx = min(start_idx + batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||
batch_init.requires_grad = True
|
||||
|
||||
# Expand info_dict for current batch — same pattern as GradientSolver
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = batch_v.unsqueeze(1)
|
||||
batch_v = batch_v.expand(
|
||||
current_bs, self.num_samples, *batch_v.shape[2:]
|
||||
)
|
||||
elif isinstance(v, np.ndarray):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = np.repeat(
|
||||
batch_v[:, None, ...], self.num_samples, axis=1
|
||||
)
|
||||
else:
|
||||
batch_v = v
|
||||
expanded_infos[k] = batch_v
|
||||
|
||||
rho = self.rho_init
|
||||
batch_cost_history = []
|
||||
costs = None
|
||||
final_constraints = None
|
||||
|
||||
for _outer in range(self.n_outer_steps):
|
||||
# Fresh optimizer each outer step — avoids stale momentum after dual ascent
|
||||
optim = self.optimizer_cls(
|
||||
[batch_init], **self.optimizer_kwargs
|
||||
)
|
||||
|
||||
for _step in range(self.n_steps):
|
||||
current_info = expanded_infos.copy()
|
||||
costs = self.model.get_cost(current_info, batch_init)
|
||||
constraints = (
|
||||
self.model.get_constraints(
|
||||
expanded_infos.copy(), batch_init
|
||||
)
|
||||
if hasattr(self.model, 'get_constraints')
|
||||
else None
|
||||
)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), (
|
||||
f'Got {type(costs)} cost, expect torch.Tensor'
|
||||
)
|
||||
assert costs.ndim == 2 and costs.shape == (
|
||||
current_bs,
|
||||
self.num_samples,
|
||||
), (
|
||||
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
||||
)
|
||||
assert costs.requires_grad, (
|
||||
'Cost must requires_grad for LagrangianSolver.'
|
||||
)
|
||||
|
||||
if constraints is not None:
|
||||
assert constraints.ndim == 3 and constraints.shape[
|
||||
:2
|
||||
] == (current_bs, self.num_samples), (
|
||||
f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}'
|
||||
)
|
||||
if self._lambdas is None:
|
||||
self._init_multipliers(constraints.shape[-1])
|
||||
lambdas_batch = self._lambdas[start_idx:end_idx]
|
||||
loss = self._augmented_lagrangian_loss(
|
||||
costs, constraints, lambdas_batch, rho
|
||||
)
|
||||
else:
|
||||
loss = costs.sum()
|
||||
|
||||
loss.backward()
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
if self.action_noise > 0:
|
||||
batch_init.data += (
|
||||
torch.randn(
|
||||
batch_init.shape, generator=self.torch_gen
|
||||
)
|
||||
* self.action_noise
|
||||
)
|
||||
|
||||
batch_cost_history.append(loss.item())
|
||||
|
||||
# Dual ascent after inner loop converges
|
||||
if constraints is not None:
|
||||
with torch.no_grad():
|
||||
final_constraints = self.model.get_constraints(
|
||||
expanded_infos.copy(), batch_init
|
||||
)
|
||||
lambdas_batch = self._update_multipliers(
|
||||
final_constraints, lambdas_batch, rho
|
||||
)
|
||||
self._lambdas[start_idx:end_idx] = lambdas_batch
|
||||
rho = min(self.rho_max, rho * self.rho_scale)
|
||||
|
||||
with torch.no_grad():
|
||||
mean_cost = costs.mean().item()
|
||||
if constraints is not None:
|
||||
viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,)
|
||||
lam = lambdas_batch.mean(dim=0) # (C,)
|
||||
viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist())
|
||||
lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist())
|
||||
print(
|
||||
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
||||
f'cost={mean_cost:.4f} | '
|
||||
f'constraint_viol=[{viol_str}] | '
|
||||
f'lambdas=[{lam_str}] | '
|
||||
f'rho={rho:.4f}'
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
||||
f'cost={mean_cost:.4f}'
|
||||
)
|
||||
|
||||
outputs['cost'].append(batch_cost_history)
|
||||
|
||||
if final_constraints is not None:
|
||||
outputs['constraint_violation'].append(
|
||||
F.relu(final_constraints).mean().item()
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.init[start_idx:end_idx] = batch_init
|
||||
|
||||
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
||||
batch_indices = torch.arange(current_bs)
|
||||
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
||||
|
||||
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
||||
outputs['lambdas'] = (
|
||||
self._lambdas.cpu() if self._lambdas is not None else None
|
||||
)
|
||||
|
||||
constraint_info = ''
|
||||
if outputs['constraint_violation']:
|
||||
mean_viol = np.mean(outputs['constraint_violation'])
|
||||
constraint_info = f' | constraint_violation={mean_viol:.4f}'
|
||||
print(
|
||||
f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.'
|
||||
)
|
||||
return outputs
|
||||
@@ -0,0 +1,208 @@
|
||||
"""Model Predictive Path Integral solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class MPPISolver:
|
||||
"""Model Predictive Path Integral solver for action optimization.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action candidates to sample per iteration.
|
||||
var_scale: Initial variance scale for action noise.
|
||||
n_steps: Number of MPPI iterations.
|
||||
topk: Number of elite samples for weighted averaging.
|
||||
temperature: Temperature parameter for softmax weighting.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
batch_size: int = 1,
|
||||
num_samples: int = 300,
|
||||
var_scale: float = 1.0,
|
||||
n_steps: int = 30,
|
||||
topk: int = 30,
|
||||
temperature: float = 0.5,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.topk = topk
|
||||
self.var_scale = var_scale
|
||||
self.n_steps = n_steps
|
||||
self.temperature = temperature
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(
|
||||
f"Action space is discrete, got {type(action_space)}. MPPISolver may not work as expected."
|
||||
)
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
|
||||
return mean, var
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using MPPI."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"costs": [],
|
||||
"mean": [],
|
||||
"var": [],
|
||||
}
|
||||
|
||||
# -- initialize the action distribution globally
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
# --- Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
# Slice Distribution Parameters for current batch
|
||||
batch_mean = mean[start_idx:end_idx]
|
||||
batch_var = var[start_idx:end_idx]
|
||||
|
||||
# Expand Info Dict for current batch (Same as CEM)
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
v_batch = v[start_idx:end_idx]
|
||||
if torch.is_tensor(v):
|
||||
# Add sample dim: (batch, 1, ...)
|
||||
v_batch = v_batch.unsqueeze(1)
|
||||
# Expand: (batch, num_samples, ...)
|
||||
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Sample noise: (Batch, Num_Samples, Horizon, Dim)
|
||||
noise = torch.randn(
|
||||
current_bs,
|
||||
self.num_samples,
|
||||
self.horizon,
|
||||
self.action_dim,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# MPPI Logic: candidates = mean + noise * sigma
|
||||
candidates = batch_mean.unsqueeze(1) + noise * batch_var.unsqueeze(1)
|
||||
|
||||
# Force the first sample to be the current mean (Zero noise)
|
||||
candidates[:, 0] = batch_mean
|
||||
|
||||
# Evaluate candidates
|
||||
costs = self.model.get_cost(expanded_infos, candidates)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
|
||||
# Select Elites (Optional, based on topk)
|
||||
if self.topk is not None and self.topk < self.num_samples:
|
||||
# topk_vals: (Batch, K), topk_inds: (Batch, K)
|
||||
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||
|
||||
# Gather Top-K Candidates
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
# (Batch, K, Horizon, Dim)
|
||||
relevant_candidates = candidates[batch_indices, topk_inds]
|
||||
relevant_costs = topk_vals
|
||||
else:
|
||||
relevant_candidates = candidates
|
||||
relevant_costs = costs
|
||||
|
||||
# MPPI Weighting: Softmax(-cost / temperature)
|
||||
# Stabilize softmax by subtracting min cost
|
||||
min_cost = relevant_costs.min(dim=1, keepdim=True)[0]
|
||||
scaled_costs = relevant_costs - min_cost
|
||||
weights = torch.softmax(-scaled_costs / self.temperature, dim=1) # (Batch, K)
|
||||
|
||||
# Update Mean: weighted sum of candidates
|
||||
# Reshape weights for broadcasting: (Batch, K, 1, 1)
|
||||
weights_expanded = weights.unsqueeze(-1).unsqueeze(-1)
|
||||
batch_mean = (weights_expanded * relevant_candidates).sum(dim=1)
|
||||
|
||||
# Store average cost of the utilized samples for logging
|
||||
final_batch_cost = relevant_costs.mean(dim=1).cpu().tolist()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
# We do not update var in standard MPPI
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
|
||||
print(f"MPPI solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
@@ -0,0 +1,80 @@
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class Costable(Protocol):
|
||||
"""Protocol for world model cost functions."""
|
||||
|
||||
def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the cost criterion for action candidates.
|
||||
|
||||
Args:
|
||||
info_dict: Dictionary containing environment state information.
|
||||
action_candidates: Tensor of proposed actions.
|
||||
|
||||
Returns:
|
||||
A tensor of cost values for each action candidate.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
"""Compute cost for given action candidates based on info dictionary.
|
||||
|
||||
Args:
|
||||
info_dict: Dictionary containing environment state information.
|
||||
action_candidates: Tensor of proposed actions.
|
||||
|
||||
Returns:
|
||||
A tensor of cost values for each action candidate.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Solver(Protocol):
|
||||
"""Protocol for model-based planning solvers."""
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment and planning specifications.
|
||||
|
||||
Args:
|
||||
action_space: The action space of the environment.
|
||||
n_envs: Number of parallel environments.
|
||||
config: Planning configuration object.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
...
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments being planned for."""
|
||||
...
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon length in timesteps."""
|
||||
...
|
||||
|
||||
def solve(
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
active_mask: torch.Tensor | np.ndarray | None = None,
|
||||
) -> dict:
|
||||
"""Solve the planning optimization problem to find optimal actions.
|
||||
|
||||
Args:
|
||||
info_dict: Dictionary containing environment state information.
|
||||
init_action: Optional initial action sequence to warm-start the solver.
|
||||
|
||||
Returns:
|
||||
Dictionary containing optimized actions and other solver-specific info.
|
||||
"""
|
||||
...
|
||||
1049
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
Normal file
1049
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
Normal file
File diff suppressed because it is too large
Load Diff
241
AMD_SETUP.md
Normal file
241
AMD_SETUP.md
Normal file
@@ -0,0 +1,241 @@
|
||||
# AMD ROCm 环境配置说明
|
||||
|
||||
这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留
|
||||
`torch.compile` 时的 PyTorch 版本选择。
|
||||
|
||||
目标运行命令:
|
||||
|
||||
```bash
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||
```
|
||||
|
||||
## 已验证环境
|
||||
|
||||
本次验证通过的环境:
|
||||
|
||||
- Ubuntu 24.04
|
||||
- AMD Radeon PRO W7900D (`gfx1100`)
|
||||
- 系统 ROCm 7.1.1
|
||||
- Python 3.10
|
||||
- `torch==2.10.0+rocm7.1`
|
||||
- `torchvision==0.25.0+rocm7.1`
|
||||
- `triton-rocm==3.6.0`
|
||||
|
||||
注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU,但在本项目里开启
|
||||
`torch.compile` 后会崩溃,错误类似:
|
||||
|
||||
```text
|
||||
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
|
||||
CUDA error: unspecified launch failure
|
||||
```
|
||||
|
||||
降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。
|
||||
|
||||
## 检查系统 ROCm
|
||||
|
||||
在新 AMD 机器上,先确认系统能识别 GPU:
|
||||
|
||||
```bash
|
||||
rocminfo
|
||||
amd-smi version
|
||||
hipcc --version
|
||||
```
|
||||
|
||||
`rocminfo` 里应该能看到 AMD GPU agent,例如 `gfx1100`。
|
||||
|
||||
## 创建 Python 环境
|
||||
|
||||
使用 `uv` 创建 Python 3.10 虚拟环境:
|
||||
|
||||
```bash
|
||||
cd /path/to/lewm
|
||||
uv venv --python 3.10 --allow-existing .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住,
|
||||
用 pip 安装大 wheel 更容易观察进度。
|
||||
|
||||
```bash
|
||||
uv pip install pip
|
||||
```
|
||||
|
||||
## 安装 ROCm 版 PyTorch
|
||||
|
||||
安装本项目已验证可用的 ROCm PyTorch 组合:
|
||||
|
||||
```bash
|
||||
python -m pip install --force-reinstall \
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1 \
|
||||
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"torch==2.10.0" \
|
||||
"torchvision==0.25.0"
|
||||
```
|
||||
|
||||
PyTorch wheel 有数 GB。如果网络慢,不要频繁中断重试,尽量等它下载完成。
|
||||
|
||||
## 安装项目依赖
|
||||
|
||||
普通 Python 包建议走国内 PyPI 镜像:
|
||||
|
||||
```bash
|
||||
python -m pip install \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"gymnasium[all]==1.2.2" \
|
||||
"stable-baselines3==2.8.0" \
|
||||
"stable-worldmodel[train,env]"
|
||||
```
|
||||
|
||||
然后修正两个容易被 pip 带偏的依赖版本:
|
||||
|
||||
```bash
|
||||
python -m pip install \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"fsspec==2025.3.0" \
|
||||
"pillow==11.3.0"
|
||||
```
|
||||
|
||||
检查环境:
|
||||
|
||||
```bash
|
||||
python -m pip check
|
||||
python - <<'PY'
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
print("torch:", torch.__version__)
|
||||
print("hip:", torch.version.hip)
|
||||
print("cuda available:", torch.cuda.is_available())
|
||||
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
|
||||
print("torchvision:", torchvision.__version__)
|
||||
PY
|
||||
```
|
||||
|
||||
期望看到类似输出:
|
||||
|
||||
```text
|
||||
torch: 2.10.0+rocm7.1
|
||||
cuda available: True
|
||||
torchvision: 0.25.0+rocm7.1
|
||||
```
|
||||
|
||||
## 恢复本仓库里的 stable-worldmodel 修改
|
||||
|
||||
这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在:
|
||||
|
||||
```text
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/
|
||||
```
|
||||
|
||||
从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行:
|
||||
|
||||
```bash
|
||||
git restore -- \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
|
||||
```
|
||||
|
||||
然后确认没有意外修改:
|
||||
|
||||
```bash
|
||||
git status --short
|
||||
```
|
||||
|
||||
## 数据和 checkpoint 路径
|
||||
|
||||
`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。
|
||||
|
||||
PushT 评估至少需要:
|
||||
|
||||
```text
|
||||
$STABLEWM_HOME/pusht_expert_train.h5
|
||||
$STABLEWM_HOME/pusht/lewm_object.ckpt
|
||||
```
|
||||
|
||||
例如本机使用:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/mnt/ASC1637/stablewm
|
||||
```
|
||||
|
||||
如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`。
|
||||
|
||||
## 运行评估
|
||||
|
||||
默认 PushT 评估,保留 `torch.compile`:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/path/to/stablewm
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||
```
|
||||
|
||||
快速 smoke test:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/path/to/stablewm
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
|
||||
eval.num_eval=1 \
|
||||
world.num_envs=1 \
|
||||
output.filename=/tmp/lewm_smoke_test.txt
|
||||
```
|
||||
|
||||
smoke test 应该能正常结束,并打印类似:
|
||||
|
||||
```text
|
||||
{'success_rate': 100.0, ...}
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT`
|
||||
|
||||
如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本:
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__, torch.version.hip)"
|
||||
```
|
||||
|
||||
如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合:
|
||||
|
||||
```bash
|
||||
python -m pip install --force-reinstall \
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1 \
|
||||
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"torch==2.10.0" \
|
||||
"torchvision==0.25.0"
|
||||
```
|
||||
|
||||
### 找不到 `pusht_expert_train.h5`
|
||||
|
||||
设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/path/to/stablewm
|
||||
```
|
||||
|
||||
### pip 尝试构建旧版 `gym==0.21`
|
||||
|
||||
这是依赖解析回退导致的。先显式安装兼容版本:
|
||||
|
||||
```bash
|
||||
python -m pip install \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"gymnasium[all]==1.2.2" \
|
||||
"stable-baselines3==2.8.0"
|
||||
```
|
||||
|
||||
### uv 或 pip 访问海外源很慢
|
||||
|
||||
普通 Python 包使用国内 PyPI 镜像:
|
||||
|
||||
```bash
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:
|
||||
|
||||
```bash
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1
|
||||
```
|
||||
27
README.md
27
README.md
@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
|
||||
```
|
||||
|
||||
## Profiling
|
||||
|
||||
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
|
||||
inference_precision=bf16 \
|
||||
+profile.enabled=true \
|
||||
+profile.with_stack=true \
|
||||
+profile.record_shapes=true \
|
||||
+profile.profile_memory=true
|
||||
```
|
||||
|
||||
Supported inference precision modes:
|
||||
- `inference_precision=fp32`
|
||||
- `inference_precision=bf16`
|
||||
- `inference_precision=fp16`
|
||||
|
||||
Outputs are written next to the evaluation results:
|
||||
- `torch_profile/key_averages.txt` for the aggregated operator table
|
||||
- `torch_profile/trace.json` for Chrome tracing
|
||||
- TensorBoard trace files under `torch_profile/`
|
||||
|
||||
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.
|
||||
|
||||
@@ -24,6 +24,7 @@ dataset:
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
inference_precision: fp16
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
@@ -36,6 +37,10 @@ eval:
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
save_video: false
|
||||
compile_warmup:
|
||||
enabled: true
|
||||
num_eval: 1
|
||||
dataset_name: ogbench/cube_single_expert
|
||||
callables:
|
||||
# -- set state
|
||||
@@ -56,6 +61,21 @@ eval:
|
||||
target_quat:
|
||||
value: goal_privileged_block_0_quat
|
||||
|
||||
multi_node:
|
||||
enabled: false
|
||||
backend: gloo
|
||||
rank_env: RANK
|
||||
world_size_env: WORLD_SIZE
|
||||
local_rank_env: LOCAL_RANK
|
||||
aggregate_results: true
|
||||
sync_before_return: false
|
||||
destroy_process_group: true
|
||||
shard_strategy: round_robin
|
||||
|
||||
preload_wait:
|
||||
enabled: false
|
||||
file: /tmp/lewm_preload_start
|
||||
poll_interval: 1.0
|
||||
|
||||
output:
|
||||
filename: ogb_cube_results.txt
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ dataset:
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
inference_precision: fp16
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
@@ -31,6 +32,10 @@ eval:
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
save_video: false
|
||||
compile_warmup:
|
||||
enabled: true
|
||||
num_eval: 1
|
||||
dataset_name: pusht_expert_train
|
||||
callables:
|
||||
# -- set state
|
||||
@@ -43,6 +48,22 @@ eval:
|
||||
args:
|
||||
goal_state:
|
||||
value: goal_state
|
||||
|
||||
multi_node:
|
||||
enabled: false
|
||||
backend: gloo
|
||||
rank_env: RANK
|
||||
world_size_env: WORLD_SIZE
|
||||
local_rank_env: LOCAL_RANK
|
||||
aggregate_results: true
|
||||
sync_before_return: false
|
||||
destroy_process_group: true
|
||||
shard_strategy: round_robin
|
||||
|
||||
preload_wait:
|
||||
enabled: false
|
||||
file: /tmp/lewm_preload_start
|
||||
poll_interval: 1.0
|
||||
|
||||
output:
|
||||
filename: pusht_results.txt
|
||||
filename: pusht_results.txt
|
||||
|
||||
@@ -18,6 +18,7 @@ dataset:
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
inference_precision: fp16
|
||||
|
||||
plan_config:
|
||||
horizon: 5
|
||||
@@ -30,6 +31,10 @@ eval:
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
save_video: false
|
||||
compile_warmup:
|
||||
enabled: true
|
||||
num_eval: 1
|
||||
dataset_name: dmc/reacher_random
|
||||
callables:
|
||||
# -- set state
|
||||
@@ -45,6 +50,21 @@ eval:
|
||||
target_qpos:
|
||||
value: goal_qpos
|
||||
|
||||
multi_node:
|
||||
enabled: false
|
||||
backend: gloo
|
||||
rank_env: RANK
|
||||
world_size_env: WORLD_SIZE
|
||||
local_rank_env: LOCAL_RANK
|
||||
aggregate_results: true
|
||||
sync_before_return: false
|
||||
destroy_process_group: true
|
||||
shard_strategy: round_robin
|
||||
|
||||
preload_wait:
|
||||
enabled: false
|
||||
file: /tmp/lewm_preload_start
|
||||
poll_interval: 1.0
|
||||
|
||||
output:
|
||||
filename: dmc_results.txt
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
_target_: stable_worldmodel.solver.CEMSolver
|
||||
model: ???
|
||||
batch_size: 1
|
||||
batch_size: 16
|
||||
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
|
||||
num_samples: 300
|
||||
var_scale: 1.0
|
||||
n_steps: 30
|
||||
topk: 30
|
||||
topk: 8
|
||||
device: "cuda"
|
||||
seed: ${seed}
|
||||
|
||||
14
config/eval/solver/gradient.yaml
Normal file
14
config/eval/solver/gradient.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
_target_: stable_worldmodel.solver.GradientSolver
|
||||
model: ???
|
||||
# Original adam.yaml reference: n_steps=30, num_samples=100, batch_size=1, lr=0.1.
|
||||
n_steps: 90
|
||||
batch_size: 100
|
||||
num_samples: 1
|
||||
action_noise: 0
|
||||
device: "cuda"
|
||||
seed: ${seed}
|
||||
optimizer_cls:
|
||||
_target_: hydra.utils.get_class
|
||||
path: torch.optim.AdamW
|
||||
optimizer_kwargs:
|
||||
lr: 0.075
|
||||
@@ -12,6 +12,7 @@ world:
|
||||
|
||||
seed: 42
|
||||
policy: random # ckpt name or random
|
||||
inference_precision: fp16
|
||||
|
||||
dataset:
|
||||
stats: ${eval.dataset_name}
|
||||
@@ -30,6 +31,10 @@ eval:
|
||||
goal_offset_steps: 25
|
||||
eval_budget: 50
|
||||
img_size: 224
|
||||
save_video: false
|
||||
compile_warmup:
|
||||
enabled: true
|
||||
num_eval: 1
|
||||
dataset_name: tworoom
|
||||
callables:
|
||||
# -- set state
|
||||
@@ -43,5 +48,21 @@ eval:
|
||||
goal_state:
|
||||
value: goal_proprio
|
||||
|
||||
multi_node:
|
||||
enabled: false
|
||||
backend: gloo
|
||||
rank_env: RANK
|
||||
world_size_env: WORLD_SIZE
|
||||
local_rank_env: LOCAL_RANK
|
||||
aggregate_results: true
|
||||
sync_before_return: false
|
||||
destroy_process_group: true
|
||||
shard_strategy: round_robin
|
||||
|
||||
preload_wait:
|
||||
enabled: false
|
||||
file: /tmp/lewm_preload_start
|
||||
poll_interval: 1.0
|
||||
|
||||
output:
|
||||
filename: tworoom_results.txt
|
||||
filename: tworoom_results.txt
|
||||
|
||||
836
eval.py
836
eval.py
@@ -2,8 +2,12 @@ import os
|
||||
|
||||
os.environ["MUJOCO_GL"] = "egl"
|
||||
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
@@ -46,76 +50,331 @@ def get_dataset(cfg, dataset_name):
|
||||
)
|
||||
return dataset
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||
def run(cfg: DictConfig):
|
||||
"""Run evaluation of dinowm vs random policy."""
|
||||
assert (
|
||||
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
||||
), "Planning horizon must be smaller than or equal to eval_budget"
|
||||
|
||||
# create world environment
|
||||
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
|
||||
world = swm.World(**cfg.world, image_shape=(224, 224))
|
||||
def get_profile_cfg(cfg):
|
||||
profile_cfg = {
|
||||
"enabled": False,
|
||||
"trace_dirname": "torch_profile",
|
||||
"record_shapes": True,
|
||||
"profile_memory": True,
|
||||
"with_stack": False,
|
||||
"with_flops": False,
|
||||
"row_limit": 40,
|
||||
"worker_name": "eval",
|
||||
"export_chrome_trace": True,
|
||||
"export_tensorboard": True,
|
||||
}
|
||||
cfg_profile = cfg.get("profile")
|
||||
if cfg_profile is not None:
|
||||
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
|
||||
return profile_cfg
|
||||
|
||||
# create the transform
|
||||
transform = {
|
||||
"pixels": img_transform(cfg),
|
||||
"goal": img_transform(cfg),
|
||||
|
||||
def get_compile_cfg(cfg):
|
||||
compile_cfg = {
|
||||
"enabled": True,
|
||||
"target": "predictor",
|
||||
"mode": "reduce-overhead",
|
||||
"fullgraph": False,
|
||||
"dynamic": False,
|
||||
"cuda_only": True,
|
||||
}
|
||||
cfg_compile = cfg.get("compile")
|
||||
if cfg_compile is not None:
|
||||
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
|
||||
return compile_cfg
|
||||
|
||||
|
||||
def get_compile_warmup_cfg(cfg):
|
||||
warmup_cfg = {
|
||||
"enabled": True,
|
||||
"num_eval": 1,
|
||||
}
|
||||
cfg_warmup = cfg.get("compile_warmup")
|
||||
if cfg_warmup is None:
|
||||
cfg_eval = cfg.get("eval")
|
||||
if cfg_eval is not None:
|
||||
cfg_warmup = cfg_eval.get("compile_warmup")
|
||||
if cfg_warmup is not None:
|
||||
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
|
||||
return warmup_cfg
|
||||
|
||||
|
||||
def get_preload_wait_cfg(cfg):
|
||||
preload_cfg = {
|
||||
"enabled": False,
|
||||
"file": "/tmp/lewm_preload_start",
|
||||
"poll_interval": 1.0,
|
||||
}
|
||||
cfg_preload = cfg.get("preload_wait")
|
||||
if cfg_preload is not None:
|
||||
preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True))
|
||||
return preload_cfg
|
||||
|
||||
|
||||
def wait_for_preload_signal(cfg, rank=0):
|
||||
preload_cfg = get_preload_wait_cfg(cfg)
|
||||
if not preload_cfg["enabled"]:
|
||||
return
|
||||
|
||||
dist_ready = (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
)
|
||||
if dist_ready:
|
||||
torch.distributed.barrier()
|
||||
|
||||
signal_path = Path(str(preload_cfg["file"])).expanduser()
|
||||
poll_interval = float(preload_cfg["poll_interval"])
|
||||
if rank == 0:
|
||||
print(
|
||||
"Preload ready. Create this file to start evaluation: "
|
||||
f"{signal_path}",
|
||||
flush=True,
|
||||
)
|
||||
while not signal_path.exists():
|
||||
time.sleep(poll_interval)
|
||||
print("Preload start signal received. Starting evaluation.", flush=True)
|
||||
|
||||
if dist_ready:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def maybe_compile_inference_target(model, cfg, device):
|
||||
compile_cfg = get_compile_cfg(cfg)
|
||||
compile_target = "disabled"
|
||||
|
||||
if not compile_cfg["enabled"]:
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
if not hasattr(torch, "compile"):
|
||||
print("torch.compile is unavailable, skipping inference compilation.")
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
|
||||
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
target = str(compile_cfg["target"]).lower()
|
||||
compile_kwargs = {
|
||||
"mode": compile_cfg["mode"],
|
||||
"fullgraph": compile_cfg["fullgraph"],
|
||||
"dynamic": compile_cfg["dynamic"],
|
||||
}
|
||||
|
||||
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
||||
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
|
||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
||||
if target == "predictor":
|
||||
if not hasattr(model, "predictor"):
|
||||
print("Requested compile target 'predictor' is unavailable on the model.")
|
||||
return model, compile_cfg, compile_target
|
||||
model.predictor = torch.compile(model.predictor, **compile_kwargs)
|
||||
compile_target = "predictor"
|
||||
elif target == "predict":
|
||||
if not hasattr(model, "predict"):
|
||||
print("Requested compile target 'predict' is unavailable on the model.")
|
||||
return model, compile_cfg, compile_target
|
||||
model.predict = torch.compile(model.predict, **compile_kwargs)
|
||||
compile_target = "predict"
|
||||
else:
|
||||
print(
|
||||
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
|
||||
)
|
||||
|
||||
return model, compile_cfg, compile_target
|
||||
|
||||
|
||||
def get_inference_context(cfg, device):
|
||||
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||||
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||||
|
||||
if precision == "fp32":
|
||||
return nullcontext(), "fp32"
|
||||
|
||||
if precision in {"bf16", "bfloat16"}:
|
||||
return (
|
||||
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
|
||||
"bf16",
|
||||
)
|
||||
|
||||
if precision in {"fp16", "float16"}:
|
||||
if device_type != "cuda":
|
||||
print("fp16 inference is only supported on CUDA, falling back to fp32.")
|
||||
return nullcontext(), "fp32"
|
||||
return (
|
||||
torch.autocast(device_type=device_type, dtype=torch.float16),
|
||||
"fp16",
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
|
||||
)
|
||||
|
||||
|
||||
def get_eval_grad_context(solver=None):
|
||||
if isinstance(solver, swm.solver.GradientSolver):
|
||||
return torch.enable_grad()
|
||||
return torch.inference_mode()
|
||||
|
||||
|
||||
def make_profiler(cfg, results_path):
|
||||
profile_cfg = get_profile_cfg(cfg)
|
||||
if not profile_cfg["enabled"]:
|
||||
return nullcontext(), None, profile_cfg
|
||||
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
if torch.cuda.is_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
|
||||
profile_dir = results_path / profile_cfg["trace_dirname"]
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
profiler = torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=profile_cfg["record_shapes"],
|
||||
profile_memory=profile_cfg["profile_memory"],
|
||||
with_stack=profile_cfg["with_stack"],
|
||||
with_flops=profile_cfg["with_flops"],
|
||||
)
|
||||
return profiler, profile_dir, profile_cfg
|
||||
|
||||
|
||||
def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
||||
if profiler is None or profile_dir is None:
|
||||
return None
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
table = profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
|
||||
row_limit=profile_cfg["row_limit"],
|
||||
)
|
||||
|
||||
summary_path = profile_dir / "key_averages.txt"
|
||||
summary_path.write_text(table)
|
||||
|
||||
if profile_cfg["export_tensorboard"]:
|
||||
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
||||
)
|
||||
trace_handler(profiler)
|
||||
elif profile_cfg["export_chrome_trace"]:
|
||||
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
||||
|
||||
return summary_path
|
||||
|
||||
|
||||
def get_multi_gpu_cfg(cfg):
|
||||
multi_gpu_cfg = {
|
||||
"enabled": False,
|
||||
"devices": None,
|
||||
"start_method": "spawn",
|
||||
}
|
||||
cfg_multi_gpu = cfg.get("multi_gpu")
|
||||
if cfg_multi_gpu is not None:
|
||||
multi_gpu_cfg.update(OmegaConf.to_container(cfg_multi_gpu, resolve=True))
|
||||
return multi_gpu_cfg
|
||||
|
||||
|
||||
def get_multi_node_cfg(cfg):
|
||||
multi_node_cfg = {
|
||||
"enabled": False,
|
||||
"backend": "gloo",
|
||||
"rank_env": "RANK",
|
||||
"world_size_env": "WORLD_SIZE",
|
||||
"local_rank_env": "LOCAL_RANK",
|
||||
"output_mode": "single",
|
||||
"aggregate_results": True,
|
||||
"sync_before_return": False,
|
||||
"destroy_process_group": True,
|
||||
"shard_strategy": "round_robin",
|
||||
}
|
||||
cfg_multi_node = cfg.get("multi_node")
|
||||
if cfg_multi_node is not None:
|
||||
multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True))
|
||||
return multi_node_cfg
|
||||
|
||||
|
||||
def get_dist_env(name, default=None):
|
||||
value = os.environ.get(name, default)
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def get_rank_context(cfg):
|
||||
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||
if not multi_node_cfg["enabled"]:
|
||||
return 0, 1, 0
|
||||
|
||||
rank = get_dist_env(multi_node_cfg["rank_env"])
|
||||
world_size = get_dist_env(multi_node_cfg["world_size_env"])
|
||||
local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0)
|
||||
|
||||
if rank is None or world_size is None:
|
||||
raise ValueError(
|
||||
"multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE"
|
||||
)
|
||||
if world_size < 1:
|
||||
raise ValueError("WORLD_SIZE must be >= 1")
|
||||
if rank < 0 or rank >= world_size:
|
||||
raise ValueError("RANK must be in [0, WORLD_SIZE)")
|
||||
return rank, world_size, local_rank
|
||||
|
||||
|
||||
def all_gather_eval_result(result):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
payload = [None for _ in range(world_size)]
|
||||
torch.distributed.all_gather_object(payload, result)
|
||||
return payload
|
||||
|
||||
|
||||
def finalize_multi_node_process_group(cfg):
|
||||
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||
if not multi_node_cfg["destroy_process_group"]:
|
||||
return
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def get_rank_result_path(output_dir: Path, cfg: DictConfig, rank: int) -> Path:
|
||||
filename = str(cfg.output.filename)
|
||||
if rank == 0:
|
||||
return output_dir / filename
|
||||
|
||||
suffix = Path(filename).suffix
|
||||
stem = Path(filename).stem
|
||||
if suffix:
|
||||
ranked_filename = f"{stem}.rank{rank}{suffix}"
|
||||
else:
|
||||
ranked_filename = f"{filename}.rank{rank}"
|
||||
return output_dir / ranked_filename
|
||||
|
||||
|
||||
def build_process(cfg, dataset):
|
||||
process = {}
|
||||
for col in cfg.dataset.keys_to_cache:
|
||||
if col in ["pixels"]:
|
||||
continue
|
||||
processor = preprocessing.StandardScaler()
|
||||
col_data = stats_dataset.get_col_data(col)
|
||||
col_data = dataset.get_col_data(col)
|
||||
col_data = col_data[~np.isnan(col_data).any(axis=1)]
|
||||
processor.fit(col_data)
|
||||
process[col] = processor
|
||||
|
||||
if col != "action":
|
||||
process[f"goal_{col}"] = process[col]
|
||||
return process
|
||||
|
||||
# -- run evaluation
|
||||
policy = cfg.get("policy", "random")
|
||||
|
||||
if policy != "random":
|
||||
model = swm.policy.AutoCostModel(cfg.policy)
|
||||
model = model.to("cuda")
|
||||
model = model.eval()
|
||||
model.requires_grad_(False)
|
||||
model.interpolate_pos_encoding = True
|
||||
config = swm.PlanConfig(**cfg.plan_config)
|
||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||||
policy = swm.policy.WorldModelPolicy(
|
||||
solver=solver, config=config, process=process, transform=transform
|
||||
)
|
||||
|
||||
else:
|
||||
policy = swm.policy.RandomPolicy()
|
||||
|
||||
results_path = (
|
||||
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
||||
if cfg.policy != "random"
|
||||
else Path(__file__).parent
|
||||
)
|
||||
|
||||
# sample the episodes and the starting indices
|
||||
def sample_eval_cases(cfg, dataset):
|
||||
stats_dataset = dataset
|
||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
||||
episode_len = get_episodes_length(dataset, ep_indices)
|
||||
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
|
||||
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
|
||||
# Map each dataset row’s episode_idx to its max_start_idx
|
||||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||||
max_start_per_row = np.array(
|
||||
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
|
||||
)
|
||||
|
||||
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
|
||||
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
|
||||
valid_indices = np.nonzero(valid_mask)[0]
|
||||
print(valid_mask.sum(), "valid starting points found for evaluation.")
|
||||
@@ -124,35 +383,478 @@ def run(cfg: DictConfig):
|
||||
random_episode_indices = g.choice(
|
||||
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
|
||||
)
|
||||
|
||||
# sort increasingly to avoid issues with HDF5Dataset indexing
|
||||
random_episode_indices = np.sort(valid_indices[random_episode_indices])
|
||||
|
||||
print(random_episode_indices)
|
||||
|
||||
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
|
||||
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
|
||||
rows = dataset.get_row_data(random_episode_indices)
|
||||
eval_episodes = rows[col_name]
|
||||
eval_start_idx = rows["step_idx"]
|
||||
|
||||
if len(eval_episodes) < cfg.eval.num_eval:
|
||||
raise ValueError("Not enough episodes with sufficient length for evaluation.")
|
||||
|
||||
return eval_episodes, eval_start_idx
|
||||
|
||||
|
||||
def normalize_multi_gpu_devices(devices):
|
||||
if devices is None:
|
||||
return [f"cuda:{idx}" for idx in range(torch.cuda.device_count())]
|
||||
|
||||
normalized = []
|
||||
for device in devices:
|
||||
if isinstance(device, int):
|
||||
normalized.append(f"cuda:{device}")
|
||||
elif isinstance(device, str) and device.isdigit():
|
||||
normalized.append(f"cuda:{int(device)}")
|
||||
else:
|
||||
normalized.append(str(device))
|
||||
return normalized
|
||||
|
||||
|
||||
def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
|
||||
if num_shards < 1:
|
||||
raise ValueError("num_shards must be >= 1")
|
||||
|
||||
total = len(eval_episodes)
|
||||
shard_sizes = [total // num_shards] * num_shards
|
||||
for idx in range(total % num_shards):
|
||||
shard_sizes[idx] += 1
|
||||
|
||||
shards = []
|
||||
start = 0
|
||||
for size in shard_sizes:
|
||||
end = start + size
|
||||
if size > 0:
|
||||
shards.append((eval_episodes[start:end], eval_start_idx[start:end]))
|
||||
start = end
|
||||
return shards
|
||||
|
||||
|
||||
def get_rank_eval_subset(
|
||||
eval_episodes,
|
||||
eval_start_idx,
|
||||
rank,
|
||||
world_size,
|
||||
*,
|
||||
strategy="contiguous",
|
||||
):
|
||||
if world_size < 1:
|
||||
raise ValueError("world_size must be >= 1")
|
||||
if rank < 0 or rank >= world_size:
|
||||
raise ValueError("rank must be in [0, world_size)")
|
||||
|
||||
if strategy == "round_robin":
|
||||
episode_subset = eval_episodes[rank::world_size]
|
||||
start_subset = eval_start_idx[rank::world_size]
|
||||
return episode_subset, start_subset
|
||||
if strategy != "contiguous":
|
||||
raise ValueError("strategy must be one of: contiguous, round_robin")
|
||||
|
||||
total = len(eval_episodes)
|
||||
shard_sizes = [total // world_size] * world_size
|
||||
for idx in range(total % world_size):
|
||||
shard_sizes[idx] += 1
|
||||
|
||||
start = sum(shard_sizes[:rank])
|
||||
end = start + shard_sizes[rank]
|
||||
return eval_episodes[start:end], eval_start_idx[start:end]
|
||||
|
||||
|
||||
def run_eval_subset(
|
||||
cfg: DictConfig,
|
||||
eval_episodes,
|
||||
eval_start_idx,
|
||||
output_dir: Path,
|
||||
*,
|
||||
device_override: str | None = None,
|
||||
enable_profile: bool = True,
|
||||
enable_compile_warmup: bool = False,
|
||||
before_evaluate=None,
|
||||
):
|
||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
local_cfg.eval.num_eval = len(eval_episodes)
|
||||
local_cfg.world.num_envs = len(eval_episodes)
|
||||
local_cfg.world.max_episode_steps = 2 * local_cfg.eval.eval_budget
|
||||
|
||||
if device_override is not None:
|
||||
local_cfg.solver.device = device_override
|
||||
if torch.cuda.is_available() and str(device_override).startswith("cuda"):
|
||||
torch.cuda.set_device(torch.device(device_override))
|
||||
|
||||
if not enable_profile:
|
||||
if local_cfg.get("profile") is None:
|
||||
local_cfg.profile = OmegaConf.create({"enabled": False})
|
||||
else:
|
||||
local_cfg.profile.enabled = False
|
||||
|
||||
world = swm.World(**local_cfg.world, image_shape=(224, 224))
|
||||
transform = {
|
||||
"pixels": img_transform(local_cfg),
|
||||
"goal": img_transform(local_cfg),
|
||||
}
|
||||
dataset = get_dataset(local_cfg, local_cfg.eval.dataset_name)
|
||||
process = build_process(local_cfg, dataset)
|
||||
|
||||
policy_name = local_cfg.get("policy", "random")
|
||||
if policy_name != "random":
|
||||
model = swm.policy.AutoCostModel(local_cfg.policy)
|
||||
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
model.requires_grad_(False)
|
||||
model, compile_cfg, compile_target = maybe_compile_inference_target(
|
||||
model, local_cfg, device
|
||||
)
|
||||
inference_ctx, inference_precision = get_inference_context(local_cfg, device)
|
||||
model.interpolate_pos_encoding = True
|
||||
config = swm.PlanConfig(**local_cfg.plan_config)
|
||||
solver = hydra.utils.instantiate(local_cfg.solver, model=model)
|
||||
policy = swm.policy.WorldModelPolicy(
|
||||
solver=solver, config=config, process=process, transform=transform
|
||||
)
|
||||
else:
|
||||
policy = swm.policy.RandomPolicy()
|
||||
solver = None
|
||||
inference_ctx = nullcontext()
|
||||
inference_precision = "fp32"
|
||||
compile_cfg = get_compile_cfg(local_cfg)
|
||||
compile_target = "disabled"
|
||||
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
profiler_ctx, profile_dir, profile_cfg = make_profiler(local_cfg, output_dir)
|
||||
world.set_policy(policy)
|
||||
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if before_evaluate is not None:
|
||||
before_evaluate()
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
|
||||
return world.evaluate_from_dataset(
|
||||
dataset,
|
||||
start_steps=list(start_indices),
|
||||
goal_offset_steps=eval_cfg.eval.goal_offset_steps,
|
||||
eval_budget=eval_cfg.eval.eval_budget,
|
||||
episodes_idx=list(episodes),
|
||||
callables=OmegaConf.to_container(
|
||||
eval_cfg.eval.get("callables"), resolve=True
|
||||
),
|
||||
save_video=bool(eval_cfg.eval.get("save_video", False)),
|
||||
video_path=output_dir,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
metrics = world.evaluate_from_dataset(
|
||||
dataset,
|
||||
start_steps=eval_start_idx.tolist(),
|
||||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
||||
eval_budget=cfg.eval.eval_budget,
|
||||
episodes_idx=eval_episodes.tolist(),
|
||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||||
video_path=results_path,
|
||||
with get_eval_grad_context(solver):
|
||||
with profiler_ctx as profiler:
|
||||
with inference_ctx:
|
||||
if enable_compile_warmup:
|
||||
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||
metrics = evaluate_subset(eval_episodes, eval_start_idx)
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
evaluation_time = time.time() - start_time
|
||||
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"evaluation_time": evaluation_time,
|
||||
"inference_precision": inference_precision,
|
||||
"compile_target": compile_target,
|
||||
"compile_mode": compile_cfg["mode"] if compile_target != "disabled" else None,
|
||||
"profile_dir": profile_dir,
|
||||
"profile_summary_path": profile_summary_path,
|
||||
}
|
||||
|
||||
|
||||
def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
||||
warmup_cfg = get_compile_warmup_cfg(cfg)
|
||||
if not warmup_cfg["enabled"]:
|
||||
return
|
||||
|
||||
if get_multi_gpu_cfg(cfg)["enabled"]:
|
||||
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
|
||||
return
|
||||
|
||||
if get_multi_node_cfg(cfg)["enabled"]:
|
||||
rank, world_size, local_rank = get_rank_context(cfg)
|
||||
eval_episodes, eval_start_idx = get_rank_eval_subset(
|
||||
eval_episodes, eval_start_idx, rank, world_size
|
||||
)
|
||||
device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device_override = None
|
||||
|
||||
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
|
||||
if warmup_count < 1:
|
||||
return
|
||||
|
||||
warmup_eval_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
warmup_eval_cfg.eval.num_eval = warmup_count
|
||||
warmup_eval_cfg.eval.save_video = False
|
||||
if warmup_eval_cfg.get("profile") is None:
|
||||
warmup_eval_cfg.profile = OmegaConf.create({"enabled": False})
|
||||
else:
|
||||
warmup_eval_cfg.profile.enabled = False
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
|
||||
run_eval_subset(
|
||||
warmup_eval_cfg,
|
||||
list(eval_episodes[:warmup_count]),
|
||||
list(eval_start_idx[:warmup_count]),
|
||||
Path(tmpdir),
|
||||
device_override=device_override,
|
||||
enable_profile=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def _multi_gpu_eval_worker(
|
||||
cfg_container,
|
||||
eval_episodes,
|
||||
eval_start_idx,
|
||||
output_dir,
|
||||
device,
|
||||
shard_idx,
|
||||
queue,
|
||||
):
|
||||
try:
|
||||
cfg = OmegaConf.create(cfg_container)
|
||||
result = run_eval_subset(
|
||||
cfg,
|
||||
eval_episodes,
|
||||
eval_start_idx,
|
||||
Path(output_dir),
|
||||
device_override=device,
|
||||
enable_profile=False,
|
||||
)
|
||||
queue.put({"ok": True, "shard_idx": shard_idx, "result": result})
|
||||
except Exception:
|
||||
queue.put(
|
||||
{
|
||||
"ok": False,
|
||||
"shard_idx": shard_idx,
|
||||
"error": traceback.format_exc(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||
multi_gpu_cfg = get_multi_gpu_cfg(cfg)
|
||||
devices = normalize_multi_gpu_devices(multi_gpu_cfg["devices"])
|
||||
if len(devices) < 2:
|
||||
raise ValueError("multi_gpu.enabled=true requires at least 2 CUDA devices")
|
||||
|
||||
shards = shard_eval_cases(eval_episodes, eval_start_idx, min(len(devices), len(eval_episodes)))
|
||||
devices = devices[: len(shards)]
|
||||
|
||||
ctx = mp.get_context(multi_gpu_cfg["start_method"])
|
||||
queue = ctx.Queue()
|
||||
cfg_container = OmegaConf.to_container(cfg, resolve=False)
|
||||
processes = []
|
||||
|
||||
start_time = time.time()
|
||||
for shard_idx, ((shard_episodes, shard_start_idx), device) in enumerate(
|
||||
zip(shards, devices, strict=True)
|
||||
):
|
||||
process = ctx.Process(
|
||||
target=_multi_gpu_eval_worker,
|
||||
args=(
|
||||
cfg_container,
|
||||
list(shard_episodes),
|
||||
list(shard_start_idx),
|
||||
str(output_dir),
|
||||
device,
|
||||
shard_idx,
|
||||
queue,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
processes.append(process)
|
||||
|
||||
shard_results = {}
|
||||
errors = []
|
||||
for _ in processes:
|
||||
message = queue.get()
|
||||
if message["ok"]:
|
||||
shard_results[message["shard_idx"]] = message["result"]
|
||||
else:
|
||||
errors.append(message["error"])
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
|
||||
if errors:
|
||||
raise RuntimeError(errors[0])
|
||||
|
||||
ordered_results = [shard_results[idx] for idx in range(len(processes))]
|
||||
episode_successes = np.concatenate(
|
||||
[
|
||||
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
||||
for result in ordered_results
|
||||
]
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
seeds = None
|
||||
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
||||
if all(seed is not None for seed in shard_seeds):
|
||||
seeds = np.concatenate(shard_seeds)
|
||||
|
||||
metrics = {
|
||||
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
||||
"episode_successes": episode_successes,
|
||||
"seeds": seeds,
|
||||
}
|
||||
reference = ordered_results[0]
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"evaluation_time": time.time() - start_time,
|
||||
"inference_precision": reference["inference_precision"],
|
||||
"compile_target": reference["compile_target"],
|
||||
"compile_mode": reference["compile_mode"],
|
||||
"profile_dir": None,
|
||||
"profile_summary_path": None,
|
||||
}
|
||||
|
||||
|
||||
def combine_eval_results(ordered_results):
|
||||
episode_successes = np.concatenate(
|
||||
[
|
||||
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
||||
for result in ordered_results
|
||||
]
|
||||
)
|
||||
|
||||
seeds = None
|
||||
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
||||
if all(seed is not None for seed in shard_seeds):
|
||||
seeds = np.concatenate(shard_seeds)
|
||||
|
||||
metrics = {
|
||||
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
||||
"episode_successes": episode_successes,
|
||||
"seeds": seeds,
|
||||
}
|
||||
reference = ordered_results[0]
|
||||
return metrics, reference
|
||||
|
||||
|
||||
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
||||
rank, world_size, local_rank = get_rank_context(cfg)
|
||||
multi_node_cfg = get_multi_node_cfg(cfg)
|
||||
shard_episodes, shard_start_idx = get_rank_eval_subset(
|
||||
eval_episodes,
|
||||
eval_start_idx,
|
||||
rank,
|
||||
world_size,
|
||||
strategy=multi_node_cfg["shard_strategy"],
|
||||
)
|
||||
if len(shard_episodes) == 0:
|
||||
raise ValueError("No evaluation episodes assigned to this rank")
|
||||
|
||||
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
||||
local_cfg.multi_node.enabled = False
|
||||
if local_cfg.get("multi_gpu") is None:
|
||||
local_cfg.multi_gpu = OmegaConf.create({"enabled": False})
|
||||
else:
|
||||
local_cfg.multi_gpu.enabled = False
|
||||
|
||||
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
||||
preload_cfg = get_preload_wait_cfg(cfg)
|
||||
if preload_cfg["enabled"]:
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("torch.distributed is required for preload_wait")
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
|
||||
|
||||
rank_output_path = get_rank_result_path(output_dir, cfg, rank)
|
||||
result = run_eval_subset(
|
||||
local_cfg,
|
||||
list(shard_episodes),
|
||||
list(shard_start_idx),
|
||||
rank_output_path.parent,
|
||||
device_override=device,
|
||||
enable_profile=False,
|
||||
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
|
||||
)
|
||||
if not multi_node_cfg["aggregate_results"]:
|
||||
result["output_filename"] = rank_output_path.name
|
||||
finalize_multi_node_process_group(cfg)
|
||||
return result
|
||||
|
||||
if not torch.distributed.is_available():
|
||||
raise RuntimeError("torch.distributed is required for multi-node evaluation")
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
|
||||
|
||||
gathered = all_gather_eval_result(result)
|
||||
metrics, reference = combine_eval_results(gathered)
|
||||
combined = {
|
||||
"metrics": metrics,
|
||||
"evaluation_time": max(item["evaluation_time"] for item in gathered),
|
||||
"inference_precision": reference["inference_precision"],
|
||||
"compile_target": reference["compile_target"],
|
||||
"compile_mode": reference["compile_mode"],
|
||||
"profile_dir": None,
|
||||
"profile_summary_path": None,
|
||||
"output_filename": cfg.output.filename,
|
||||
}
|
||||
if multi_node_cfg["sync_before_return"]:
|
||||
torch.distributed.barrier()
|
||||
finalize_multi_node_process_group(cfg)
|
||||
if rank != 0:
|
||||
return None
|
||||
return combined
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||
def run(cfg: DictConfig):
|
||||
"""Run evaluation of dinowm vs random policy."""
|
||||
assert (
|
||||
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
||||
), "Planning horizon must be smaller than or equal to eval_budget"
|
||||
|
||||
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
||||
eval_episodes, eval_start_idx = sample_eval_cases(cfg, dataset)
|
||||
output_dir = Path.cwd().resolve()
|
||||
profile_cfg = get_profile_cfg(cfg)
|
||||
|
||||
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
||||
eval_wall_start = time.time()
|
||||
|
||||
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
|
||||
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
|
||||
|
||||
if get_multi_node_cfg(cfg)["enabled"]:
|
||||
eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||
if eval_result is None:
|
||||
return
|
||||
elif get_multi_gpu_cfg(cfg)["enabled"]:
|
||||
if profile_cfg["enabled"]:
|
||||
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
|
||||
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
||||
else:
|
||||
eval_result = run_eval_subset(
|
||||
cfg,
|
||||
eval_episodes.tolist(),
|
||||
eval_start_idx.tolist(),
|
||||
output_dir,
|
||||
)
|
||||
|
||||
metrics = eval_result["metrics"]
|
||||
evaluation_time = eval_result["evaluation_time"]
|
||||
inference_precision = eval_result["inference_precision"]
|
||||
compile_target = eval_result["compile_target"]
|
||||
compile_mode = eval_result["compile_mode"]
|
||||
profile_dir = eval_result["profile_dir"]
|
||||
profile_summary_path = eval_result["profile_summary_path"]
|
||||
output_filename = eval_result.get("output_filename", cfg.output.filename)
|
||||
|
||||
print(metrics)
|
||||
|
||||
results_path = results_path / cfg.output.filename
|
||||
results_path = output_dir / output_filename
|
||||
results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with results_path.open("a") as f:
|
||||
@@ -164,7 +866,17 @@ def run(cfg: DictConfig):
|
||||
|
||||
f.write("==== RESULTS ====\n")
|
||||
f.write(f"metrics: {metrics}\n")
|
||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||||
f.write(f"evaluation_time: {evaluation_time} seconds\n")
|
||||
f.write(f"inference_precision: {inference_precision}\n")
|
||||
f.write(f"inference_compile_target: {compile_target}\n")
|
||||
if compile_target != "disabled":
|
||||
f.write(f"inference_compile_mode: {compile_mode}\n")
|
||||
if profile_cfg["enabled"]:
|
||||
f.write(f"profile_dir: {profile_dir}\n")
|
||||
if profile_summary_path is not None:
|
||||
f.write(f"profile_summary: {profile_summary_path}\n")
|
||||
|
||||
f.write(f"total_wall_time: {time.time() - eval_wall_start} seconds\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
320
jepa.py
320
jepa.py
@@ -2,12 +2,8 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
def detach_clone(v):
|
||||
return v.detach().clone() if torch.is_tensor(v) else v
|
||||
|
||||
class JEPA(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -25,34 +21,124 @@ class JEPA(nn.Module):
|
||||
self.action_encoder = action_encoder
|
||||
self.projector = projector or nn.Identity()
|
||||
self.pred_proj = pred_proj or nn.Identity()
|
||||
self._cached_device_tensors = {}
|
||||
self._cached_init_signature = None
|
||||
self._cached_init_emb = None
|
||||
self._cached_goal_signature = None
|
||||
self._cached_goal_emb = None
|
||||
|
||||
def _ensure_runtime_caches(self):
|
||||
if not hasattr(self, "_cached_device_tensors"):
|
||||
self._cached_device_tensors = {}
|
||||
if not hasattr(self, "_cached_init_signature"):
|
||||
self._cached_init_signature = None
|
||||
if not hasattr(self, "_cached_init_emb"):
|
||||
self._cached_init_emb = None
|
||||
if not hasattr(self, "_cached_goal_signature"):
|
||||
self._cached_goal_signature = None
|
||||
if not hasattr(self, "_cached_goal_emb"):
|
||||
self._cached_goal_emb = None
|
||||
|
||||
@staticmethod
|
||||
def _tensor_signature(tensor: torch.Tensor):
|
||||
try:
|
||||
version = tensor._version
|
||||
except RuntimeError:
|
||||
version = None
|
||||
return (
|
||||
str(tensor.device),
|
||||
tensor.dtype,
|
||||
tuple(tensor.shape),
|
||||
tuple(tensor.stride()),
|
||||
tensor.storage_offset(),
|
||||
tensor.data_ptr(),
|
||||
version,
|
||||
)
|
||||
|
||||
def _get_cached_device_tensor(
|
||||
self,
|
||||
key: str,
|
||||
tensor: torch.Tensor,
|
||||
device: torch.device,
|
||||
*,
|
||||
ensure_contiguous: bool = False,
|
||||
):
|
||||
self._ensure_runtime_caches()
|
||||
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
|
||||
return tensor
|
||||
|
||||
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
|
||||
cached = self._cached_device_tensors.get(key)
|
||||
if cached is None or cached[0] != signature:
|
||||
prepared = tensor.to(device, non_blocking=True)
|
||||
if ensure_contiguous and not prepared.is_contiguous():
|
||||
prepared = prepared.contiguous()
|
||||
self._cached_device_tensors[key] = (
|
||||
signature,
|
||||
prepared,
|
||||
)
|
||||
return self._cached_device_tensors[key][1]
|
||||
|
||||
def _ensure_info_device(self, info_dict: dict, device: torch.device):
|
||||
for key, value in list(info_dict.items()):
|
||||
if key.startswith("_lewm_"):
|
||||
continue
|
||||
if torch.is_tensor(value):
|
||||
info_dict[key] = self._get_cached_device_tensor(
|
||||
key,
|
||||
value,
|
||||
device,
|
||||
ensure_contiguous=True,
|
||||
)
|
||||
return info_dict
|
||||
|
||||
def _get_cached_init_emb(self, info_dict: dict):
|
||||
self._ensure_runtime_caches()
|
||||
pixels = info_dict["pixels"]
|
||||
signature = self._tensor_signature(pixels)
|
||||
if self._cached_init_signature != signature:
|
||||
init_info = {"pixels": pixels[:, 0]}
|
||||
self._cached_init_emb = self.encode(init_info)["emb"].detach()
|
||||
self._cached_init_signature = signature
|
||||
return self._cached_init_emb
|
||||
|
||||
def _get_cached_goal_emb(self, info_dict: dict):
|
||||
self._ensure_runtime_caches()
|
||||
goal = info_dict["goal"]
|
||||
signature = self._tensor_signature(goal)
|
||||
if self._cached_goal_signature != signature:
|
||||
goal_info = {"pixels": goal[:, 0]}
|
||||
self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach()
|
||||
self._cached_goal_signature = signature
|
||||
return self._cached_goal_emb
|
||||
|
||||
def encode(self, info):
|
||||
"""Encode observations and actions into embeddings.
|
||||
info: dict with pixels and action keys
|
||||
"""
|
||||
with torch.profiler.record_function("lewm.encode"):
|
||||
pixels = info['pixels'].float()
|
||||
b, t = pixels.shape[:2]
|
||||
pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
|
||||
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
||||
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
||||
emb = self.projector(pixels_emb)
|
||||
info["emb"] = emb.reshape(b, t, -1)
|
||||
|
||||
pixels = info['pixels'].float()
|
||||
b = pixels.size(0)
|
||||
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
|
||||
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
||||
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
||||
emb = self.projector(pixels_emb)
|
||||
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
|
||||
if "action" in info:
|
||||
info["act_emb"] = self.action_encoder(info["action"])
|
||||
|
||||
if "action" in info:
|
||||
info["act_emb"] = self.action_encoder(info["action"])
|
||||
|
||||
return info
|
||||
return info
|
||||
|
||||
def predict(self, emb, act_emb):
|
||||
"""Predict next state embedding
|
||||
emb: (B, T, D)
|
||||
act_emb: (B, T, A_emb)
|
||||
"""
|
||||
preds = self.predictor(emb, act_emb)
|
||||
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
|
||||
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
|
||||
return preds
|
||||
with torch.profiler.record_function("lewm.predict"):
|
||||
preds = self.predictor(emb, act_emb)
|
||||
preds = self.pred_proj(preds)
|
||||
return preds
|
||||
|
||||
####################
|
||||
## Inference only ##
|
||||
@@ -65,89 +151,151 @@ class JEPA(nn.Module):
|
||||
- S is the number of action plan samples
|
||||
- T is the time horizon
|
||||
"""
|
||||
with torch.profiler.record_function("lewm.rollout"):
|
||||
assert "pixels" in info, "pixels not in info_dict"
|
||||
if history_size < 1:
|
||||
raise ValueError("history_size must be >= 1")
|
||||
|
||||
assert "pixels" in info, "pixels not in info_dict"
|
||||
H = info["pixels"].size(2)
|
||||
B, S, T = action_sequence.shape[:3]
|
||||
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
||||
info["action"] = act_0
|
||||
n_steps = T - H
|
||||
H = info["pixels"].size(2)
|
||||
B, S, T = action_sequence.shape[:3]
|
||||
if T < H:
|
||||
raise ValueError(
|
||||
f"action_sequence horizon ({T}) must be >= history length ({H})"
|
||||
)
|
||||
|
||||
# copy and encode initial info dict
|
||||
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
|
||||
_init = self.encode(_init)
|
||||
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
|
||||
_init = {k: detach_clone(v) for k, v in _init.items()}
|
||||
# Cache the encoded initial state across solver iterations.
|
||||
init_emb = self._get_cached_init_emb(info)
|
||||
HS = history_size
|
||||
hist_len = min(HS, init_emb.size(1), H)
|
||||
if hist_len < 1:
|
||||
raise ValueError("rollout requires at least one history step")
|
||||
|
||||
# flatten batch and sample dimensions for rollout
|
||||
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
|
||||
act = rearrange(act_0, "b s ... -> (b s) ...")
|
||||
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
||||
init_hist = init_emb[:, -hist_len:]
|
||||
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
|
||||
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
|
||||
|
||||
# rollout predictor autoregressively for n_steps
|
||||
HS = history_size
|
||||
for t in range(n_steps):
|
||||
act_emb = self.action_encoder(act)
|
||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
||||
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
|
||||
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
|
||||
action_emb = self.action_encoder(flat_actions)
|
||||
act_hist = action_emb[:, H - hist_len : H]
|
||||
act_future = action_emb[:, H:]
|
||||
|
||||
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
|
||||
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
|
||||
if HS == 1:
|
||||
emb_hist = init_hist[:, -1:]
|
||||
act_emb_hist = act_hist[:, -1:]
|
||||
|
||||
# predict the last state
|
||||
act_emb = self.action_encoder(act) # (BS, T, A_emb)
|
||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
||||
emb = torch.cat([emb, pred_emb], dim=1)
|
||||
for t in range(act_future.size(1)):
|
||||
emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||
act_emb_hist = act_future[:, t : t + 1]
|
||||
|
||||
# unflatten batch and sample dimensions
|
||||
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
|
||||
info["predicted_emb"] = pred_rollout
|
||||
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
||||
else:
|
||||
if torch.is_grad_enabled() and action_sequence.requires_grad:
|
||||
emb_slots = init_hist.split(1, dim=1)
|
||||
act_slots = act_hist.split(1, dim=1)
|
||||
|
||||
return info
|
||||
for t in range(act_future.size(1)):
|
||||
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||
next_act_emb = act_future[:, t : t + 1]
|
||||
|
||||
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
|
||||
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
|
||||
|
||||
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
||||
act_view = torch.cat(act_slots[-HS:], dim=1)
|
||||
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||
info["predicted_emb"] = pred_rollout.reshape(
|
||||
B, S, *pred_rollout.shape[1:]
|
||||
)
|
||||
return info
|
||||
|
||||
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
||||
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
||||
emb_hist[:, :hist_len].copy_(init_hist)
|
||||
act_emb_hist[:, :hist_len].copy_(act_hist)
|
||||
|
||||
history_order = torch.stack(
|
||||
[
|
||||
(torch.arange(HS, device=action_emb.device) + offset) % HS
|
||||
for offset in range(HS)
|
||||
]
|
||||
)
|
||||
filled = hist_len
|
||||
next_slot = hist_len % HS
|
||||
|
||||
for t in range(act_future.size(1)):
|
||||
if filled < HS:
|
||||
emb_view = emb_hist[:, :filled]
|
||||
act_view = act_emb_hist[:, :filled]
|
||||
elif next_slot == 0:
|
||||
emb_view = emb_hist
|
||||
act_view = act_emb_hist
|
||||
else:
|
||||
order = history_order[next_slot]
|
||||
emb_view = emb_hist.index_select(1, order)
|
||||
act_view = act_emb_hist.index_select(1, order)
|
||||
|
||||
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
||||
next_act_emb = act_future[:, t : t + 1]
|
||||
emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb)
|
||||
act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb)
|
||||
|
||||
if filled < HS:
|
||||
filled += 1
|
||||
next_slot = (next_slot + 1) % HS
|
||||
|
||||
if filled < HS:
|
||||
emb_view = emb_hist[:, :filled]
|
||||
act_view = act_emb_hist[:, :filled]
|
||||
elif next_slot == 0:
|
||||
emb_view = emb_hist
|
||||
act_view = act_emb_hist
|
||||
else:
|
||||
order = history_order[next_slot]
|
||||
emb_view = emb_hist.index_select(1, order)
|
||||
act_view = act_emb_hist.index_select(1, order)
|
||||
|
||||
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
||||
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
|
||||
|
||||
return info
|
||||
|
||||
def criterion(self, info_dict: dict):
|
||||
"""Compute the cost between predicted embeddings and goal embeddings."""
|
||||
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
||||
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
||||
with torch.profiler.record_function("lewm.criterion"):
|
||||
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
||||
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
||||
if goal_emb.ndim == pred_emb.ndim - 1:
|
||||
goal_emb = goal_emb.unsqueeze(1)
|
||||
|
||||
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
|
||||
# return last-step cost per action candidate
|
||||
cost = F.mse_loss(
|
||||
pred_emb[..., -1:, :],
|
||||
goal_emb[..., -1:, :].detach(),
|
||||
reduction="none",
|
||||
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
||||
|
||||
# return last-step cost per action candidate
|
||||
cost = F.mse_loss(
|
||||
pred_emb[..., -1:, :],
|
||||
goal_emb[..., -1:, :].detach(),
|
||||
reduction="none",
|
||||
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
||||
|
||||
return cost
|
||||
return cost
|
||||
|
||||
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
||||
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
||||
with torch.profiler.record_function("lewm.get_cost"):
|
||||
assert "goal" in info_dict, "goal not in info_dict"
|
||||
|
||||
assert "goal" in info_dict, "goal not in info_dict"
|
||||
self._ensure_runtime_caches()
|
||||
device = next(self.parameters()).device
|
||||
info_dict = self._ensure_info_device(info_dict, device)
|
||||
action_candidates = self._get_cached_device_tensor(
|
||||
"_lewm_action_candidates",
|
||||
action_candidates,
|
||||
device,
|
||||
ensure_contiguous=True,
|
||||
)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
for k in list(info_dict.keys()):
|
||||
if torch.is_tensor(info_dict[k]):
|
||||
info_dict[k] = info_dict[k].to(device)
|
||||
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
|
||||
info_dict = self.rollout(info_dict, action_candidates)
|
||||
|
||||
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
|
||||
goal["pixels"] = goal["goal"]
|
||||
|
||||
for k in info_dict:
|
||||
if k.startswith("goal_"):
|
||||
goal[k[len("goal_") :]] = goal.pop(k)
|
||||
|
||||
goal.pop("action")
|
||||
goal = self.encode(goal)
|
||||
|
||||
info_dict["goal_emb"] = goal["emb"]
|
||||
info_dict = self.rollout(info_dict, action_candidates)
|
||||
|
||||
cost = self.criterion(info_dict)
|
||||
|
||||
return cost
|
||||
cost = self.criterion(info_dict)
|
||||
|
||||
return cost
|
||||
|
||||
@@ -236,9 +236,13 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B*T, D)
|
||||
x: (..., D)
|
||||
"""
|
||||
return self.net(x)
|
||||
if x.ndim <= 2:
|
||||
return self.net(x)
|
||||
|
||||
output = self.net(x.reshape(-1, x.size(-1)))
|
||||
return output.reshape(*x.shape[:-1], output.size(-1))
|
||||
|
||||
|
||||
class ARPredictor(nn.Module):
|
||||
|
||||
1624
pusht_results.txt
Normal file
1624
pusht_results.txt
Normal file
File diff suppressed because it is too large
Load Diff
131
scripts/convert_hf_checkpoint.py
Normal file
131
scripts/convert_hf_checkpoint.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
"""Convert LeWM HuggingFace weights into eval-compatible object checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import stable_pretraining as spt
|
||||
import torch
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from jepa import JEPA
|
||||
from module import ARPredictor, Embedder, MLP
|
||||
|
||||
|
||||
def _load_json(path: Path) -> dict:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _strip_target(config: dict) -> dict:
|
||||
return {key: value for key, value in config.items() if key != "_target_"}
|
||||
|
||||
|
||||
def infer_config_from_state_dict(state_dict: dict) -> dict:
|
||||
action_dim = state_dict["action_encoder.patch_embed.weight"].shape[1]
|
||||
return {
|
||||
"encoder": {
|
||||
"size": "tiny",
|
||||
"patch_size": 14,
|
||||
"image_size": 224,
|
||||
"pretrained": False,
|
||||
"use_mask_token": False,
|
||||
},
|
||||
"predictor": {
|
||||
"num_frames": 3,
|
||||
"input_dim": 192,
|
||||
"hidden_dim": 192,
|
||||
"output_dim": 192,
|
||||
"depth": 6,
|
||||
"heads": 16,
|
||||
"mlp_dim": 2048,
|
||||
"dim_head": 64,
|
||||
"dropout": 0.1,
|
||||
"emb_dropout": 0.0,
|
||||
},
|
||||
"action_encoder": {
|
||||
"input_dim": action_dim,
|
||||
"emb_dim": 192,
|
||||
},
|
||||
"projector": {
|
||||
"input_dim": 192,
|
||||
"output_dim": 192,
|
||||
"hidden_dim": 2048,
|
||||
},
|
||||
"pred_proj": {
|
||||
"input_dim": 192,
|
||||
"output_dim": 192,
|
||||
"hidden_dim": 2048,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_model(config: dict) -> JEPA:
|
||||
encoder = spt.backbone.utils.vit_hf(**_strip_target(config["encoder"]))
|
||||
predictor = ARPredictor(**_strip_target(config["predictor"]))
|
||||
action_encoder = Embedder(**_strip_target(config["action_encoder"]))
|
||||
|
||||
projector_cfg = _strip_target(config["projector"])
|
||||
projector_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||
projector = MLP(**projector_cfg)
|
||||
|
||||
pred_proj_cfg = _strip_target(config["pred_proj"])
|
||||
pred_proj_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||
pred_proj = MLP(**pred_proj_cfg)
|
||||
|
||||
return JEPA(
|
||||
encoder=encoder,
|
||||
predictor=predictor,
|
||||
action_encoder=action_encoder,
|
||||
projector=projector,
|
||||
pred_proj=pred_proj,
|
||||
)
|
||||
|
||||
|
||||
def convert_checkpoint(input_dir: Path, output_name: str) -> tuple[Path, Path]:
|
||||
config_path = input_dir / "config.json"
|
||||
weights_path = input_dir / "weights.pt"
|
||||
if not weights_path.exists():
|
||||
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
||||
|
||||
state_dict = torch.load(weights_path, map_location="cpu")
|
||||
config = _load_json(config_path) if config_path.exists() else infer_config_from_state_dict(state_dict)
|
||||
model = build_model(config)
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=True)
|
||||
if missing or unexpected:
|
||||
raise RuntimeError(
|
||||
f"State dict mismatch: missing={missing}, unexpected={unexpected}"
|
||||
)
|
||||
model.eval()
|
||||
|
||||
object_path = input_dir / f"{output_name}_object.ckpt"
|
||||
weight_path = input_dir / f"{output_name}_weight.ckpt"
|
||||
torch.save(model, object_path)
|
||||
torch.save(model.state_dict(), weight_path)
|
||||
return object_path, weight_path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
type=Path,
|
||||
help="Directory containing weights.pt and optionally config.json.",
|
||||
)
|
||||
parser.add_argument("--output-name", default="lewm")
|
||||
args = parser.parse_args()
|
||||
|
||||
object_path, weight_path = convert_checkpoint(args.input_dir, args.output_name)
|
||||
print(f"wrote {object_path}")
|
||||
print(f"wrote {weight_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
255
scripts/launch_multinode_eval.sh
Executable file
255
scripts/launch_multinode_eval.sh
Executable file
@@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch 2-node LeWM evaluation from node-3.
|
||||
#
|
||||
# Defaults match the current cluster layout:
|
||||
# node-3: 10.16.200.9, node_rank=0
|
||||
# node-2: 10.16.200.8, node_rank=1
|
||||
# Each node runs two local torchrun processes for two visible GPUs.
|
||||
|
||||
REPO_ROOT="${REPO_ROOT:-/home/lewm/lewm}"
|
||||
REMOTE_HOST="${REMOTE_HOST:-lewm@10.16.200.8}"
|
||||
MASTER_ADDR="${MASTER_ADDR:-10.16.200.9}"
|
||||
MASTER_PORT="${MASTER_PORT:-29500}"
|
||||
|
||||
NNODES="${NNODES:-2}"
|
||||
NPROC_PER_NODE="${NPROC_PER_NODE:-2}"
|
||||
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1}"
|
||||
STABLEWM_HOME="${STABLEWM_HOME:-/home/lewm/.stable-wm}"
|
||||
|
||||
CONFIG_NAME="${CONFIG_NAME:-pusht.yaml}"
|
||||
POLICY="${POLICY:-pusht/lewm}"
|
||||
OUTPUT_FILENAME="${OUTPUT_FILENAME:-pusht_multinode_results.txt}"
|
||||
EXTRA_ARGS="${EXTRA_ARGS:-}"
|
||||
DRY_RUN="${DRY_RUN:-0}"
|
||||
TAIL_LOGS="${TAIL_LOGS:-1}"
|
||||
PRELOAD_WAIT="${PRELOAD_WAIT:-0}"
|
||||
PRELOAD_SIGNAL_FILE="${PRELOAD_SIGNAL_FILE:-/tmp/lewm_preload_start}"
|
||||
PRELOAD_CLEAR_SIGNAL="${PRELOAD_CLEAR_SIGNAL:-1}"
|
||||
|
||||
LOG_DIR="${LOG_DIR:-${REPO_ROOT}/logs/multinode}"
|
||||
mkdir -p "${LOG_DIR}"
|
||||
RUN_ID="$(date +%Y%m%d_%H%M%S)"
|
||||
LOCAL_LOG="${LOG_DIR}/${RUN_ID}_node3_rank0.log"
|
||||
REMOTE_LOG="${LOG_DIR}/${RUN_ID}_node2_rank1.log"
|
||||
|
||||
SSH_OPTS=(
|
||||
-F /dev/null
|
||||
-o StrictHostKeyChecking=no
|
||||
-o ServerAliveInterval=30
|
||||
-o ServerAliveCountMax=20
|
||||
)
|
||||
|
||||
COMMON_ARGS=(
|
||||
"--config-name=${CONFIG_NAME}"
|
||||
"policy=${POLICY}"
|
||||
"multi_node.enabled=true"
|
||||
"output.filename=${OUTPUT_FILENAME}"
|
||||
)
|
||||
|
||||
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
|
||||
COMMON_ARGS+=(
|
||||
"preload_wait.enabled=true"
|
||||
"preload_wait.file=${PRELOAD_SIGNAL_FILE}"
|
||||
)
|
||||
fi
|
||||
|
||||
if [[ -n "${EXTRA_ARGS}" ]]; then
|
||||
# shellcheck disable=SC2206
|
||||
COMMON_ARGS+=(${EXTRA_ARGS})
|
||||
fi
|
||||
|
||||
make_command() {
|
||||
local node_rank="$1"
|
||||
local repo_q cuda_q stablewm_q arg_q eval_args
|
||||
printf -v repo_q '%q' "${REPO_ROOT}"
|
||||
printf -v cuda_q '%q' "${CUDA_VISIBLE_DEVICES}"
|
||||
printf -v stablewm_q '%q' "${STABLEWM_HOME}"
|
||||
|
||||
eval_args=""
|
||||
for arg in "${COMMON_ARGS[@]}"; do
|
||||
printf -v arg_q '%q' "${arg}"
|
||||
eval_args+=" ${arg_q}"
|
||||
done
|
||||
|
||||
printf 'cd %s && source .venv/bin/activate && export CUDA_VISIBLE_DEVICES=%s && export STABLEWM_HOME=%s && torchrun --nnodes=%q --nproc_per_node=%q --node_rank=%q --master_addr=%q --master_port=%q eval.py%s' \
|
||||
"${repo_q}" \
|
||||
"${cuda_q}" \
|
||||
"${stablewm_q}" \
|
||||
"${NNODES}" \
|
||||
"${NPROC_PER_NODE}" \
|
||||
"${node_rank}" \
|
||||
"${MASTER_ADDR}" \
|
||||
"${MASTER_PORT}" \
|
||||
"${eval_args}"
|
||||
}
|
||||
|
||||
REMOTE_CMD="$(make_command 1)"
|
||||
LOCAL_CMD="$(make_command 0)"
|
||||
printf -v REMOTE_CMD_Q '%q' "${REMOTE_CMD}"
|
||||
|
||||
REMOTE_PID=""
|
||||
LOCAL_PID=""
|
||||
LOCAL_TAIL_PID=""
|
||||
REMOTE_TAIL_PID=""
|
||||
REMOTE_CLEANUP_CMD=""
|
||||
REMOTE_CLEANUP_CMD_Q=""
|
||||
|
||||
start_log_tail() {
|
||||
local label="$1"
|
||||
local log_file="$2"
|
||||
local label_q log_q
|
||||
|
||||
printf -v label_q '%q' "${label}"
|
||||
printf -v log_q '%q' "${log_file}"
|
||||
setsid bash -lc "tail -n +1 -F ${log_q} 2>/dev/null | sed -u 's/^/[${label_q}] /'" &
|
||||
}
|
||||
|
||||
stop_log_tails() {
|
||||
local pid
|
||||
for pid in "${LOCAL_TAIL_PID}" "${REMOTE_TAIL_PID}"; do
|
||||
if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then
|
||||
kill -TERM "-${pid}" 2>/dev/null || kill -TERM "${pid}" 2>/dev/null || true
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
remote_cleanup_command() {
|
||||
local pattern_q
|
||||
local patterns=(
|
||||
"torchrun .*--master_addr=${MASTER_ADDR} .*--master_port=${MASTER_PORT} .*eval.py"
|
||||
"torchrun .*--master_port=${MASTER_PORT} .*eval.py"
|
||||
"python.*eval.py .*output.filename=${OUTPUT_FILENAME}"
|
||||
)
|
||||
|
||||
printf 'set +e; '
|
||||
for pattern in "${patterns[@]}"; do
|
||||
printf -v pattern_q '%q' "${pattern}"
|
||||
printf 'pkill -TERM -f %s 2>/dev/null; ' "${pattern_q}"
|
||||
done
|
||||
printf 'sleep 2; '
|
||||
for pattern in "${patterns[@]}"; do
|
||||
printf -v pattern_q '%q' "${pattern}"
|
||||
printf 'pkill -KILL -f %s 2>/dev/null; ' "${pattern_q}"
|
||||
done
|
||||
printf 'true'
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
local status="$?"
|
||||
trap - INT TERM EXIT
|
||||
|
||||
if [[ "${status}" -eq 0 ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo
|
||||
echo "Stopping multi-node eval..."
|
||||
stop_log_tails
|
||||
|
||||
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
|
||||
kill -TERM "-${LOCAL_PID}" 2>/dev/null || kill -TERM "${LOCAL_PID}" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
if [[ -n "${REMOTE_PID}" ]] && kill -0 "${REMOTE_PID}" 2>/dev/null; then
|
||||
kill -TERM "${REMOTE_PID}" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CLEANUP_CMD_Q}" >/dev/null 2>&1 || true
|
||||
|
||||
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
|
||||
sleep 2
|
||||
kill -KILL "-${LOCAL_PID}" 2>/dev/null || kill -KILL "${LOCAL_PID}" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
echo "Cleanup requested. Check logs if any process was already exiting:"
|
||||
echo " local: ${LOCAL_LOG}"
|
||||
echo " remote: ${REMOTE_LOG}"
|
||||
exit "${status}"
|
||||
}
|
||||
|
||||
trap cleanup INT TERM EXIT
|
||||
|
||||
REMOTE_CLEANUP_CMD="$(remote_cleanup_command)"
|
||||
printf -v REMOTE_CLEANUP_CMD_Q '%q' "${REMOTE_CLEANUP_CMD}"
|
||||
|
||||
echo "Launching multi-node eval"
|
||||
echo " master: ${MASTER_ADDR}:${MASTER_PORT}"
|
||||
echo " remote: ${REMOTE_HOST}"
|
||||
echo " repo: ${REPO_ROOT}"
|
||||
echo " stablewm: ${STABLEWM_HOME}"
|
||||
echo " config: ${CONFIG_NAME}"
|
||||
echo " policy: ${POLICY}"
|
||||
echo " output: ${OUTPUT_FILENAME}"
|
||||
echo " extra: ${EXTRA_ARGS:-<none>}"
|
||||
echo " tail logs: ${TAIL_LOGS}"
|
||||
echo " preload wait: ${PRELOAD_WAIT}"
|
||||
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
|
||||
echo " preload signal: ${PRELOAD_SIGNAL_FILE}"
|
||||
echo " start command: touch ${PRELOAD_SIGNAL_FILE}"
|
||||
fi
|
||||
echo " local log: ${LOCAL_LOG}"
|
||||
echo " remote log: ${REMOTE_LOG}"
|
||||
|
||||
if [[ "${DRY_RUN}" == "1" ]]; then
|
||||
echo
|
||||
echo "Remote command:"
|
||||
echo "ssh ${SSH_OPTS[*]} ${REMOTE_HOST} bash -lc ${REMOTE_CMD_Q}"
|
||||
echo
|
||||
echo "Local command:"
|
||||
printf -v LOCAL_CMD_Q '%q' "${LOCAL_CMD}"
|
||||
echo "bash -lc ${LOCAL_CMD_Q}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "${PRELOAD_WAIT}" == "1" && "${PRELOAD_CLEAR_SIGNAL}" == "1" ]]; then
|
||||
rm -f "${PRELOAD_SIGNAL_FILE}"
|
||||
fi
|
||||
|
||||
echo "Starting remote node_rank=1..."
|
||||
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CMD_Q}" >"${REMOTE_LOG}" 2>&1 &
|
||||
REMOTE_PID="$!"
|
||||
|
||||
if [[ "${TAIL_LOGS}" == "1" ]]; then
|
||||
start_log_tail "node2" "${REMOTE_LOG}"
|
||||
REMOTE_TAIL_PID="$!"
|
||||
fi
|
||||
|
||||
sleep 3
|
||||
|
||||
echo "Starting local node_rank=0..."
|
||||
set +e
|
||||
setsid bash -lc "${LOCAL_CMD}" >"${LOCAL_LOG}" 2>&1 &
|
||||
LOCAL_PID="$!"
|
||||
|
||||
if [[ "${TAIL_LOGS}" == "1" ]]; then
|
||||
start_log_tail "node3" "${LOCAL_LOG}"
|
||||
LOCAL_TAIL_PID="$!"
|
||||
fi
|
||||
|
||||
wait "${LOCAL_PID}"
|
||||
LOCAL_STATUS="$?"
|
||||
|
||||
wait "${REMOTE_PID}"
|
||||
REMOTE_STATUS="$?"
|
||||
set -e
|
||||
|
||||
stop_log_tails
|
||||
trap - INT TERM EXIT
|
||||
|
||||
echo "Local status: ${LOCAL_STATUS}"
|
||||
echo "Remote status: ${REMOTE_STATUS}"
|
||||
echo "Local log: ${LOCAL_LOG}"
|
||||
echo "Remote log: ${REMOTE_LOG}"
|
||||
|
||||
if [[ "${LOCAL_STATUS}" -ne 0 || "${REMOTE_STATUS}" -ne 0 ]]; then
|
||||
echo "Multi-node eval failed. Tail logs:"
|
||||
echo "===== local tail ====="
|
||||
tail -80 "${LOCAL_LOG}" || true
|
||||
echo "===== remote tail ====="
|
||||
tail -80 "${REMOTE_LOG}" || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Multi-node eval complete."
|
||||
111
scripts/warmup_eval.sh
Executable file
111
scripts/warmup_eval.sh
Executable file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Warm up LeWM evaluation before a formal run.
|
||||
#
|
||||
# This script intentionally does a small eval for each task so ROCm/PyTorch can
|
||||
# initialize GPU contexts, compile predictor graphs, populate kernel caches, and
|
||||
# touch dataset/checkpoint paths before the timed run.
|
||||
#
|
||||
# Site-specific things to check before using this at the competition:
|
||||
# 1. STABLEWM_HOME points to the directory containing datasets/checkpoints.
|
||||
# 2. The policy names below match the checkpoint folders at STABLEWM_HOME.
|
||||
# 3. The dataset names in config/eval/*.yaml match the onsite dataset files.
|
||||
# 4. The GPU visibility variables match the GPUs allocated to this job.
|
||||
# 5. WARMUP_NUM_EVAL is close enough to the formal shape to trigger useful
|
||||
# compilation, but small enough not to waste much time.
|
||||
|
||||
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
cd "${REPO_ROOT}"
|
||||
|
||||
PYTHON_BIN="${PYTHON_BIN:-${REPO_ROOT}/.venv/bin/python}"
|
||||
STABLEWM_HOME="${STABLEWM_HOME:-/mnt/ASC1637/stablewm}"
|
||||
export STABLEWM_HOME
|
||||
|
||||
# If Slurm allocates multiple GPUs, set these to the allocated physical GPU ids.
|
||||
# Example for physical GPU 2 and 3:
|
||||
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1
|
||||
#
|
||||
# Important ROCm detail:
|
||||
# ROCR_VISIBLE_DEVICES uses physical ids.
|
||||
# HIP_VISIBLE_DEVICES/CUDA_VISIBLE_DEVICES use ids after ROCR remapping.
|
||||
export ROCR_VISIBLE_DEVICES="${ROCR_VISIBLE_DEVICES:-0}"
|
||||
export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-0}"
|
||||
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
||||
|
||||
WARMUP_NUM_EVAL="${WARMUP_NUM_EVAL:-10}"
|
||||
INFERENCE_PRECISION="${INFERENCE_PRECISION:-fp16}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/lewm_warmup}"
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
# Enable multi-GPU warmup by setting MULTI_GPU=1.
|
||||
# MULTI_GPU_DEVICES are process-local ids, not physical ids after ROCR remapping.
|
||||
# Example:
|
||||
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]'
|
||||
MULTI_GPU="${MULTI_GPU:-0}"
|
||||
MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}"
|
||||
MULTI_NODE="${MULTI_NODE:-0}"
|
||||
|
||||
# Multi-node warmup uses the same eval.py entrypoint under torchrun.
|
||||
# Example:
|
||||
# torchrun --nnodes=2 --nproc_per_node=2 --node_rank=0 --master_addr=<ip> --master_port=29500 \
|
||||
# eval.py --config-name=pusht.yaml policy=pusht/lewm multi_node.enabled=true
|
||||
# This script leaves multi-node launch to the caller.
|
||||
|
||||
COMMON_ARGS=(
|
||||
"eval.num_eval=${WARMUP_NUM_EVAL}"
|
||||
"inference_precision=${INFERENCE_PRECISION}"
|
||||
)
|
||||
|
||||
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||
COMMON_ARGS+=(
|
||||
"+multi_gpu.enabled=true"
|
||||
"+multi_gpu.devices=${MULTI_GPU_DEVICES}"
|
||||
)
|
||||
fi
|
||||
|
||||
if [[ "${MULTI_NODE}" == "1" ]]; then
|
||||
COMMON_ARGS+=(
|
||||
"multi_node.enabled=true"
|
||||
)
|
||||
fi
|
||||
|
||||
run_warmup() {
|
||||
local config_name="$1"
|
||||
local policy="$2"
|
||||
local output_name="$3"
|
||||
|
||||
echo
|
||||
echo "== Warmup ${config_name} policy=${policy} =="
|
||||
"${PYTHON_BIN}" eval.py \
|
||||
"--config-name=${config_name}" \
|
||||
"policy=${policy}" \
|
||||
"output.filename=${OUTPUT_DIR}/${output_name}" \
|
||||
"${COMMON_ARGS[@]}"
|
||||
}
|
||||
|
||||
echo "LeWM warmup"
|
||||
echo " repo: ${REPO_ROOT}"
|
||||
echo " python: ${PYTHON_BIN}"
|
||||
echo " STABLEWM_HOME: ${STABLEWM_HOME}"
|
||||
echo " ROCR_VISIBLE_DEVICES: ${ROCR_VISIBLE_DEVICES}"
|
||||
echo " HIP_VISIBLE_DEVICES: ${HIP_VISIBLE_DEVICES}"
|
||||
echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
|
||||
echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}"
|
||||
echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}"
|
||||
echo " MULTI_GPU: ${MULTI_GPU}"
|
||||
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||
echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}"
|
||||
fi
|
||||
echo " MULTI_NODE: ${MULTI_NODE}"
|
||||
|
||||
# Defaults match the checkpoint names used in this repo. If onsite checkpoint
|
||||
# folders differ, override by editing these calls or passing the equivalent
|
||||
# eval.py command manually.
|
||||
run_warmup "pusht.yaml" "pusht/lewm" "warmup_pusht.txt"
|
||||
run_warmup "reacher.yaml" "reacher/lewm" "warmup_reacher.txt"
|
||||
run_warmup "cube.yaml" "cube/lewm" "warmup_cube.txt"
|
||||
run_warmup "tworoom.yaml" "tworoom/lewm" "warmup_tworoom.txt"
|
||||
|
||||
echo
|
||||
echo "Warmup complete. Logs were appended under ${OUTPUT_DIR}."
|
||||
52
sth.md
Normal file
52
sth.md
Normal file
@@ -0,0 +1,52 @@
|
||||
我建议优先做这 4 类,都是跨数据集成立的:
|
||||
|
||||
1. 压 rollout 内环实现
|
||||
见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规
|
||||
模 predict 调用,这种碎片化实现对任何任务都亏。
|
||||
通用改法:
|
||||
|
||||
- 整条 action_sequence 一次性做 action_encoder
|
||||
- emb_hist / act_emb_hist 改成预分配 buffer
|
||||
- 循环里只做索引覆盖或 copy_
|
||||
- 去掉循环内 torch.cat
|
||||
|
||||
2. 减少热路径里的搬运和同步
|
||||
profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看
|
||||
jepa.py:67 和 jepa.py:186。
|
||||
通用目标:
|
||||
|
||||
- 模型侧张量尽量全程留在 GPU
|
||||
- 避免热路径反复 .to(device) / 隐式 layout 修复
|
||||
- 到必须和环境交互的边界再一次性转 CPU / numpy
|
||||
- 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
|
||||
|
||||
3. 把编译成本移出正式计时
|
||||
现在 torch.compile 默认开在 predictor,见 eval.py:70。102s -> 45s 很
|
||||
像首轮编译预热。
|
||||
通用做法:
|
||||
|
||||
- 在正式 start_time 前做一次 dummy predict 或 dummy rollout
|
||||
- 保留只编译 predictor/predict,不要编译整个 solver
|
||||
|
||||
4. 减少临时对象和 shape bookkeeping
|
||||
这是所有任务都会受益的。
|
||||
重点看:
|
||||
|
||||
- jepa.py:100 到 jepa.py:106
|
||||
- jepa.py:143 到 jepa.py:148
|
||||
方向是:
|
||||
- 能循环外做的 reshape,不放循环里
|
||||
- 能原地更新,不新建张量
|
||||
- 少做 dict 字段增删和中间容器组装
|
||||
|
||||
不建议优先做的通用性较差方案:
|
||||
|
||||
- 调 TwoRoom 专属 cache 规则
|
||||
- 改数据集采样逻辑
|
||||
- 按小数据集特点缩短 horizon
|
||||
- 直接改 CEM 超参当“优化”
|
||||
|
||||
如果你要我直接开始改,我建议第一批只做两件事:
|
||||
|
||||
- 重写 jepa.py:127 这段 rollout,去掉循环内 action_encoder + cat
|
||||
- 在 eval.py:306 前加 compile warmup
|
||||
Reference in New Issue
Block a user