Compare commits

32 Commits
main ... qhy

Author SHA1 Message Date
qihuanye
a639fdefca 没啥用 2026-05-18 02:09:19 +08:00
qihuanye
28f2fba0e8 加入一个提前停止的机制 还有减少环境步中间步骤传递至cpu 2026-05-18 00:48:59 +08:00
qihuanye
113e591899 多机调整 2026-05-17 20:49:33 +08:00
qihuanye
0164e21f48 多机 2026-05-17 19:23:31 +08:00
qihuanye
02080e2564 在正式测试前添加warm up 2026-05-16 14:53:58 +00:00
qihuanye
d86aeb2df0 修改默认求解器 2026-05-14 08:53:10 +00:00
qihuanye
5e55727901 增加脚本 2026-05-14 04:27:10 +00:00
qihuanye
02c3cea3f9 amd构建说明 2026-05-14 03:52:50 +00:00
qihuanye
f08f2b82f4 Parameter Tuning 2026-05-04 08:20:23 +00:00
qihuanye
e84074d6d6 调整ignore的文件 2026-05-04 08:05:47 +00:00
qihuanye
cf43af0729 更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平 2026-05-04 07:55:13 +00:00
qihuanye
4c3fdbcce6 调参 2026-05-04 07:01:33 +00:00
qihuanye
75a5d86966 调整配置 启动视频写入开关 2026-04-10 03:44:09 +00:00
qihuanye
46cb2177bc pusht数据集配置,优化后测试结果 2026-04-10 03:40:20 +00:00
qihuanye
8ba5bc8b0b 多卡 2026-04-10 03:13:54 +00:00
qihuanye
e6f2b2b9d4 调高batch_size 2026-04-09 13:17:39 +00:00
qihuanye
25e4ddb628 继续做了通用性能优化,重点从 jepa.py 热路径转到实际的 stable_worldmodel
solver/policy 边界:去掉 CEM 每轮 cpu().tolist() 和结果过早回 CPU,把
  plan/warm-start 保持在 GPU,只在 env.step 前最后一步转成 numpy,同时补
  了输入张量的 contiguous 处理;
2026-04-09 12:33:50 +00:00
qihuanye
995cd8cfec 优化 jepa.py 中通用 rollout 热路径:批量预编码动
作、移除循环内
  torch.cat,并为 history_size==1 与环形缓冲区更新
  添加更轻量实现; 收益不大
2026-04-09 11:57:09 +00:00
qihuanye
cd03a0d5cb 补充结果 2026-04-09 11:11:07 +00:00
qihuanye
20ffb3492b Disable Gym passive checker by default in stable_worldmodel env creation 2026-04-09 11:11:07 +00:00
qihuanye
96e17a13af 补充结果 2026-04-09 11:11:07 +00:00
qihuanye
006102d00c 减少循环里的张量形状重排和临时对象 2026-04-09 11:11:07 +00:00
qihuanye
3a94829eac 补充评测结果 2026-04-09 11:11:07 +00:00
qihuanye
38be7d3bef Optimize inference path: add predictor-only torch.compile with reduce-overhead 2026-04-09 11:11:07 +00:00
qihuanye
f2750daace 取消视频保存 2026-04-09 11:11:07 +00:00
qihuanye
9e2407cdc4 Wrap eval inference in torch.inference_mode 2026-04-09 11:11:07 +00:00
qihuanye
0f85e39690 Reduce evaluation overhead with parallel video saving 2026-04-08 13:56:57 +00:00
qihuanye
85795bd91d Vectorize image preprocessing in stable_worldmodel policy 2026-04-08 13:48:19 +00:00
qihuanye
7c2e341d93 fp16 2026-04-08 13:40:33 +00:00
qihuanye
12ba4f4352 Optimize CEM input transfers before sample expansion 2026-04-08 13:01:24 +00:00
qihuanye
fa1c15c896 Optimize JEPA eval outputs and inference hot path 2026-04-08 12:41:21 +00:00
qihuanye
8b84251eb9 add profile frame and bf15/fp16 switch 2026-03-31 11:09:02 +00:00
28 changed files with 6938 additions and 156 deletions

34
.gitignore vendored Normal file
View File

@@ -0,0 +1,34 @@
.venv/*
!.venv/.gitignore
!.venv/lib/
.venv/lib/*
!.venv/lib/python3.10/
.venv/lib/python3.10/*
!.venv/lib/python3.10/site-packages/
.venv/lib/python3.10/site-packages/*
!.venv/lib/python3.10/site-packages/stable_worldmodel/
.venv/lib/python3.10/site-packages/stable_worldmodel/*
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/
!.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
!.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
outputs/
__pycache__/
*.py[cod]
.pytest_cache/
.mypy_cache/
.ruff_cache/
torch_profile/
trace.json
key_averages.txt
eval_tmp_*.npy
*.mp4
*.gif
.DS_Store
.idea/
.vscode/
*.log

0
.venv/.gitignore vendored Normal file
View File

View File

@@ -0,0 +1,628 @@
from collections import deque
from dataclasses import dataclass
import inspect
from pathlib import Path
from typing import Any, Protocol
from collections.abc import Callable
import numpy as np
import torch
from loguru import logger as logging
import stable_worldmodel as swm
from stable_worldmodel.solver import Solver
@dataclass(frozen=True)
class PlanConfig:
"""Configuration for the MPC planning loop.
Attributes:
horizon: Planning horizon in number of steps.
receding_horizon: Number of steps to execute before re-planning.
history_len: Number of past observations to consider.
action_block: Number of times each action is repeated (frameskip).
warm_start: Whether to use the previous plan to initialize the next one.
"""
horizon: int
receding_horizon: int
history_len: int = 1
action_block: int = 1
warm_start: bool = True
@property
def plan_len(self) -> int:
"""Total plan length in environment steps."""
return self.horizon * self.action_block
class Transformable(Protocol):
"""Protocol for reversible data transformations (e.g., normalizers, scalers)."""
def transform(self, x: np.ndarray) -> np.ndarray: # pragma: no cover
"""Apply preprocessing to input data.
Args:
x: Input data as a numpy array.
Returns:
Preprocessed data as a numpy array.
"""
...
def inverse_transform(
self, x: np.ndarray
) -> np.ndarray: # pragma: no cover
"""Reverse the preprocessing transformation.
Args:
x: Preprocessed data as a numpy array.
Returns:
Original data as a numpy array.
"""
...
class Actionable(Protocol):
"""Protocol for model action computation."""
def get_action(info) -> torch.Tensor: # pragma: no cover
"""Compute action from observation and goal"""
...
class BasePolicy:
"""Base class for agent policies.
Attributes:
env: The environment the policy is associated with.
type: A string identifier for the policy type.
"""
env: Any
type: str
def __init__(self, **kwargs: Any) -> None:
"""Initialize the base policy.
Args:
**kwargs: Additional configuration parameters.
"""
self.env = None
self.type = 'base'
for arg, value in kwargs.items():
setattr(self, arg, value)
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
"""Get action from the policy given the observation.
Args:
obs: The current observation from the environment.
**kwargs: Additional parameters for action selection.
Returns:
Selected action as a numpy array.
Raises:
NotImplementedError: If not implemented by a subclass.
"""
raise NotImplementedError
def set_env(self, env: Any) -> None:
"""Associate this policy with an environment.
Args:
env: The environment to associate.
"""
self.env = env
def _move_info_to_device(
self, info_dict: dict[str, Any], device: torch.device | str
) -> dict[str, Any]:
target = torch.device(device)
for k, v in info_dict.items():
if torch.is_tensor(v):
if v.device != target:
v = v.to(target, non_blocking=True)
if not v.is_contiguous():
v = v.contiguous()
info_dict[k] = v
return info_dict
def _prepare_info_for_device(
self,
info_dict: dict,
device: torch.device | str | None = None,
) -> dict[str, torch.Tensor]:
if device is None:
return self._prepare_info(info_dict)
prepare_info = self._prepare_info
try:
signature = inspect.signature(prepare_info)
except (TypeError, ValueError):
return prepare_info(info_dict)
accepts_device = any(
param.kind == inspect.Parameter.VAR_KEYWORD
or (param.name == "device" and param.kind in {
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
})
for param in signature.parameters.values()
)
if accepts_device:
return prepare_info(info_dict, device=device)
return prepare_info(info_dict)
def _prepare_info(
self,
info_dict: dict,
device: torch.device | str | None = None,
) -> dict[str, torch.Tensor]:
"""Pre-process and transform observations.
Applies preprocessing (via `self.process`) and transformations (via `self.transform`)
to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy.
Args:
info_dict: Raw observation dictionary from the environment.
Returns:
A dictionary of processed tensors.
Raises:
ValueError: If an expected numpy array is missing for processing.
"""
target_device = torch.device(device) if device is not None else None
for k, v in info_dict.items():
is_numpy = isinstance(v, (np.ndarray | np.generic))
if hasattr(self, 'process') and k in self.process:
if not is_numpy:
raise ValueError(
f"Expected numpy array for key '{k}' in process, got {type(v)}"
)
# flatten extra dimensions if needed
shape = v.shape
if len(shape) > 2:
v = v.reshape(-1, *shape[2:])
# process and reshape back
v = self.process[k].transform(v)
v = v.reshape(shape)
# collapse env and time dimensions for transform (e, t, ...) -> (e * t, ...)
# then restore after transform
if hasattr(self, 'transform') and k in self.transform:
shape = None
if is_numpy or torch.is_tensor(v):
if v.ndim > 2:
shape = v.shape
v = v.reshape(-1, *shape[2:])
if is_numpy:
if v.dtype.kind in 'USO':
raise ValueError(
f"Expected numeric numpy array for key '{k}', got dtype {v.dtype}"
)
v = torch.from_numpy(v)
is_numpy = False
moved_for_transform = False
if target_device is not None and target_device.type != "cpu":
if v.device != target_device:
v = v.to(target_device, non_blocking=True)
moved_for_transform = True
if k.startswith('pixels') or k.startswith('goal'):
# Vectorized image transform on the full batch.
v = v.permute(0, 3, 1, 2)
try:
v = self.transform[k](v)
except (NotImplementedError, RuntimeError):
if not moved_for_transform:
raise
v = self.transform[k](v.cpu())
if shape is not None:
v = v.reshape(*shape[:2], *v.shape[1:])
if is_numpy and v.dtype.kind not in 'USO':
v = torch.from_numpy(v)
if (
torch.cuda.is_available()
and torch.is_tensor(v)
and v.device.type == "cpu"
and not v.is_pinned()
):
v = v.pin_memory()
info_dict[k] = v
return info_dict
class RandomPolicy(BasePolicy):
"""Policy that samples random actions from the action space."""
def __init__(self, seed: int | None = None, **kwargs: Any) -> None:
"""Initialize the random policy.
Args:
seed: Optional random seed for the action space.
**kwargs: Additional configuration parameters.
"""
super().__init__(**kwargs)
self.type = 'random'
self.seed = seed
def get_action(self, obs: Any, **kwargs: Any) -> np.ndarray:
"""Get a random action from the environment's action space.
Args:
obs: The current observation (ignored).
**kwargs: Additional parameters (ignored).
Returns:
A randomly sampled action.
"""
return self.env.action_space.sample()
def set_seed(self, seed: int) -> None:
"""Set the random seed for action sampling.
Args:
seed: The seed value.
"""
if self.env is not None:
self.env.action_space.seed(seed)
class ExpertPolicy(BasePolicy):
"""Policy using expert demonstrations or heuristics."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the expert policy.
Args:
**kwargs: Additional configuration parameters.
"""
super().__init__(**kwargs)
self.type = 'expert'
def get_action(
self, obs: Any, goal_obs: Any, **kwargs: Any
) -> np.ndarray | None:
"""Get action from the expert policy.
Args:
obs: The current observation.
goal_obs: The goal observation.
**kwargs: Additional parameters.
Returns:
The expert action, or None if not available.
"""
# Implement expert policy logic here
pass
class FeedForwardPolicy(BasePolicy):
"""Feed-Forward Policy using a neural network model.
Actions are computed via a single forward pass through the model.
Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).
Attributes:
model: Neural network model implementing the Actionable protocol.
process: Dictionary of data preprocessors for specific keys.
transform: Dictionary of tensor transformations (e.g., image transforms).
"""
def __init__(
self,
model: Actionable,
process: dict[str, Transformable] | None = None,
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
| None = None,
**kwargs: Any,
) -> None:
"""Initialize the feed-forward policy.
Args:
model: Neural network model with a `get_action` method.
process: Dictionary of data preprocessors for specific keys.
transform: Dictionary of tensor transformations (e.g., image transforms).
**kwargs: Additional configuration parameters.
"""
super().__init__(**kwargs)
self.type = 'feed_forward'
self.model = model.eval()
self.process = process or {}
self.transform = transform or {}
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
"""Get action via a forward pass through the neural network model.
Args:
info_dict: Current state information containing at minimum a 'goal' key.
**kwargs: Additional parameters (unused).
Returns:
The selected action as a numpy array.
Raises:
AssertionError: If environment not set or 'goal' not in info_dict.
"""
assert hasattr(self, 'env'), 'Environment not set for the policy'
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
# Prepare the info dict (transforms and normalizes inputs)
info_dict = self._prepare_info_for_device(
info_dict, device=next(self.model.parameters()).device
)
# Add goal_pixels key for GCBC model
if 'goal' in info_dict:
info_dict['goal_pixels'] = info_dict['goal']
# Move all tensors to the model's device
device = next(self.model.parameters()).device
info_dict = self._move_info_to_device(info_dict, device)
# Get action from model
with torch.no_grad():
action = self.model.get_action(info_dict)
# Convert to numpy
if torch.is_tensor(action):
action = action.cpu().detach().numpy()
# post-process action
if 'action' in self.process:
action = self.process['action'].inverse_transform(action)
return action
class WorldModelPolicy(BasePolicy):
"""Policy using a world model and planning solver for action selection."""
def __init__(
self,
solver: Solver,
config: PlanConfig,
process: dict[str, Transformable] | None = None,
transform: dict[str, Callable[[torch.Tensor], torch.Tensor]]
| None = None,
**kwargs: Any,
) -> None:
"""Initialize the world model policy.
Args:
solver: The planning solver to use.
config: MPC planning configuration.
process: Dictionary of data preprocessors for specific keys.
transform: Dictionary of tensor transformations (e.g., image transforms).
**kwargs: Additional configuration parameters.
"""
super().__init__(**kwargs)
self.type = 'world_model'
self.cfg = config
self.solver = solver
self.action_buffer: deque[torch.Tensor] = deque(
maxlen=self.flatten_receding_horizon
)
self.process = process or {}
self.transform = transform or {}
self._action_buffer: deque[torch.Tensor] | None = None
self._next_init: torch.Tensor | None = None
@property
def flatten_receding_horizon(self) -> int:
"""Receding horizon in environment steps (with frameskip)."""
return self.cfg.receding_horizon * self.cfg.action_block
def set_env(self, env: Any) -> None:
"""Configure the policy and solver for the given environment.
Args:
env: The environment to associate with the policy.
"""
self.env = env
n_envs = getattr(env, 'num_envs', 1)
self.solver.configure(
action_space=env.action_space, n_envs=n_envs, config=self.cfg
)
self._action_buffer = deque(maxlen=self.flatten_receding_horizon)
assert isinstance(self.solver, Solver), (
'Solver must implement the Solver protocol'
)
def _normalize_active_mask(self, active_mask: Any) -> np.ndarray | None:
if active_mask is None:
return None
active_mask = np.asarray(active_mask, dtype=bool)
if active_mask.ndim != 1 or active_mask.shape[0] != self.env.num_envs:
raise ValueError(
f"active_mask must have shape ({self.env.num_envs},), got {active_mask.shape}"
)
return active_mask
def get_action(self, info_dict: dict, **kwargs: Any) -> np.ndarray:
"""Get action via planning with the world model.
Args:
info_dict: Current state information from the environment.
**kwargs: Additional parameters for planning.
Returns:
The selected action(s) as a numpy array.
"""
assert hasattr(self, 'env'), 'Environment not set for the policy'
active_mask = self._normalize_active_mask(kwargs.get("active_mask"))
# need to replan if action buffer is empty
if len(self._action_buffer) == 0:
assert 'pixels' in info_dict, "'pixels' must be provided in info_dict"
assert 'goal' in info_dict, "'goal' must be provided in info_dict"
info_dict = self._prepare_info_for_device(
info_dict, device=self.solver.device
)
info_dict = self._move_info_to_device(info_dict, self.solver.device)
outputs = self.solver(
info_dict,
init_action=self._next_init,
active_mask=active_mask,
)
actions = outputs['actions'] # (num_envs, horizon, action_dim)
if active_mask is not None and actions.shape[0] != self.env.num_envs:
full_actions = torch.zeros(
self.env.num_envs,
actions.shape[1],
actions.shape[2],
dtype=actions.dtype,
device=actions.device,
)
full_actions[torch.as_tensor(active_mask, device=actions.device)] = actions
actions = full_actions
keep_horizon = self.cfg.receding_horizon
plan = actions[:, :keep_horizon]
rest = actions[:, keep_horizon:]
self._next_init = rest.contiguous() if self.cfg.warm_start else None
# frameskip back to timestep
plan = plan.reshape(
self.env.num_envs, self.flatten_receding_horizon, -1
).contiguous()
self._action_buffer.extend(plan.transpose(0, 1).unbind(0))
action = self._action_buffer.popleft()
action = action.reshape(*self.env.action_space.shape)
if active_mask is not None:
if torch.is_tensor(action):
inactive_mask = torch.as_tensor(
~active_mask, device=action.device, dtype=torch.bool
)
action = action.clone()
action[inactive_mask] = 0
else:
action = np.array(action, copy=True)
action[~active_mask] = 0
if torch.is_tensor(action):
action = action.detach().cpu().numpy()
else:
action = np.asarray(action)
# post-process action
if 'action' in self.process:
action = self.process['action'].inverse_transform(action)
return action # (num_envs, action_dim)
def _load_model_with_attribute(run_name, attribute_name, cache_dir=None):
"""Helper function to load a model checkpoint and find a module with the specified attribute.
Args:
run_name: Path or name of the model run
attribute_name: Name of the attribute to look for in the module (e.g., 'get_action', 'get_cost')
cache_dir: Optional cache directory path
Returns:
The module with the specified attribute
Raises:
RuntimeError: If no module with the specified attribute is found
"""
if Path(run_name).exists():
run_path = Path(run_name)
else:
run_path = Path(cache_dir or swm.data.utils.get_cache_dir(), run_name)
if run_path.is_dir():
ckpt_files = list(run_path.glob('*_object.ckpt'))
ckpt_files.sort(key=lambda x: x.stat().st_ctime, reverse=True)
path = ckpt_files[0]
logging.info(f'Loading model from checkpoint: {path}')
else:
path = Path(f'{run_path}_object.ckpt')
assert path.exists(), (
f'Checkpoint path does not exist: {path}. Launch pretraining first.'
)
spt_module = torch.load(path, weights_only=False, map_location='cpu')
def scan_module(module):
if hasattr(module, attribute_name):
if isinstance(module, torch.nn.Module):
module = module.eval()
return module
for child in module.children():
result = scan_module(child)
if result is not None:
return result
return None
result = scan_module(spt_module)
if result is not None:
return result
raise RuntimeError(
f"No module with '{attribute_name}' found in the loaded world model."
)
def AutoActionableModel(
run_name: str, cache_dir: str | Path | None = None
) -> torch.nn.Module:
"""Load a model checkpoint and return the module with a `get_action` method.
Automatically scans the checkpoint for a module implementing the Actionable
protocol (i.e., has a `get_action` method).
Args:
run_name: Path or name of the model run/checkpoint.
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
Returns:
The module with a `get_action` method, set to eval mode.
Raises:
RuntimeError: If no module with `get_action` is found in the checkpoint.
"""
return _load_model_with_attribute(run_name, 'get_action', cache_dir)
def AutoCostModel(
run_name: str, cache_dir: str | Path | None = None
) -> torch.nn.Module:
"""Load a model checkpoint and return the module with a `get_cost` method.
Automatically scans the checkpoint for a module implementing a cost function
(i.e., has a `get_cost` method) for use with planning solvers.
Args:
run_name: Path or name of the model run/checkpoint.
cache_dir: Optional cache directory path. Defaults to STABLEWM_HOME.
Returns:
The module with a `get_cost` method, set to eval mode.
Raises:
RuntimeError: If no module with `get_cost` is found in the checkpoint.
"""
return _load_model_with_attribute(run_name, 'get_cost', cache_dir)
# Alias for backward compatibility and type hinting
Policy = BasePolicy

View File

@@ -0,0 +1,17 @@
from .cem import CEMSolver
from .gd import GradientSolver
from .icem import ICEMSolver
from .lagrangian import LagrangianSolver
from .mppi import MPPISolver
from .solver import Solver
from .discrete_solvers import PGDSolver
__all__ = [
'Solver',
'GradientSolver',
'CEMSolver',
'ICEMSolver',
'PGDSolver',
'MPPISolver',
'LagrangianSolver',
]

View File

@@ -0,0 +1,277 @@
"""Cross Entropy Method solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class CEMSolver:
"""Cross Entropy Method solver for action optimization.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for the action distribution.
n_steps: Number of CEM iterations.
topk: Number of elite samples to keep for distribution update.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(f"Action space is discrete, got {type(action_space)}. CEMSolver may not work as expected.")
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
@staticmethod
def _normalize_active_mask(
active_mask: torch.Tensor | np.ndarray | None,
n_envs: int,
device: torch.device,
) -> torch.Tensor | None:
if active_mask is None:
return None
if not torch.is_tensor(active_mask):
active_mask = torch.as_tensor(active_mask, dtype=torch.bool, device=device)
else:
active_mask = active_mask.to(device=device, dtype=torch.bool)
if active_mask.ndim != 1 or active_mask.shape[0] != n_envs:
raise ValueError(
f"active_mask must have shape ({n_envs},), got {tuple(active_mask.shape)}"
)
return active_mask
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
device = torch.device(self.device)
var = self.var_scale * torch.ones(
[self.n_envs, self.horizon, self.action_dim],
device=device,
)
mean = (
torch.zeros([self.n_envs, 0, self.action_dim], device=device)
if actions is None
else actions
)
remaining = self.horizon - mean.shape[1]
if remaining > 0:
new_mean = torch.zeros(
[self.n_envs, remaining, self.action_dim],
device=mean.device,
)
mean = torch.cat([mean, new_mean], dim=1)
return mean, var
@torch.inference_mode()
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning problem using Cross Entropy Method."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [], # History of means
"var": [], # History of vars
}
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
if mean.device != torch.device(self.device):
mean = mean.to(self.device, non_blocking=True)
if var.device != torch.device(self.device):
var = var.to(self.device, non_blocking=True)
active_mask = self._normalize_active_mask(
active_mask, self.n_envs, torch.device(self.device)
)
if active_mask is not None and not torch.any(active_mask):
return {
"costs": [],
"actions": mean.detach(),
"mean": [mean.detach()],
"var": [var.detach()],
}
total_envs = self.n_envs
# --- Iterate over batches ---
for start_idx in range(0, total_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, total_envs)
current_bs = end_idx - start_idx
# Slice Distribution Parameters for current batch
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
# Expand Info Dict for current batch
expanded_infos = {}
for k, v in info_dict.items():
# v is shape (n_envs, ...)
# Slice batch
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
if v_batch.device != self.device:
v_batch = v_batch.to(self.device, non_blocking=True)
# Add sample dim: (batch, 1, ...)
v_batch = v_batch.unsqueeze(1)
# Expand: (batch, num_samples, ...)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
if active_mask is not None:
batch_mask = active_mask[start_idx:end_idx]
if not torch.any(batch_mask):
outputs["costs"].append(
torch.full((current_bs,), float("nan"), device=self.device)
)
continue
active_local = torch.nonzero(batch_mask, as_tuple=False).squeeze(1)
active_local_np = active_local.detach().cpu().numpy()
batch_mean = batch_mean[active_local]
batch_var = batch_var[active_local]
expanded_infos = {
k: (v[active_local] if torch.is_tensor(v) else v[active_local_np])
for k, v in expanded_infos.items()
}
current_bs = int(active_local.numel())
else:
active_local = None
# Optimization Loop
final_batch_cost = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1)
for step in range(self.n_steps):
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
candidates = torch.randn(
current_bs,
self.num_samples,
self.horizon,
self.action_dim,
generator=self.torch_gen,
device=self.device,
)
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
# Force the first sample to be the current mean
candidates[:, 0] = batch_mean
current_info = expanded_infos.copy()
# Evaluate candidates
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
# Select Top-K
# topk_vals: (Batch, K), topk_inds: (Batch, K)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
# Gather Top-K Candidates
# We need to select the specific candidates corresponding to topk_inds
# Indexing: candidates[batch_idx, sample_idx]
# Result shape: (Batch, K, Horizon, Dim)
topk_candidates = candidates[batch_indices, topk_inds]
# Update Mean and Variance based on Top-K
batch_mean = topk_candidates.mean(dim=1)
batch_var = topk_candidates.std(dim=1)
# Update final cost for logging
# We average the cost of the top elites
final_batch_cost = topk_vals.mean(dim=1).detach()
# Write results back to global storage
if active_mask is not None:
global_indices = start_idx + active_local
mean[global_indices] = batch_mean
var[global_indices] = batch_var
batch_costs = torch.full(
(end_idx - start_idx,), float("nan"), device=self.device
)
batch_costs[active_local] = final_batch_cost
else:
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
batch_costs = final_batch_cost
# Store history/metadata
outputs["costs"].append(batch_costs)
if outputs["costs"]:
outputs["costs"] = torch.cat(outputs["costs"]).cpu().tolist()
else:
outputs["costs"] = []
outputs["actions"] = mean.detach()
outputs["mean"] = [mean.detach()]
outputs["var"] = [var.detach()]
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
return outputs

View File

@@ -0,0 +1,256 @@
"""Projected Gradient Descent solver for discrete action spaces."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Discrete
from .solver import Costable
class PGDSolver(torch.nn.Module):
"""Projected Gradient Descent solver for discrete action optimization.
Args:
model: World model implementing the Costable protocol.
n_steps: Number of gradient descent iterations.
batch_size: Number of environments to process in parallel.
var_scale: Initial variance scale for action perturbations.
num_samples: Number of action samples to optimize in parallel.
action_noise: Noise added to actions during optimization.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self._configured = False
self._n_envs = None
self._action_dim = None
self._action_simplex_dim = None
self._config = None
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
assert isinstance(action_space, Discrete), f"Action space must be discrete, got {type(action_space)}"
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._action_simplex_dim = int(action_space.n)
self._configured = True
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def action_simplex_dim(self) -> int:
"""Simplex dimension for discrete action probabilities."""
return self._action_simplex_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(
self, actions: torch.Tensor | None = None, from_scalar: bool = False
) -> None:
"""Initialize the action tensor for optimization."""
if actions is None:
actions = torch.zeros((self._n_envs, 0, self.action_simplex_dim))
elif from_scalar:
# convert scalar to one-hot
actions = torch.nn.functional.one_hot(actions, num_classes=self._action_simplex_dim).to(torch.float32)
# merge action_block dim
actions = actions.reshape(*actions.shape[:-2], self.action_simplex_dim)
assert (
actions.shape[0] == self._n_envs
and actions.shape[1] <= self.horizon
and actions.shape[2] == self.action_simplex_dim
)
# fill remaining action
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(self._n_envs, remaining, self.action_simplex_dim)
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
actions = actions.unsqueeze(1).repeat_interleave(self.num_samples, dim=1) # add sample dim
actions[:, 1:] += (
torch.randn(actions[:, 1:].shape, generator=self.torch_gen, device=self.device) * self.var_scale
) # add small noise to all samples except the first one
# reset actions
if hasattr(self, "init"):
self.init.copy_(actions)
else:
self.register_parameter("init", torch.nn.Parameter(actions))
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
from_scalar: bool = False,
) -> dict:
"""Solve the planning problem using projected gradient descent."""
start_time = time.time()
outputs = {
"cost": [], # Will store list of cost histories per batch
"actions": None,
}
with torch.no_grad():
self.init_action(init_action, from_scalar=from_scalar)
# Determine batch size (default to all envs if not specified which can cause memory issues)
batch_size = self.batch_size if self.batch_size is not None else self.n_envs
total_envs = self.n_envs
# Lists to hold results from each batch to be concatenated later
batch_top_actions_list = []
# --- Outer Loop: Iterate over batches ---
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
optim = torch.optim.SGD([batch_init], lr=1.0)
# Prepare Batch Infos
# Slice the input info_dict and then expand dimensions
expanded_infos = {}
for k, v in info_dict.items():
# Slice the data for the current batch indices
# Assumes input data dim 0 corresponds to n_envs
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(current_bs, self.num_samples, *batch_v.shape[2:])
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(batch_v[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = batch_v
# Perform Gradient Descent for this batch
batch_cost_history = []
for step in range(self.n_steps):
current_info = expanded_infos.copy()
# Calculate cost using the batch parameter
costs = self.model.get_cost(current_info, batch_init)
assert isinstance(costs, torch.Tensor), f"Got {type(costs)} cost, expect torch.Tensor"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
assert costs.requires_grad, "Cost must requires_grad for PGD solver."
cost = costs.sum() # Sum cost for this batch
cost.backward()
optim.step()
optim.zero_grad(set_to_none=True)
# Add noise
if self.action_noise > 0:
batch_init.data += torch.randn(batch_init.shape, generator=self.torch_gen) * self.action_noise
# projection onto simplex
with torch.no_grad():
batch_init.copy_(self._project_action_simplex(batch_init))
batch_cost_history.append(cost.item())
# Store cost history for this batch
outputs["cost"].append(batch_cost_history)
# Update the global self.init with the optimized batch values
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = torch.argsort(costs, dim=1)[:, 0]
batch_indices = torch.arange(current_bs)
top_actions_batch = batch_init[batch_indices, top_idx]
# convert one-hot back to discrete actions
top_actions_batch = self._factor_action_block(top_actions_batch).argmax(dim=-1)
batch_top_actions_list.append(top_actions_batch.detach().cpu())
# Concatenate all batch results
outputs["actions"] = torch.cat(batch_top_actions_list, dim=0)
end_time = time.time()
print(f"PGDSolver.solve completed in {end_time - start_time:.4f} seconds.")
return outputs
def _factor_action_block(self, actions: torch.Tensor) -> torch.Tensor:
"""Factor the action block dimension from action_simplex_dim."""
original_shape = actions.shape
action_block = self._config.action_block
simplex_dim = self._action_simplex_dim
return actions.reshape(*original_shape[:-1], action_block, simplex_dim)
def _project_action_simplex(self, actions: torch.Tensor) -> torch.Tensor:
"""Project the action onto the probability simplex."""
original_shape = actions.shape
s = self._factor_action_block(actions).reshape(-1, self._action_simplex_dim)
mu, _ = torch.sort(s, descending=True, dim=-1)
cumulative = mu.cumsum(dim=-1)
d = s.size(-1)
indices = torch.arange(1, d + 1, device=s.device, dtype=s.dtype)
threshold = (cumulative - 1) / indices
cond = (mu > threshold).to(torch.int32)
rho = cond.cumsum(dim=-1)
valid_rho = rho * cond
rho_max = valid_rho.max(dim=-1, keepdim=True)[0]
rho_min = torch.clamp(rho_max, min=1)
psi = (cumulative.gather(-1, rho_min - 1) - 1) / rho_min
projected = torch.clamp(s - psi, min=0.0).reshape(original_shape)
return projected

View File

@@ -0,0 +1,252 @@
"""Gradient-based solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class GradientSolver(torch.nn.Module):
"""Gradient-based solver using backpropagation through the world model.
Args:
model: World model implementing the Costable protocol.
n_steps: Number of gradient descent iterations.
batch_size: Number of environments to process in parallel.
var_scale: Initial variance scale for action perturbations.
num_samples: Number of action samples to optimize in parallel.
action_noise: Noise added to actions during optimization.
device: Device for tensor computations.
seed: Random seed for reproducibility.
optimizer_cls: PyTorch optimizer class to use.
optimizer_kwargs: Keyword arguments for the optimizer.
"""
def __init__(
self,
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | torch.device = 'cpu',
seed: int = 1234,
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD,
optimizer_kwargs: dict | None = None,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = (
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
)
self._configured = False
self._n_envs = None
self._action_dim = None
self._config = None
def configure(
self, *, action_space: gym.Space, n_envs: int, config: Any
) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.'
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(self, actions: torch.Tensor | None = None) -> None:
"""Initialize the action tensor for optimization."""
device = torch.device(self.device)
if actions is None:
actions = torch.zeros(
(self._n_envs, 0, self.action_dim), device=device
)
elif actions.device != device:
actions = actions.to(device, non_blocking=True)
# fill remaining action
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(
self._n_envs,
remaining,
self.action_dim,
device=actions.device,
dtype=actions.dtype,
)
actions = torch.cat([actions, new_actions], dim=1)
actions = actions.unsqueeze(1).repeat_interleave(
self.num_samples, dim=1
) # add sample dim
actions[:, 1:] += (
torch.randn(
actions[:, 1:].shape,
generator=self.torch_gen,
device=self.device,
)
* self.var_scale
) # add small noise to all samples except the first one
# reset actions
if hasattr(self, 'init'):
self.init.copy_(actions)
else:
self.register_parameter('init', torch.nn.Parameter(actions))
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using gradient descent."""
start_time = time.time()
outputs = {
'cost': [], # Will store list of cost histories per batch
'actions': None,
}
with torch.no_grad():
self.init_action(init_action)
# Determine batch size (default to all envs if not specified which can cause memory issues)
batch_size = (
self.batch_size if self.batch_size is not None else self.n_envs
)
total_envs = self.n_envs
# Lists to hold results from each batch to be concatenated later
batch_top_actions_list = []
# --- Outer Loop: Iterate over batches ---
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
# We initialize the optimizer class passed in __init__ with the kwargs
optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs)
# Prepare Batch Infos
# Slice the input info_dict and then expand dimensions
expanded_infos = {}
for k, v in info_dict.items():
# Slice the data for the current batch indices
# Assumes input data dim 0 corresponds to n_envs
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(
current_bs, self.num_samples, *batch_v.shape[2:]
)
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(
batch_v[:, None, ...], self.num_samples, axis=1
)
expanded_infos[k] = batch_v
final_batch_cost = None
for step in range(self.n_steps):
current_info = expanded_infos.copy()
# Calculate cost using the batch parameter
costs = self.model.get_cost(current_info, batch_init)
assert isinstance(costs, torch.Tensor), (
f'Got {type(costs)} cost, expect torch.Tensor'
)
assert (
costs.ndim == 2
and costs.shape[0] == current_bs
and costs.shape[1] == self.num_samples
), (
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
)
assert costs.requires_grad, (
'Cost must requires_grad for GD solver.'
)
cost = costs.sum() # Sum cost for this batch
cost.backward()
optim.step()
optim.zero_grad(set_to_none=True)
# Add noise
if self.action_noise > 0:
batch_init.data += (
torch.randn(
batch_init.shape,
generator=self.torch_gen,
device=self.device,
)
* self.action_noise
)
final_batch_cost = costs.detach().min(dim=1).values
# Store cost history for this batch
outputs['cost'].append(final_batch_cost)
# Update the global self.init with the optimized batch values
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = costs.argmin(dim=1)
batch_indices = torch.arange(current_bs, device=self.device)
top_actions_batch = batch_init[batch_indices, top_idx]
batch_top_actions_list.append(top_actions_batch.detach())
# Concatenate all batch results
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist()
end_time = time.time()
print(
f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.'
)
return outputs

View File

@@ -0,0 +1,219 @@
"""Improved Cross Entropy Method (iCEM) solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class ICEMSolver:
"""Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention.
iCEM improves the sample efficiency over standard CEM and was introduced by
[1] for real-time planning.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for the action distribution.
n_steps: Number of CEM iterations.
topk: Number of elite samples to keep for distribution update.
noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
alpha: Momentum for mean/std EMA update.
n_elite_keep: Number of elites carried from previous iteration.
return_mean: If False, return best single trajectory instead of mean.
device: Device for tensor computations.
seed: Random seed for reproducibility.
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and
G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning".
Conference on Robot Learning, 2020.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
noise_beta: float = 2.0,
alpha: float = 0.1,
n_elite_keep: int = 5,
return_mean: bool = True,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.noise_beta = noise_beta
self.alpha = alpha
self.n_elite_keep = n_elite_keep
self.return_mean = return_mean
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if isinstance(action_space, Box):
self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32)
self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32)
else:
logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.")
self._action_low = None
self._action_high = None
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using improved Cross Entropy Method."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [],
"var": [],
}
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
for start_idx in range(0, self.n_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, self.n_envs)
current_bs = end_idx - start_idx
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
expanded_infos = {}
for k, v in info_dict.items():
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
v_batch = v_batch.unsqueeze(1)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
prev_topk_candidates = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Precompute FFT scale for colored noise
noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon)
freqs = torch.fft.rfftfreq(self.horizon, device=self.device)
freqs[0] = 1.0
noise_scale = freqs.pow(-self.noise_beta / 2)
noise_scale[0] = noise_scale[1]
for step in range(self.n_steps):
# Colored noise: generate with temporal axis last, then transpose
if self.horizon <= 1:
noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
else:
white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
fft = torch.fft.rfft(white, dim=-1)
colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1)
std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8)
noise = colored / std
noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim)
candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
candidates[:, 0] = batch_mean
# Inject previous elites
if prev_topk_candidates is not None:
n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1])
candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject]
# Clip to action bounds
if self._action_low is not None:
candidates = candidates.clamp(self._action_low, self._action_high)
current_info = expanded_infos.copy()
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
topk_candidates = candidates[batch_indices, topk_inds]
prev_topk_candidates = topk_candidates
# Momentum update
elite_mean = topk_candidates.mean(dim=1)
elite_var = topk_candidates.std(dim=1)
batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean
batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
if self.return_mean:
mean[start_idx:end_idx] = batch_mean
else:
mean[start_idx:end_idx] = topk_candidates[:, 0]
var[start_idx:end_idx] = batch_var
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"iCEM solve time: {time.time() - start_time:.4f} seconds")
return outputs

View File

@@ -0,0 +1,360 @@
"""Lagrangian solver for stable world model."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class LagrangianSolver(torch.nn.Module):
"""Lagrangian solver for stable world model.
get_cost returns the cost tensor (B, S). If the model also implements get_constraints,
it should return the constraint violations (B, S, C), where C is the number of constraints.
The constraint_cost should represent the cost of violating the constraints, where the constraint
is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
Args:
model: World model implementing the Costable protocol. Its get_cost() returns
a plain cost tensor (B, S). If it also has get_constraints(), that method
returns constraints of shape (B, S, C).
n_steps: Number of gradient descent steps per outer iteration.
n_outer_steps: Number of dual ascent (outer) iterations.
batch_size: Number of environments to process in parallel.
num_samples: Number of action samples to optimize in parallel.
var_scale: Initial variance scale for action perturbations.
action_noise: Noise added to actions during optimization.
rho_init: Initial penalty coefficient for the quadratic constraint term.
rho_max: Maximum value of the penalty coefficient.
rho_scale: Multiplicative growth factor for rho after each outer step.
persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls.
device: Device for tensor computations.
seed: Random seed for reproducibility.
optimizer_cls: PyTorch optimizer class to use.
optimizer_kwargs: Keyword arguments for the optimizer.
"""
def __init__(
self,
model: Costable,
n_steps: int,
n_outer_steps: int = 5,
batch_size: int | None = None,
num_samples: int = 1,
var_scale: float = 1.0,
action_noise: float = 0.0,
rho_init: float = 1.0,
rho_max: float = 1e4,
rho_scale: float = 2.0,
persist_multipliers: bool = True,
device: str | torch.device = 'cpu',
seed: int = 1234,
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: dict | None = None,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.n_outer_steps = n_outer_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.rho_init = rho_init
self.rho_max = rho_max
self.rho_scale = rho_scale
self.persist_multipliers = persist_multipliers
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = (
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
)
self._configured = False
self._n_envs = None
self._action_dim = None
self._config = None
self._lambdas: torch.Tensor | None = None # (n_envs, C)
def configure(
self, *, action_space: gym.Space, n_envs: int, config: Any
) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.'
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(self, actions: torch.Tensor | None = None) -> None:
"""Initialize the action tensor for optimization."""
if actions is None:
actions = torch.zeros((self._n_envs, 0, self.action_dim))
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
actions = actions.unsqueeze(1).repeat_interleave(
self.num_samples, dim=1
)
actions[:, 1:] += (
torch.randn(
actions[:, 1:].shape,
generator=self.torch_gen,
device=self.device,
)
* self.var_scale
)
if hasattr(self, 'init'):
self.init.copy_(actions)
else:
self.register_parameter('init', torch.nn.Parameter(actions))
def _init_multipliers(self, num_constraints: int) -> None:
"""Lazily initialize Lagrange multipliers to zeros."""
self._lambdas = torch.zeros(
self._n_envs, num_constraints, device=self.device
)
def _augmented_lagrangian_loss(
self,
costs: torch.Tensor, # (B, S)
constraints: torch.Tensor, # (B, S, C)
lambdas_batch: torch.Tensor, # (B, C)
rho: float,
) -> torch.Tensor:
"""Compute the augmented Lagrangian loss.
L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2
"""
# lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C)
linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum(
dim=-1
) # (B, S)
quadratic_penalty = rho * F.relu(constraints).pow(2).sum(
dim=-1
) # (B, S)
return (costs + linear_penalty + quadratic_penalty).sum()
def _update_multipliers(
self,
constraints: torch.Tensor, # (B, S, C) — detached, no grad
lambdas_batch: torch.Tensor, # (B, C)
rho: float,
) -> torch.Tensor:
"""Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i))."""
mean_g = constraints.mean(dim=1) # (B, C)
return torch.clamp(lambdas_batch + rho * mean_g, min=0.0)
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using augmented Lagrangian gradient descent."""
start_time = time.time()
outputs: dict = {
'cost': [],
'constraint_violation': [],
'actions': None,
'lambdas': None,
}
with torch.no_grad():
self.init_action(init_action)
if not self.persist_multipliers:
self._lambdas = None
batch_size = (
self.batch_size if self.batch_size is not None else self.n_envs
)
total_envs = self.n_envs
batch_top_actions_list = []
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
# Expand info_dict for current batch — same pattern as GradientSolver
expanded_infos = {}
for k, v in info_dict.items():
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(
current_bs, self.num_samples, *batch_v.shape[2:]
)
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(
batch_v[:, None, ...], self.num_samples, axis=1
)
else:
batch_v = v
expanded_infos[k] = batch_v
rho = self.rho_init
batch_cost_history = []
costs = None
final_constraints = None
for _outer in range(self.n_outer_steps):
# Fresh optimizer each outer step — avoids stale momentum after dual ascent
optim = self.optimizer_cls(
[batch_init], **self.optimizer_kwargs
)
for _step in range(self.n_steps):
current_info = expanded_infos.copy()
costs = self.model.get_cost(current_info, batch_init)
constraints = (
self.model.get_constraints(
expanded_infos.copy(), batch_init
)
if hasattr(self.model, 'get_constraints')
else None
)
assert isinstance(costs, torch.Tensor), (
f'Got {type(costs)} cost, expect torch.Tensor'
)
assert costs.ndim == 2 and costs.shape == (
current_bs,
self.num_samples,
), (
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
)
assert costs.requires_grad, (
'Cost must requires_grad for LagrangianSolver.'
)
if constraints is not None:
assert constraints.ndim == 3 and constraints.shape[
:2
] == (current_bs, self.num_samples), (
f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}'
)
if self._lambdas is None:
self._init_multipliers(constraints.shape[-1])
lambdas_batch = self._lambdas[start_idx:end_idx]
loss = self._augmented_lagrangian_loss(
costs, constraints, lambdas_batch, rho
)
else:
loss = costs.sum()
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
if self.action_noise > 0:
batch_init.data += (
torch.randn(
batch_init.shape, generator=self.torch_gen
)
* self.action_noise
)
batch_cost_history.append(loss.item())
# Dual ascent after inner loop converges
if constraints is not None:
with torch.no_grad():
final_constraints = self.model.get_constraints(
expanded_infos.copy(), batch_init
)
lambdas_batch = self._update_multipliers(
final_constraints, lambdas_batch, rho
)
self._lambdas[start_idx:end_idx] = lambdas_batch
rho = min(self.rho_max, rho * self.rho_scale)
with torch.no_grad():
mean_cost = costs.mean().item()
if constraints is not None:
viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,)
lam = lambdas_batch.mean(dim=0) # (C,)
viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist())
lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist())
print(
f' [outer {_outer+1}/{self.n_outer_steps}] '
f'cost={mean_cost:.4f} | '
f'constraint_viol=[{viol_str}] | '
f'lambdas=[{lam_str}] | '
f'rho={rho:.4f}'
)
else:
print(
f' [outer {_outer+1}/{self.n_outer_steps}] '
f'cost={mean_cost:.4f}'
)
outputs['cost'].append(batch_cost_history)
if final_constraints is not None:
outputs['constraint_violation'].append(
F.relu(final_constraints).mean().item()
)
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = torch.argsort(costs, dim=1)[:, 0]
batch_indices = torch.arange(current_bs)
top_actions_batch = batch_init[batch_indices, top_idx]
batch_top_actions_list.append(top_actions_batch.detach().cpu())
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
outputs['lambdas'] = (
self._lambdas.cpu() if self._lambdas is not None else None
)
constraint_info = ''
if outputs['constraint_violation']:
mean_viol = np.mean(outputs['constraint_violation'])
constraint_info = f' | constraint_violation={mean_viol:.4f}'
print(
f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.'
)
return outputs

View File

@@ -0,0 +1,208 @@
"""Model Predictive Path Integral solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class MPPISolver:
"""Model Predictive Path Integral solver for action optimization.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for action noise.
n_steps: Number of MPPI iterations.
topk: Number of elite samples for weighted averaging.
temperature: Temperature parameter for softmax weighting.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1.0,
n_steps: int = 30,
topk: int = 30,
temperature: float = 0.5,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.num_samples = num_samples
self.topk = topk
self.var_scale = var_scale
self.n_steps = n_steps
self.temperature = temperature
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f"Action space is discrete, got {type(action_space)}. MPPISolver may not work as expected."
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using MPPI."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [],
"var": [],
}
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
total_envs = self.n_envs
# --- Iterate over batches ---
for start_idx in range(0, total_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, total_envs)
current_bs = end_idx - start_idx
# Slice Distribution Parameters for current batch
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
# Expand Info Dict for current batch (Same as CEM)
expanded_infos = {}
for k, v in info_dict.items():
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
# Add sample dim: (batch, 1, ...)
v_batch = v_batch.unsqueeze(1)
# Expand: (batch, num_samples, ...)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
# Optimization Loop
final_batch_cost = None
for step in range(self.n_steps):
# Sample noise: (Batch, Num_Samples, Horizon, Dim)
noise = torch.randn(
current_bs,
self.num_samples,
self.horizon,
self.action_dim,
generator=self.torch_gen,
device=self.device,
)
# MPPI Logic: candidates = mean + noise * sigma
candidates = batch_mean.unsqueeze(1) + noise * batch_var.unsqueeze(1)
# Force the first sample to be the current mean (Zero noise)
candidates[:, 0] = batch_mean
# Evaluate candidates
costs = self.model.get_cost(expanded_infos, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
# Select Elites (Optional, based on topk)
if self.topk is not None and self.topk < self.num_samples:
# topk_vals: (Batch, K), topk_inds: (Batch, K)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
# Gather Top-K Candidates
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# (Batch, K, Horizon, Dim)
relevant_candidates = candidates[batch_indices, topk_inds]
relevant_costs = topk_vals
else:
relevant_candidates = candidates
relevant_costs = costs
# MPPI Weighting: Softmax(-cost / temperature)
# Stabilize softmax by subtracting min cost
min_cost = relevant_costs.min(dim=1, keepdim=True)[0]
scaled_costs = relevant_costs - min_cost
weights = torch.softmax(-scaled_costs / self.temperature, dim=1) # (Batch, K)
# Update Mean: weighted sum of candidates
# Reshape weights for broadcasting: (Batch, K, 1, 1)
weights_expanded = weights.unsqueeze(-1).unsqueeze(-1)
batch_mean = (weights_expanded * relevant_candidates).sum(dim=1)
# Store average cost of the utilized samples for logging
final_batch_cost = relevant_costs.mean(dim=1).cpu().tolist()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
# We do not update var in standard MPPI
# Store history/metadata
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"MPPI solve time: {time.time() - start_time:.4f} seconds")
return outputs

View File

@@ -0,0 +1,80 @@
from typing import Any, Protocol, runtime_checkable
import gymnasium as gym
import numpy as np
import torch
class Costable(Protocol):
"""Protocol for world model cost functions."""
def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor:
"""Compute the cost criterion for action candidates.
Args:
info_dict: Dictionary containing environment state information.
action_candidates: Tensor of proposed actions.
Returns:
A tensor of cost values for each action candidate.
"""
...
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""Compute cost for given action candidates based on info dictionary.
Args:
info_dict: Dictionary containing environment state information.
action_candidates: Tensor of proposed actions.
Returns:
A tensor of cost values for each action candidate.
"""
...
@runtime_checkable
class Solver(Protocol):
"""Protocol for model-based planning solvers."""
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment and planning specifications.
Args:
action_space: The action space of the environment.
n_envs: Number of parallel environments.
config: Planning configuration object.
"""
...
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
...
@property
def n_envs(self) -> int:
"""Number of parallel environments being planned for."""
...
@property
def horizon(self) -> int:
"""Planning horizon length in timesteps."""
...
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
active_mask: torch.Tensor | np.ndarray | None = None,
) -> dict:
"""Solve the planning optimization problem to find optimal actions.
Args:
info_dict: Dictionary containing environment state information.
init_action: Optional initial action sequence to warm-start the solver.
Returns:
Dictionary containing optimized actions and other solver-specific info.
"""
...

File diff suppressed because it is too large Load Diff

241
AMD_SETUP.md Normal file
View File

@@ -0,0 +1,241 @@
# AMD ROCm 环境配置说明
这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留
`torch.compile` 时的 PyTorch 版本选择。
目标运行命令:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm
```
## 已验证环境
本次验证通过的环境:
- Ubuntu 24.04
- AMD Radeon PRO W7900D (`gfx1100`)
- 系统 ROCm 7.1.1
- Python 3.10
- `torch==2.10.0+rocm7.1`
- `torchvision==0.25.0+rocm7.1`
- `triton-rocm==3.6.0`
注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU但在本项目里开启
`torch.compile` 后会崩溃,错误类似:
```text
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
CUDA error: unspecified launch failure
```
降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。
## 检查系统 ROCm
在新 AMD 机器上,先确认系统能识别 GPU
```bash
rocminfo
amd-smi version
hipcc --version
```
`rocminfo` 里应该能看到 AMD GPU agent例如 `gfx1100`
## 创建 Python 环境
使用 `uv` 创建 Python 3.10 虚拟环境:
```bash
cd /path/to/lewm
uv venv --python 3.10 --allow-existing .venv
source .venv/bin/activate
```
给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住,
用 pip 安装大 wheel 更容易观察进度。
```bash
uv pip install pip
```
## 安装 ROCm 版 PyTorch
安装本项目已验证可用的 ROCm PyTorch 组合:
```bash
python -m pip install --force-reinstall \
--index-url https://download.pytorch.org/whl/rocm7.1 \
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"torch==2.10.0" \
"torchvision==0.25.0"
```
PyTorch wheel 有数 GB。如果网络慢不要频繁中断重试尽量等它下载完成。
## 安装项目依赖
普通 Python 包建议走国内 PyPI 镜像:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"gymnasium[all]==1.2.2" \
"stable-baselines3==2.8.0" \
"stable-worldmodel[train,env]"
```
然后修正两个容易被 pip 带偏的依赖版本:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"fsspec==2025.3.0" \
"pillow==11.3.0"
```
检查环境:
```bash
python -m pip check
python - <<'PY'
import torch
import torchvision
print("torch:", torch.__version__)
print("hip:", torch.version.hip)
print("cuda available:", torch.cuda.is_available())
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
print("torchvision:", torchvision.__version__)
PY
```
期望看到类似输出:
```text
torch: 2.10.0+rocm7.1
cuda available: True
torchvision: 0.25.0+rocm7.1
```
## 恢复本仓库里的 stable-worldmodel 修改
这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在:
```text
.venv/lib/python3.10/site-packages/stable_worldmodel/
```
从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行:
```bash
git restore -- \
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
```
然后确认没有意外修改:
```bash
git status --short
```
## 数据和 checkpoint 路径
`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。
PushT 评估至少需要:
```text
$STABLEWM_HOME/pusht_expert_train.h5
$STABLEWM_HOME/pusht/lewm_object.ckpt
```
例如本机使用:
```bash
export STABLEWM_HOME=/mnt/ASC1637/stablewm
```
如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`
## 运行评估
默认 PushT 评估,保留 `torch.compile`
```bash
export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm
```
快速 smoke test
```bash
export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
eval.num_eval=1 \
world.num_envs=1 \
output.filename=/tmp/lewm_smoke_test.txt
```
smoke test 应该能正常结束,并打印类似:
```text
{'success_rate': 100.0, ...}
```
## 常见问题
### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT`
如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本:
```bash
python -c "import torch; print(torch.__version__, torch.version.hip)"
```
如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合:
```bash
python -m pip install --force-reinstall \
--index-url https://download.pytorch.org/whl/rocm7.1 \
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"torch==2.10.0" \
"torchvision==0.25.0"
```
### 找不到 `pusht_expert_train.h5`
设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录:
```bash
export STABLEWM_HOME=/path/to/stablewm
```
### pip 尝试构建旧版 `gym==0.21`
这是依赖解析回退导致的。先显式安装兼容版本:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"gymnasium[all]==1.2.2" \
"stable-baselines3==2.8.0"
```
### uv 或 pip 访问海外源很慢
普通 Python 包使用国内 PyPI 镜像:
```bash
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
```
PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:
```bash
--index-url https://download.pytorch.org/whl/rocm7.1
```

View File

@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
```
## Profiling
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
Example:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
inference_precision=bf16 \
+profile.enabled=true \
+profile.with_stack=true \
+profile.record_shapes=true \
+profile.profile_memory=true
```
Supported inference precision modes:
- `inference_precision=fp32`
- `inference_precision=bf16`
- `inference_precision=fp16`
Outputs are written next to the evaluation results:
- `torch_profile/key_averages.txt` for the aggregated operator table
- `torch_profile/trace.json` for Chrome tracing
- TensorBoard trace files under `torch_profile/`
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
## Pretrained Checkpoints
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.

View File

@@ -24,6 +24,7 @@ dataset:
seed: 42
policy: random # ckpt name or random
inference_precision: fp16
plan_config:
horizon: 5
@@ -36,6 +37,10 @@ eval:
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: ogbench/cube_single_expert
callables:
# -- set state
@@ -56,6 +61,21 @@ eval:
target_quat:
value: goal_privileged_block_0_quat
multi_node:
enabled: false
backend: gloo
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false
file: /tmp/lewm_preload_start
poll_interval: 1.0
output:
filename: ogb_cube_results.txt

View File

@@ -19,6 +19,7 @@ dataset:
seed: 42
policy: random # ckpt name or random
inference_precision: fp16
plan_config:
horizon: 5
@@ -31,6 +32,10 @@ eval:
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: pusht_expert_train
callables:
# -- set state
@@ -43,6 +48,22 @@ eval:
args:
goal_state:
value: goal_state
multi_node:
enabled: false
backend: gloo
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false
file: /tmp/lewm_preload_start
poll_interval: 1.0
output:
filename: pusht_results.txt
filename: pusht_results.txt

View File

@@ -18,6 +18,7 @@ dataset:
seed: 42
policy: random # ckpt name or random
inference_precision: fp16
plan_config:
horizon: 5
@@ -30,6 +31,10 @@ eval:
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: dmc/reacher_random
callables:
# -- set state
@@ -45,6 +50,21 @@ eval:
target_qpos:
value: goal_qpos
multi_node:
enabled: false
backend: gloo
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false
file: /tmp/lewm_preload_start
poll_interval: 1.0
output:
filename: dmc_results.txt

View File

@@ -1,9 +1,10 @@
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
batch_size: 16
# Original defaults: num_samples=300, n_steps=30, topk=30, batch_size=8.
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
topk: 8
device: "cuda"
seed: ${seed}

View File

@@ -0,0 +1,14 @@
_target_: stable_worldmodel.solver.GradientSolver
model: ???
# Original adam.yaml reference: n_steps=30, num_samples=100, batch_size=1, lr=0.1.
n_steps: 90
batch_size: 100
num_samples: 1
action_noise: 0
device: "cuda"
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.075

View File

@@ -12,6 +12,7 @@ world:
seed: 42
policy: random # ckpt name or random
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
@@ -30,6 +31,10 @@ eval:
goal_offset_steps: 25
eval_budget: 50
img_size: 224
save_video: false
compile_warmup:
enabled: true
num_eval: 1
dataset_name: tworoom
callables:
# -- set state
@@ -43,5 +48,21 @@ eval:
goal_state:
value: goal_proprio
multi_node:
enabled: false
backend: gloo
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false
file: /tmp/lewm_preload_start
poll_interval: 1.0
output:
filename: tworoom_results.txt
filename: tworoom_results.txt

836
eval.py
View File

@@ -2,8 +2,12 @@ import os
os.environ["MUJOCO_GL"] = "egl"
import multiprocessing as mp
import time
import traceback
from contextlib import nullcontext
from pathlib import Path
import tempfile
import hydra
import numpy as np
@@ -46,76 +50,331 @@ def get_dataset(cfg, dataset_name):
)
return dataset
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy."""
assert (
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
), "Planning horizon must be smaller than or equal to eval_budget"
# create world environment
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
world = swm.World(**cfg.world, image_shape=(224, 224))
def get_profile_cfg(cfg):
profile_cfg = {
"enabled": False,
"trace_dirname": "torch_profile",
"record_shapes": True,
"profile_memory": True,
"with_stack": False,
"with_flops": False,
"row_limit": 40,
"worker_name": "eval",
"export_chrome_trace": True,
"export_tensorboard": True,
}
cfg_profile = cfg.get("profile")
if cfg_profile is not None:
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
return profile_cfg
# create the transform
transform = {
"pixels": img_transform(cfg),
"goal": img_transform(cfg),
def get_compile_cfg(cfg):
compile_cfg = {
"enabled": True,
"target": "predictor",
"mode": "reduce-overhead",
"fullgraph": False,
"dynamic": False,
"cuda_only": True,
}
cfg_compile = cfg.get("compile")
if cfg_compile is not None:
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
return compile_cfg
def get_compile_warmup_cfg(cfg):
warmup_cfg = {
"enabled": True,
"num_eval": 1,
}
cfg_warmup = cfg.get("compile_warmup")
if cfg_warmup is None:
cfg_eval = cfg.get("eval")
if cfg_eval is not None:
cfg_warmup = cfg_eval.get("compile_warmup")
if cfg_warmup is not None:
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
return warmup_cfg
def get_preload_wait_cfg(cfg):
preload_cfg = {
"enabled": False,
"file": "/tmp/lewm_preload_start",
"poll_interval": 1.0,
}
cfg_preload = cfg.get("preload_wait")
if cfg_preload is not None:
preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True))
return preload_cfg
def wait_for_preload_signal(cfg, rank=0):
preload_cfg = get_preload_wait_cfg(cfg)
if not preload_cfg["enabled"]:
return
dist_ready = (
torch.distributed.is_available()
and torch.distributed.is_initialized()
)
if dist_ready:
torch.distributed.barrier()
signal_path = Path(str(preload_cfg["file"])).expanduser()
poll_interval = float(preload_cfg["poll_interval"])
if rank == 0:
print(
"Preload ready. Create this file to start evaluation: "
f"{signal_path}",
flush=True,
)
while not signal_path.exists():
time.sleep(poll_interval)
print("Preload start signal received. Starting evaluation.", flush=True)
if dist_ready:
torch.distributed.barrier()
def maybe_compile_inference_target(model, cfg, device):
compile_cfg = get_compile_cfg(cfg)
compile_target = "disabled"
if not compile_cfg["enabled"]:
return model, compile_cfg, compile_target
if not hasattr(torch, "compile"):
print("torch.compile is unavailable, skipping inference compilation.")
return model, compile_cfg, compile_target
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
return model, compile_cfg, compile_target
target = str(compile_cfg["target"]).lower()
compile_kwargs = {
"mode": compile_cfg["mode"],
"fullgraph": compile_cfg["fullgraph"],
"dynamic": compile_cfg["dynamic"],
}
dataset = get_dataset(cfg, cfg.eval.dataset_name)
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
if target == "predictor":
if not hasattr(model, "predictor"):
print("Requested compile target 'predictor' is unavailable on the model.")
return model, compile_cfg, compile_target
model.predictor = torch.compile(model.predictor, **compile_kwargs)
compile_target = "predictor"
elif target == "predict":
if not hasattr(model, "predict"):
print("Requested compile target 'predict' is unavailable on the model.")
return model, compile_cfg, compile_target
model.predict = torch.compile(model.predict, **compile_kwargs)
compile_target = "predict"
else:
print(
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
)
return model, compile_cfg, compile_target
def get_inference_context(cfg, device):
precision = str(cfg.get("inference_precision", "fp32")).lower()
device_type = "cuda" if device.startswith("cuda") else "cpu"
if precision == "fp32":
return nullcontext(), "fp32"
if precision in {"bf16", "bfloat16"}:
return (
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
"bf16",
)
if precision in {"fp16", "float16"}:
if device_type != "cuda":
print("fp16 inference is only supported on CUDA, falling back to fp32.")
return nullcontext(), "fp32"
return (
torch.autocast(device_type=device_type, dtype=torch.float16),
"fp16",
)
raise ValueError(
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
)
def get_eval_grad_context(solver=None):
if isinstance(solver, swm.solver.GradientSolver):
return torch.enable_grad()
return torch.inference_mode()
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
return nullcontext(), None, profile_cfg
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
profile_dir = results_path / profile_cfg["trace_dirname"]
profile_dir.mkdir(parents=True, exist_ok=True)
profiler = torch.profiler.profile(
activities=activities,
record_shapes=profile_cfg["record_shapes"],
profile_memory=profile_cfg["profile_memory"],
with_stack=profile_cfg["with_stack"],
with_flops=profile_cfg["with_flops"],
)
return profiler, profile_dir, profile_cfg
def dump_profiler_results(profiler, profile_dir, profile_cfg):
if profiler is None or profile_dir is None:
return None
has_cuda = torch.cuda.is_available()
table = profiler.key_averages().table(
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
row_limit=profile_cfg["row_limit"],
)
summary_path = profile_dir / "key_averages.txt"
summary_path.write_text(table)
if profile_cfg["export_tensorboard"]:
trace_handler = torch.profiler.tensorboard_trace_handler(
str(profile_dir), worker_name=profile_cfg["worker_name"]
)
trace_handler(profiler)
elif profile_cfg["export_chrome_trace"]:
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
return summary_path
def get_multi_gpu_cfg(cfg):
multi_gpu_cfg = {
"enabled": False,
"devices": None,
"start_method": "spawn",
}
cfg_multi_gpu = cfg.get("multi_gpu")
if cfg_multi_gpu is not None:
multi_gpu_cfg.update(OmegaConf.to_container(cfg_multi_gpu, resolve=True))
return multi_gpu_cfg
def get_multi_node_cfg(cfg):
multi_node_cfg = {
"enabled": False,
"backend": "gloo",
"rank_env": "RANK",
"world_size_env": "WORLD_SIZE",
"local_rank_env": "LOCAL_RANK",
"output_mode": "single",
"aggregate_results": True,
"sync_before_return": False,
"destroy_process_group": True,
"shard_strategy": "round_robin",
}
cfg_multi_node = cfg.get("multi_node")
if cfg_multi_node is not None:
multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True))
return multi_node_cfg
def get_dist_env(name, default=None):
value = os.environ.get(name, default)
if value is None:
return None
return int(value)
def get_rank_context(cfg):
multi_node_cfg = get_multi_node_cfg(cfg)
if not multi_node_cfg["enabled"]:
return 0, 1, 0
rank = get_dist_env(multi_node_cfg["rank_env"])
world_size = get_dist_env(multi_node_cfg["world_size_env"])
local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0)
if rank is None or world_size is None:
raise ValueError(
"multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE"
)
if world_size < 1:
raise ValueError("WORLD_SIZE must be >= 1")
if rank < 0 or rank >= world_size:
raise ValueError("RANK must be in [0, WORLD_SIZE)")
return rank, world_size, local_rank
def all_gather_eval_result(result):
world_size = torch.distributed.get_world_size()
payload = [None for _ in range(world_size)]
torch.distributed.all_gather_object(payload, result)
return payload
def finalize_multi_node_process_group(cfg):
multi_node_cfg = get_multi_node_cfg(cfg)
if not multi_node_cfg["destroy_process_group"]:
return
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def get_rank_result_path(output_dir: Path, cfg: DictConfig, rank: int) -> Path:
filename = str(cfg.output.filename)
if rank == 0:
return output_dir / filename
suffix = Path(filename).suffix
stem = Path(filename).stem
if suffix:
ranked_filename = f"{stem}.rank{rank}{suffix}"
else:
ranked_filename = f"{filename}.rank{rank}"
return output_dir / ranked_filename
def build_process(cfg, dataset):
process = {}
for col in cfg.dataset.keys_to_cache:
if col in ["pixels"]:
continue
processor = preprocessing.StandardScaler()
col_data = stats_dataset.get_col_data(col)
col_data = dataset.get_col_data(col)
col_data = col_data[~np.isnan(col_data).any(axis=1)]
processor.fit(col_data)
process[col] = processor
if col != "action":
process[f"goal_{col}"] = process[col]
return process
# -- run evaluation
policy = cfg.get("policy", "random")
if policy != "random":
model = swm.policy.AutoCostModel(cfg.policy)
model = model.to("cuda")
model = model.eval()
model.requires_grad_(False)
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model)
policy = swm.policy.WorldModelPolicy(
solver=solver, config=config, process=process, transform=transform
)
else:
policy = swm.policy.RandomPolicy()
results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random"
else Path(__file__).parent
)
# sample the episodes and the starting indices
def sample_eval_cases(cfg, dataset):
stats_dataset = dataset
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
episode_len = get_episodes_length(dataset, ep_indices)
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
# Map each dataset rows episode_idx to its max_start_idx
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
max_start_per_row = np.array(
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
)
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
valid_indices = np.nonzero(valid_mask)[0]
print(valid_mask.sum(), "valid starting points found for evaluation.")
@@ -124,35 +383,478 @@ def run(cfg: DictConfig):
random_episode_indices = g.choice(
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
)
# sort increasingly to avoid issues with HDF5Dataset indexing
random_episode_indices = np.sort(valid_indices[random_episode_indices])
print(random_episode_indices)
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
rows = dataset.get_row_data(random_episode_indices)
eval_episodes = rows[col_name]
eval_start_idx = rows["step_idx"]
if len(eval_episodes) < cfg.eval.num_eval:
raise ValueError("Not enough episodes with sufficient length for evaluation.")
return eval_episodes, eval_start_idx
def normalize_multi_gpu_devices(devices):
if devices is None:
return [f"cuda:{idx}" for idx in range(torch.cuda.device_count())]
normalized = []
for device in devices:
if isinstance(device, int):
normalized.append(f"cuda:{device}")
elif isinstance(device, str) and device.isdigit():
normalized.append(f"cuda:{int(device)}")
else:
normalized.append(str(device))
return normalized
def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
if num_shards < 1:
raise ValueError("num_shards must be >= 1")
total = len(eval_episodes)
shard_sizes = [total // num_shards] * num_shards
for idx in range(total % num_shards):
shard_sizes[idx] += 1
shards = []
start = 0
for size in shard_sizes:
end = start + size
if size > 0:
shards.append((eval_episodes[start:end], eval_start_idx[start:end]))
start = end
return shards
def get_rank_eval_subset(
eval_episodes,
eval_start_idx,
rank,
world_size,
*,
strategy="contiguous",
):
if world_size < 1:
raise ValueError("world_size must be >= 1")
if rank < 0 or rank >= world_size:
raise ValueError("rank must be in [0, world_size)")
if strategy == "round_robin":
episode_subset = eval_episodes[rank::world_size]
start_subset = eval_start_idx[rank::world_size]
return episode_subset, start_subset
if strategy != "contiguous":
raise ValueError("strategy must be one of: contiguous, round_robin")
total = len(eval_episodes)
shard_sizes = [total // world_size] * world_size
for idx in range(total % world_size):
shard_sizes[idx] += 1
start = sum(shard_sizes[:rank])
end = start + shard_sizes[rank]
return eval_episodes[start:end], eval_start_idx[start:end]
def run_eval_subset(
cfg: DictConfig,
eval_episodes,
eval_start_idx,
output_dir: Path,
*,
device_override: str | None = None,
enable_profile: bool = True,
enable_compile_warmup: bool = False,
before_evaluate=None,
):
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
local_cfg.eval.num_eval = len(eval_episodes)
local_cfg.world.num_envs = len(eval_episodes)
local_cfg.world.max_episode_steps = 2 * local_cfg.eval.eval_budget
if device_override is not None:
local_cfg.solver.device = device_override
if torch.cuda.is_available() and str(device_override).startswith("cuda"):
torch.cuda.set_device(torch.device(device_override))
if not enable_profile:
if local_cfg.get("profile") is None:
local_cfg.profile = OmegaConf.create({"enabled": False})
else:
local_cfg.profile.enabled = False
world = swm.World(**local_cfg.world, image_shape=(224, 224))
transform = {
"pixels": img_transform(local_cfg),
"goal": img_transform(local_cfg),
}
dataset = get_dataset(local_cfg, local_cfg.eval.dataset_name)
process = build_process(local_cfg, dataset)
policy_name = local_cfg.get("policy", "random")
if policy_name != "random":
model = swm.policy.AutoCostModel(local_cfg.policy)
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = model.eval()
model.requires_grad_(False)
model, compile_cfg, compile_target = maybe_compile_inference_target(
model, local_cfg, device
)
inference_ctx, inference_precision = get_inference_context(local_cfg, device)
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**local_cfg.plan_config)
solver = hydra.utils.instantiate(local_cfg.solver, model=model)
policy = swm.policy.WorldModelPolicy(
solver=solver, config=config, process=process, transform=transform
)
else:
policy = swm.policy.RandomPolicy()
solver = None
inference_ctx = nullcontext()
inference_precision = "fp32"
compile_cfg = get_compile_cfg(local_cfg)
compile_target = "disabled"
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
profiler_ctx, profile_dir, profile_cfg = make_profiler(local_cfg, output_dir)
world.set_policy(policy)
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
if before_evaluate is not None:
before_evaluate()
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
return world.evaluate_from_dataset(
dataset,
start_steps=list(start_indices),
goal_offset_steps=eval_cfg.eval.goal_offset_steps,
eval_budget=eval_cfg.eval.eval_budget,
episodes_idx=list(episodes),
callables=OmegaConf.to_container(
eval_cfg.eval.get("callables"), resolve=True
),
save_video=bool(eval_cfg.eval.get("save_video", False)),
video_path=output_dir,
)
start_time = time.time()
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path,
with get_eval_grad_context(solver):
with profiler_ctx as profiler:
with inference_ctx:
if enable_compile_warmup:
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = evaluate_subset(eval_episodes, eval_start_idx)
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
evaluation_time = time.time() - start_time
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
return {
"metrics": metrics,
"evaluation_time": evaluation_time,
"inference_precision": inference_precision,
"compile_target": compile_target,
"compile_mode": compile_cfg["mode"] if compile_target != "disabled" else None,
"profile_dir": profile_dir,
"profile_summary_path": profile_summary_path,
}
def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
warmup_cfg = get_compile_warmup_cfg(cfg)
if not warmup_cfg["enabled"]:
return
if get_multi_gpu_cfg(cfg)["enabled"]:
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
return
if get_multi_node_cfg(cfg)["enabled"]:
rank, world_size, local_rank = get_rank_context(cfg)
eval_episodes, eval_start_idx = get_rank_eval_subset(
eval_episodes, eval_start_idx, rank, world_size
)
device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
else:
device_override = None
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
if warmup_count < 1:
return
warmup_eval_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
warmup_eval_cfg.eval.num_eval = warmup_count
warmup_eval_cfg.eval.save_video = False
if warmup_eval_cfg.get("profile") is None:
warmup_eval_cfg.profile = OmegaConf.create({"enabled": False})
else:
warmup_eval_cfg.profile.enabled = False
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
run_eval_subset(
warmup_eval_cfg,
list(eval_episodes[:warmup_count]),
list(eval_start_idx[:warmup_count]),
Path(tmpdir),
device_override=device_override,
enable_profile=False,
)
def _multi_gpu_eval_worker(
cfg_container,
eval_episodes,
eval_start_idx,
output_dir,
device,
shard_idx,
queue,
):
try:
cfg = OmegaConf.create(cfg_container)
result = run_eval_subset(
cfg,
eval_episodes,
eval_start_idx,
Path(output_dir),
device_override=device,
enable_profile=False,
)
queue.put({"ok": True, "shard_idx": shard_idx, "result": result})
except Exception:
queue.put(
{
"ok": False,
"shard_idx": shard_idx,
"error": traceback.format_exc(),
}
)
def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
multi_gpu_cfg = get_multi_gpu_cfg(cfg)
devices = normalize_multi_gpu_devices(multi_gpu_cfg["devices"])
if len(devices) < 2:
raise ValueError("multi_gpu.enabled=true requires at least 2 CUDA devices")
shards = shard_eval_cases(eval_episodes, eval_start_idx, min(len(devices), len(eval_episodes)))
devices = devices[: len(shards)]
ctx = mp.get_context(multi_gpu_cfg["start_method"])
queue = ctx.Queue()
cfg_container = OmegaConf.to_container(cfg, resolve=False)
processes = []
start_time = time.time()
for shard_idx, ((shard_episodes, shard_start_idx), device) in enumerate(
zip(shards, devices, strict=True)
):
process = ctx.Process(
target=_multi_gpu_eval_worker,
args=(
cfg_container,
list(shard_episodes),
list(shard_start_idx),
str(output_dir),
device,
shard_idx,
queue,
),
)
process.start()
processes.append(process)
shard_results = {}
errors = []
for _ in processes:
message = queue.get()
if message["ok"]:
shard_results[message["shard_idx"]] = message["result"]
else:
errors.append(message["error"])
for process in processes:
process.join()
if errors:
raise RuntimeError(errors[0])
ordered_results = [shard_results[idx] for idx in range(len(processes))]
episode_successes = np.concatenate(
[
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
for result in ordered_results
]
)
end_time = time.time()
seeds = None
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
if all(seed is not None for seed in shard_seeds):
seeds = np.concatenate(shard_seeds)
metrics = {
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
"episode_successes": episode_successes,
"seeds": seeds,
}
reference = ordered_results[0]
return {
"metrics": metrics,
"evaluation_time": time.time() - start_time,
"inference_precision": reference["inference_precision"],
"compile_target": reference["compile_target"],
"compile_mode": reference["compile_mode"],
"profile_dir": None,
"profile_summary_path": None,
}
def combine_eval_results(ordered_results):
episode_successes = np.concatenate(
[
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
for result in ordered_results
]
)
seeds = None
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
if all(seed is not None for seed in shard_seeds):
seeds = np.concatenate(shard_seeds)
metrics = {
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
"episode_successes": episode_successes,
"seeds": seeds,
}
reference = ordered_results[0]
return metrics, reference
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
rank, world_size, local_rank = get_rank_context(cfg)
multi_node_cfg = get_multi_node_cfg(cfg)
shard_episodes, shard_start_idx = get_rank_eval_subset(
eval_episodes,
eval_start_idx,
rank,
world_size,
strategy=multi_node_cfg["shard_strategy"],
)
if len(shard_episodes) == 0:
raise ValueError("No evaluation episodes assigned to this rank")
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
local_cfg.multi_node.enabled = False
if local_cfg.get("multi_gpu") is None:
local_cfg.multi_gpu = OmegaConf.create({"enabled": False})
else:
local_cfg.multi_gpu.enabled = False
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
preload_cfg = get_preload_wait_cfg(cfg)
if preload_cfg["enabled"]:
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is required for preload_wait")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
rank_output_path = get_rank_result_path(output_dir, cfg, rank)
result = run_eval_subset(
local_cfg,
list(shard_episodes),
list(shard_start_idx),
rank_output_path.parent,
device_override=device,
enable_profile=False,
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
)
if not multi_node_cfg["aggregate_results"]:
result["output_filename"] = rank_output_path.name
finalize_multi_node_process_group(cfg)
return result
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is required for multi-node evaluation")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
gathered = all_gather_eval_result(result)
metrics, reference = combine_eval_results(gathered)
combined = {
"metrics": metrics,
"evaluation_time": max(item["evaluation_time"] for item in gathered),
"inference_precision": reference["inference_precision"],
"compile_target": reference["compile_target"],
"compile_mode": reference["compile_mode"],
"profile_dir": None,
"profile_summary_path": None,
"output_filename": cfg.output.filename,
}
if multi_node_cfg["sync_before_return"]:
torch.distributed.barrier()
finalize_multi_node_process_group(cfg)
if rank != 0:
return None
return combined
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy."""
assert (
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
), "Planning horizon must be smaller than or equal to eval_budget"
dataset = get_dataset(cfg, cfg.eval.dataset_name)
eval_episodes, eval_start_idx = sample_eval_cases(cfg, dataset)
output_dir = Path.cwd().resolve()
profile_cfg = get_profile_cfg(cfg)
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
eval_wall_start = time.time()
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
if get_multi_node_cfg(cfg)["enabled"]:
eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir)
if eval_result is None:
return
elif get_multi_gpu_cfg(cfg)["enabled"]:
if profile_cfg["enabled"]:
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
else:
eval_result = run_eval_subset(
cfg,
eval_episodes.tolist(),
eval_start_idx.tolist(),
output_dir,
)
metrics = eval_result["metrics"]
evaluation_time = eval_result["evaluation_time"]
inference_precision = eval_result["inference_precision"]
compile_target = eval_result["compile_target"]
compile_mode = eval_result["compile_mode"]
profile_dir = eval_result["profile_dir"]
profile_summary_path = eval_result["profile_summary_path"]
output_filename = eval_result.get("output_filename", cfg.output.filename)
print(metrics)
results_path = results_path / cfg.output.filename
results_path = output_dir / output_filename
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("a") as f:
@@ -164,7 +866,17 @@ def run(cfg: DictConfig):
f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
f.write(f"evaluation_time: {evaluation_time} seconds\n")
f.write(f"inference_precision: {inference_precision}\n")
f.write(f"inference_compile_target: {compile_target}\n")
if compile_target != "disabled":
f.write(f"inference_compile_mode: {compile_mode}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
f.write(f"total_wall_time: {time.time() - eval_wall_start} seconds\n")
if __name__ == "__main__":

320
jepa.py
View File

@@ -2,12 +2,8 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
def detach_clone(v):
return v.detach().clone() if torch.is_tensor(v) else v
class JEPA(nn.Module):
def __init__(
@@ -25,34 +21,124 @@ class JEPA(nn.Module):
self.action_encoder = action_encoder
self.projector = projector or nn.Identity()
self.pred_proj = pred_proj or nn.Identity()
self._cached_device_tensors = {}
self._cached_init_signature = None
self._cached_init_emb = None
self._cached_goal_signature = None
self._cached_goal_emb = None
def _ensure_runtime_caches(self):
if not hasattr(self, "_cached_device_tensors"):
self._cached_device_tensors = {}
if not hasattr(self, "_cached_init_signature"):
self._cached_init_signature = None
if not hasattr(self, "_cached_init_emb"):
self._cached_init_emb = None
if not hasattr(self, "_cached_goal_signature"):
self._cached_goal_signature = None
if not hasattr(self, "_cached_goal_emb"):
self._cached_goal_emb = None
@staticmethod
def _tensor_signature(tensor: torch.Tensor):
try:
version = tensor._version
except RuntimeError:
version = None
return (
str(tensor.device),
tensor.dtype,
tuple(tensor.shape),
tuple(tensor.stride()),
tensor.storage_offset(),
tensor.data_ptr(),
version,
)
def _get_cached_device_tensor(
self,
key: str,
tensor: torch.Tensor,
device: torch.device,
*,
ensure_contiguous: bool = False,
):
self._ensure_runtime_caches()
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
return tensor
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
cached = self._cached_device_tensors.get(key)
if cached is None or cached[0] != signature:
prepared = tensor.to(device, non_blocking=True)
if ensure_contiguous and not prepared.is_contiguous():
prepared = prepared.contiguous()
self._cached_device_tensors[key] = (
signature,
prepared,
)
return self._cached_device_tensors[key][1]
def _ensure_info_device(self, info_dict: dict, device: torch.device):
for key, value in list(info_dict.items()):
if key.startswith("_lewm_"):
continue
if torch.is_tensor(value):
info_dict[key] = self._get_cached_device_tensor(
key,
value,
device,
ensure_contiguous=True,
)
return info_dict
def _get_cached_init_emb(self, info_dict: dict):
self._ensure_runtime_caches()
pixels = info_dict["pixels"]
signature = self._tensor_signature(pixels)
if self._cached_init_signature != signature:
init_info = {"pixels": pixels[:, 0]}
self._cached_init_emb = self.encode(init_info)["emb"].detach()
self._cached_init_signature = signature
return self._cached_init_emb
def _get_cached_goal_emb(self, info_dict: dict):
self._ensure_runtime_caches()
goal = info_dict["goal"]
signature = self._tensor_signature(goal)
if self._cached_goal_signature != signature:
goal_info = {"pixels": goal[:, 0]}
self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach()
self._cached_goal_signature = signature
return self._cached_goal_emb
def encode(self, info):
"""Encode observations and actions into embeddings.
info: dict with pixels and action keys
"""
with torch.profiler.record_function("lewm.encode"):
pixels = info['pixels'].float()
b, t = pixels.shape[:2]
pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = emb.reshape(b, t, -1)
pixels = info['pixels'].float()
b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
if "action" in info:
info["act_emb"] = self.action_encoder(info["action"])
if "action" in info:
info["act_emb"] = self.action_encoder(info["action"])
return info
return info
def predict(self, emb, act_emb):
"""Predict next state embedding
emb: (B, T, D)
act_emb: (B, T, A_emb)
"""
preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
return preds
with torch.profiler.record_function("lewm.predict"):
preds = self.predictor(emb, act_emb)
preds = self.pred_proj(preds)
return preds
####################
## Inference only ##
@@ -65,89 +151,151 @@ class JEPA(nn.Module):
- S is the number of action plan samples
- T is the time horizon
"""
with torch.profiler.record_function("lewm.rollout"):
assert "pixels" in info, "pixels not in info_dict"
if history_size < 1:
raise ValueError("history_size must be >= 1")
assert "pixels" in info, "pixels not in info_dict"
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
info["action"] = act_0
n_steps = T - H
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
if T < H:
raise ValueError(
f"action_sequence horizon ({T}) must be >= history length ({H})"
)
# copy and encode initial info dict
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
_init = self.encode(_init)
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
_init = {k: detach_clone(v) for k, v in _init.items()}
# Cache the encoded initial state across solver iterations.
init_emb = self._get_cached_init_emb(info)
HS = history_size
hist_len = min(HS, init_emb.size(1), H)
if hist_len < 1:
raise ValueError("rollout requires at least one history step")
# flatten batch and sample dimensions for rollout
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
act = rearrange(act_0, "b s ... -> (b s) ...")
act_future = rearrange(act_future, "b s ... -> (b s) ...")
init_hist = init_emb[:, -hist_len:]
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
# rollout predictor autoregressively for n_steps
HS = history_size
for t in range(n_steps):
act_emb = self.action_encoder(act)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
action_emb = self.action_encoder(flat_actions)
act_hist = action_emb[:, H - hist_len : H]
act_future = action_emb[:, H:]
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
if HS == 1:
emb_hist = init_hist[:, -1:]
act_emb_hist = act_hist[:, -1:]
# predict the last state
act_emb = self.action_encoder(act) # (BS, T, A_emb)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1)
for t in range(act_future.size(1)):
emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:]
act_emb_hist = act_future[:, t : t + 1]
# unflatten batch and sample dimensions
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
else:
if torch.is_grad_enabled() and action_sequence.requires_grad:
emb_slots = init_hist.split(1, dim=1)
act_slots = act_hist.split(1, dim=1)
return info
for t in range(act_future.size(1)):
emb_view = torch.cat(emb_slots[-HS:], dim=1)
act_view = torch.cat(act_slots[-HS:], dim=1)
pred_emb = self.predict(emb_view, act_view)[:, -1:]
next_act_emb = act_future[:, t : t + 1]
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
emb_view = torch.cat(emb_slots[-HS:], dim=1)
act_view = torch.cat(act_slots[-HS:], dim=1)
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
info["predicted_emb"] = pred_rollout.reshape(
B, S, *pred_rollout.shape[1:]
)
return info
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
emb_hist[:, :hist_len].copy_(init_hist)
act_emb_hist[:, :hist_len].copy_(act_hist)
history_order = torch.stack(
[
(torch.arange(HS, device=action_emb.device) + offset) % HS
for offset in range(HS)
]
)
filled = hist_len
next_slot = hist_len % HS
for t in range(act_future.size(1)):
if filled < HS:
emb_view = emb_hist[:, :filled]
act_view = act_emb_hist[:, :filled]
elif next_slot == 0:
emb_view = emb_hist
act_view = act_emb_hist
else:
order = history_order[next_slot]
emb_view = emb_hist.index_select(1, order)
act_view = act_emb_hist.index_select(1, order)
pred_emb = self.predict(emb_view, act_view)[:, -1:]
next_act_emb = act_future[:, t : t + 1]
emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb)
act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb)
if filled < HS:
filled += 1
next_slot = (next_slot + 1) % HS
if filled < HS:
emb_view = emb_hist[:, :filled]
act_view = act_emb_hist[:, :filled]
elif next_slot == 0:
emb_view = emb_hist
act_view = act_emb_hist
else:
order = history_order[next_slot]
emb_view = emb_hist.index_select(1, order)
act_view = act_emb_hist.index_select(1, order)
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
return info
def criterion(self, info_dict: dict):
"""Compute the cost between predicted embeddings and goal embeddings."""
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
with torch.profiler.record_function("lewm.criterion"):
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
if goal_emb.ndim == pred_emb.ndim - 1:
goal_emb = goal_emb.unsqueeze(1)
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
# return last-step cost per action candidate
cost = F.mse_loss(
pred_emb[..., -1:, :],
goal_emb[..., -1:, :].detach(),
reduction="none",
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
# return last-step cost per action candidate
cost = F.mse_loss(
pred_emb[..., -1:, :],
goal_emb[..., -1:, :].detach(),
reduction="none",
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
return cost
return cost
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
""" Compute the cost of action candidates given an info dict with goal and initial state."""
with torch.profiler.record_function("lewm.get_cost"):
assert "goal" in info_dict, "goal not in info_dict"
assert "goal" in info_dict, "goal not in info_dict"
self._ensure_runtime_caches()
device = next(self.parameters()).device
info_dict = self._ensure_info_device(info_dict, device)
action_candidates = self._get_cached_device_tensor(
"_lewm_action_candidates",
action_candidates,
device,
ensure_contiguous=True,
)
device = next(self.parameters()).device
for k in list(info_dict.keys()):
if torch.is_tensor(info_dict[k]):
info_dict[k] = info_dict[k].to(device)
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
info_dict = self.rollout(info_dict, action_candidates)
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
goal["pixels"] = goal["goal"]
for k in info_dict:
if k.startswith("goal_"):
goal[k[len("goal_") :]] = goal.pop(k)
goal.pop("action")
goal = self.encode(goal)
info_dict["goal_emb"] = goal["emb"]
info_dict = self.rollout(info_dict, action_candidates)
cost = self.criterion(info_dict)
return cost
cost = self.criterion(info_dict)
return cost

View File

@@ -236,9 +236,13 @@ class MLP(nn.Module):
def forward(self, x):
"""
x: (B*T, D)
x: (..., D)
"""
return self.net(x)
if x.ndim <= 2:
return self.net(x)
output = self.net(x.reshape(-1, x.size(-1)))
return output.reshape(*x.shape[:-1], output.size(-1))
class ARPredictor(nn.Module):

1624
pusht_results.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python
"""Convert LeWM HuggingFace weights into eval-compatible object checkpoints."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import stable_pretraining as spt
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from jepa import JEPA
from module import ARPredictor, Embedder, MLP
def _load_json(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def _strip_target(config: dict) -> dict:
return {key: value for key, value in config.items() if key != "_target_"}
def infer_config_from_state_dict(state_dict: dict) -> dict:
action_dim = state_dict["action_encoder.patch_embed.weight"].shape[1]
return {
"encoder": {
"size": "tiny",
"patch_size": 14,
"image_size": 224,
"pretrained": False,
"use_mask_token": False,
},
"predictor": {
"num_frames": 3,
"input_dim": 192,
"hidden_dim": 192,
"output_dim": 192,
"depth": 6,
"heads": 16,
"mlp_dim": 2048,
"dim_head": 64,
"dropout": 0.1,
"emb_dropout": 0.0,
},
"action_encoder": {
"input_dim": action_dim,
"emb_dim": 192,
},
"projector": {
"input_dim": 192,
"output_dim": 192,
"hidden_dim": 2048,
},
"pred_proj": {
"input_dim": 192,
"output_dim": 192,
"hidden_dim": 2048,
},
}
def build_model(config: dict) -> JEPA:
encoder = spt.backbone.utils.vit_hf(**_strip_target(config["encoder"]))
predictor = ARPredictor(**_strip_target(config["predictor"]))
action_encoder = Embedder(**_strip_target(config["action_encoder"]))
projector_cfg = _strip_target(config["projector"])
projector_cfg["norm_fn"] = torch.nn.BatchNorm1d
projector = MLP(**projector_cfg)
pred_proj_cfg = _strip_target(config["pred_proj"])
pred_proj_cfg["norm_fn"] = torch.nn.BatchNorm1d
pred_proj = MLP(**pred_proj_cfg)
return JEPA(
encoder=encoder,
predictor=predictor,
action_encoder=action_encoder,
projector=projector,
pred_proj=pred_proj,
)
def convert_checkpoint(input_dir: Path, output_name: str) -> tuple[Path, Path]:
config_path = input_dir / "config.json"
weights_path = input_dir / "weights.pt"
if not weights_path.exists():
raise FileNotFoundError(f"Missing weights file: {weights_path}")
state_dict = torch.load(weights_path, map_location="cpu")
config = _load_json(config_path) if config_path.exists() else infer_config_from_state_dict(state_dict)
model = build_model(config)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
if missing or unexpected:
raise RuntimeError(
f"State dict mismatch: missing={missing}, unexpected={unexpected}"
)
model.eval()
object_path = input_dir / f"{output_name}_object.ckpt"
weight_path = input_dir / f"{output_name}_weight.ckpt"
torch.save(model, object_path)
torch.save(model.state_dict(), weight_path)
return object_path, weight_path
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir",
type=Path,
help="Directory containing weights.pt and optionally config.json.",
)
parser.add_argument("--output-name", default="lewm")
args = parser.parse_args()
object_path, weight_path = convert_checkpoint(args.input_dir, args.output_name)
print(f"wrote {object_path}")
print(f"wrote {weight_path}")
if __name__ == "__main__":
main()

255
scripts/launch_multinode_eval.sh Executable file
View File

@@ -0,0 +1,255 @@
#!/usr/bin/env bash
set -euo pipefail
# Launch 2-node LeWM evaluation from node-3.
#
# Defaults match the current cluster layout:
# node-3: 10.16.200.9, node_rank=0
# node-2: 10.16.200.8, node_rank=1
# Each node runs two local torchrun processes for two visible GPUs.
REPO_ROOT="${REPO_ROOT:-/home/lewm/lewm}"
REMOTE_HOST="${REMOTE_HOST:-lewm@10.16.200.8}"
MASTER_ADDR="${MASTER_ADDR:-10.16.200.9}"
MASTER_PORT="${MASTER_PORT:-29500}"
NNODES="${NNODES:-2}"
NPROC_PER_NODE="${NPROC_PER_NODE:-2}"
CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1}"
STABLEWM_HOME="${STABLEWM_HOME:-/home/lewm/.stable-wm}"
CONFIG_NAME="${CONFIG_NAME:-pusht.yaml}"
POLICY="${POLICY:-pusht/lewm}"
OUTPUT_FILENAME="${OUTPUT_FILENAME:-pusht_multinode_results.txt}"
EXTRA_ARGS="${EXTRA_ARGS:-}"
DRY_RUN="${DRY_RUN:-0}"
TAIL_LOGS="${TAIL_LOGS:-1}"
PRELOAD_WAIT="${PRELOAD_WAIT:-0}"
PRELOAD_SIGNAL_FILE="${PRELOAD_SIGNAL_FILE:-/tmp/lewm_preload_start}"
PRELOAD_CLEAR_SIGNAL="${PRELOAD_CLEAR_SIGNAL:-1}"
LOG_DIR="${LOG_DIR:-${REPO_ROOT}/logs/multinode}"
mkdir -p "${LOG_DIR}"
RUN_ID="$(date +%Y%m%d_%H%M%S)"
LOCAL_LOG="${LOG_DIR}/${RUN_ID}_node3_rank0.log"
REMOTE_LOG="${LOG_DIR}/${RUN_ID}_node2_rank1.log"
SSH_OPTS=(
-F /dev/null
-o StrictHostKeyChecking=no
-o ServerAliveInterval=30
-o ServerAliveCountMax=20
)
COMMON_ARGS=(
"--config-name=${CONFIG_NAME}"
"policy=${POLICY}"
"multi_node.enabled=true"
"output.filename=${OUTPUT_FILENAME}"
)
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
COMMON_ARGS+=(
"preload_wait.enabled=true"
"preload_wait.file=${PRELOAD_SIGNAL_FILE}"
)
fi
if [[ -n "${EXTRA_ARGS}" ]]; then
# shellcheck disable=SC2206
COMMON_ARGS+=(${EXTRA_ARGS})
fi
make_command() {
local node_rank="$1"
local repo_q cuda_q stablewm_q arg_q eval_args
printf -v repo_q '%q' "${REPO_ROOT}"
printf -v cuda_q '%q' "${CUDA_VISIBLE_DEVICES}"
printf -v stablewm_q '%q' "${STABLEWM_HOME}"
eval_args=""
for arg in "${COMMON_ARGS[@]}"; do
printf -v arg_q '%q' "${arg}"
eval_args+=" ${arg_q}"
done
printf 'cd %s && source .venv/bin/activate && export CUDA_VISIBLE_DEVICES=%s && export STABLEWM_HOME=%s && torchrun --nnodes=%q --nproc_per_node=%q --node_rank=%q --master_addr=%q --master_port=%q eval.py%s' \
"${repo_q}" \
"${cuda_q}" \
"${stablewm_q}" \
"${NNODES}" \
"${NPROC_PER_NODE}" \
"${node_rank}" \
"${MASTER_ADDR}" \
"${MASTER_PORT}" \
"${eval_args}"
}
REMOTE_CMD="$(make_command 1)"
LOCAL_CMD="$(make_command 0)"
printf -v REMOTE_CMD_Q '%q' "${REMOTE_CMD}"
REMOTE_PID=""
LOCAL_PID=""
LOCAL_TAIL_PID=""
REMOTE_TAIL_PID=""
REMOTE_CLEANUP_CMD=""
REMOTE_CLEANUP_CMD_Q=""
start_log_tail() {
local label="$1"
local log_file="$2"
local label_q log_q
printf -v label_q '%q' "${label}"
printf -v log_q '%q' "${log_file}"
setsid bash -lc "tail -n +1 -F ${log_q} 2>/dev/null | sed -u 's/^/[${label_q}] /'" &
}
stop_log_tails() {
local pid
for pid in "${LOCAL_TAIL_PID}" "${REMOTE_TAIL_PID}"; do
if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then
kill -TERM "-${pid}" 2>/dev/null || kill -TERM "${pid}" 2>/dev/null || true
fi
done
}
remote_cleanup_command() {
local pattern_q
local patterns=(
"torchrun .*--master_addr=${MASTER_ADDR} .*--master_port=${MASTER_PORT} .*eval.py"
"torchrun .*--master_port=${MASTER_PORT} .*eval.py"
"python.*eval.py .*output.filename=${OUTPUT_FILENAME}"
)
printf 'set +e; '
for pattern in "${patterns[@]}"; do
printf -v pattern_q '%q' "${pattern}"
printf 'pkill -TERM -f %s 2>/dev/null; ' "${pattern_q}"
done
printf 'sleep 2; '
for pattern in "${patterns[@]}"; do
printf -v pattern_q '%q' "${pattern}"
printf 'pkill -KILL -f %s 2>/dev/null; ' "${pattern_q}"
done
printf 'true'
}
cleanup() {
local status="$?"
trap - INT TERM EXIT
if [[ "${status}" -eq 0 ]]; then
return 0
fi
echo
echo "Stopping multi-node eval..."
stop_log_tails
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
kill -TERM "-${LOCAL_PID}" 2>/dev/null || kill -TERM "${LOCAL_PID}" 2>/dev/null || true
fi
if [[ -n "${REMOTE_PID}" ]] && kill -0 "${REMOTE_PID}" 2>/dev/null; then
kill -TERM "${REMOTE_PID}" 2>/dev/null || true
fi
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CLEANUP_CMD_Q}" >/dev/null 2>&1 || true
if [[ -n "${LOCAL_PID}" ]] && kill -0 "${LOCAL_PID}" 2>/dev/null; then
sleep 2
kill -KILL "-${LOCAL_PID}" 2>/dev/null || kill -KILL "${LOCAL_PID}" 2>/dev/null || true
fi
echo "Cleanup requested. Check logs if any process was already exiting:"
echo " local: ${LOCAL_LOG}"
echo " remote: ${REMOTE_LOG}"
exit "${status}"
}
trap cleanup INT TERM EXIT
REMOTE_CLEANUP_CMD="$(remote_cleanup_command)"
printf -v REMOTE_CLEANUP_CMD_Q '%q' "${REMOTE_CLEANUP_CMD}"
echo "Launching multi-node eval"
echo " master: ${MASTER_ADDR}:${MASTER_PORT}"
echo " remote: ${REMOTE_HOST}"
echo " repo: ${REPO_ROOT}"
echo " stablewm: ${STABLEWM_HOME}"
echo " config: ${CONFIG_NAME}"
echo " policy: ${POLICY}"
echo " output: ${OUTPUT_FILENAME}"
echo " extra: ${EXTRA_ARGS:-<none>}"
echo " tail logs: ${TAIL_LOGS}"
echo " preload wait: ${PRELOAD_WAIT}"
if [[ "${PRELOAD_WAIT}" == "1" ]]; then
echo " preload signal: ${PRELOAD_SIGNAL_FILE}"
echo " start command: touch ${PRELOAD_SIGNAL_FILE}"
fi
echo " local log: ${LOCAL_LOG}"
echo " remote log: ${REMOTE_LOG}"
if [[ "${DRY_RUN}" == "1" ]]; then
echo
echo "Remote command:"
echo "ssh ${SSH_OPTS[*]} ${REMOTE_HOST} bash -lc ${REMOTE_CMD_Q}"
echo
echo "Local command:"
printf -v LOCAL_CMD_Q '%q' "${LOCAL_CMD}"
echo "bash -lc ${LOCAL_CMD_Q}"
exit 0
fi
if [[ "${PRELOAD_WAIT}" == "1" && "${PRELOAD_CLEAR_SIGNAL}" == "1" ]]; then
rm -f "${PRELOAD_SIGNAL_FILE}"
fi
echo "Starting remote node_rank=1..."
ssh "${SSH_OPTS[@]}" "${REMOTE_HOST}" "bash -lc ${REMOTE_CMD_Q}" >"${REMOTE_LOG}" 2>&1 &
REMOTE_PID="$!"
if [[ "${TAIL_LOGS}" == "1" ]]; then
start_log_tail "node2" "${REMOTE_LOG}"
REMOTE_TAIL_PID="$!"
fi
sleep 3
echo "Starting local node_rank=0..."
set +e
setsid bash -lc "${LOCAL_CMD}" >"${LOCAL_LOG}" 2>&1 &
LOCAL_PID="$!"
if [[ "${TAIL_LOGS}" == "1" ]]; then
start_log_tail "node3" "${LOCAL_LOG}"
LOCAL_TAIL_PID="$!"
fi
wait "${LOCAL_PID}"
LOCAL_STATUS="$?"
wait "${REMOTE_PID}"
REMOTE_STATUS="$?"
set -e
stop_log_tails
trap - INT TERM EXIT
echo "Local status: ${LOCAL_STATUS}"
echo "Remote status: ${REMOTE_STATUS}"
echo "Local log: ${LOCAL_LOG}"
echo "Remote log: ${REMOTE_LOG}"
if [[ "${LOCAL_STATUS}" -ne 0 || "${REMOTE_STATUS}" -ne 0 ]]; then
echo "Multi-node eval failed. Tail logs:"
echo "===== local tail ====="
tail -80 "${LOCAL_LOG}" || true
echo "===== remote tail ====="
tail -80 "${REMOTE_LOG}" || true
exit 1
fi
echo "Multi-node eval complete."

111
scripts/warmup_eval.sh Executable file
View File

@@ -0,0 +1,111 @@
#!/usr/bin/env bash
set -euo pipefail
# Warm up LeWM evaluation before a formal run.
#
# This script intentionally does a small eval for each task so ROCm/PyTorch can
# initialize GPU contexts, compile predictor graphs, populate kernel caches, and
# touch dataset/checkpoint paths before the timed run.
#
# Site-specific things to check before using this at the competition:
# 1. STABLEWM_HOME points to the directory containing datasets/checkpoints.
# 2. The policy names below match the checkpoint folders at STABLEWM_HOME.
# 3. The dataset names in config/eval/*.yaml match the onsite dataset files.
# 4. The GPU visibility variables match the GPUs allocated to this job.
# 5. WARMUP_NUM_EVAL is close enough to the formal shape to trigger useful
# compilation, but small enough not to waste much time.
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "${REPO_ROOT}"
PYTHON_BIN="${PYTHON_BIN:-${REPO_ROOT}/.venv/bin/python}"
STABLEWM_HOME="${STABLEWM_HOME:-/mnt/ASC1637/stablewm}"
export STABLEWM_HOME
# If Slurm allocates multiple GPUs, set these to the allocated physical GPU ids.
# Example for physical GPU 2 and 3:
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1
#
# Important ROCm detail:
# ROCR_VISIBLE_DEVICES uses physical ids.
# HIP_VISIBLE_DEVICES/CUDA_VISIBLE_DEVICES use ids after ROCR remapping.
export ROCR_VISIBLE_DEVICES="${ROCR_VISIBLE_DEVICES:-0}"
export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-0}"
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
WARMUP_NUM_EVAL="${WARMUP_NUM_EVAL:-10}"
INFERENCE_PRECISION="${INFERENCE_PRECISION:-fp16}"
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/lewm_warmup}"
mkdir -p "${OUTPUT_DIR}"
# Enable multi-GPU warmup by setting MULTI_GPU=1.
# MULTI_GPU_DEVICES are process-local ids, not physical ids after ROCR remapping.
# Example:
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]'
MULTI_GPU="${MULTI_GPU:-0}"
MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}"
MULTI_NODE="${MULTI_NODE:-0}"
# Multi-node warmup uses the same eval.py entrypoint under torchrun.
# Example:
# torchrun --nnodes=2 --nproc_per_node=2 --node_rank=0 --master_addr=<ip> --master_port=29500 \
# eval.py --config-name=pusht.yaml policy=pusht/lewm multi_node.enabled=true
# This script leaves multi-node launch to the caller.
COMMON_ARGS=(
"eval.num_eval=${WARMUP_NUM_EVAL}"
"inference_precision=${INFERENCE_PRECISION}"
)
if [[ "${MULTI_GPU}" == "1" ]]; then
COMMON_ARGS+=(
"+multi_gpu.enabled=true"
"+multi_gpu.devices=${MULTI_GPU_DEVICES}"
)
fi
if [[ "${MULTI_NODE}" == "1" ]]; then
COMMON_ARGS+=(
"multi_node.enabled=true"
)
fi
run_warmup() {
local config_name="$1"
local policy="$2"
local output_name="$3"
echo
echo "== Warmup ${config_name} policy=${policy} =="
"${PYTHON_BIN}" eval.py \
"--config-name=${config_name}" \
"policy=${policy}" \
"output.filename=${OUTPUT_DIR}/${output_name}" \
"${COMMON_ARGS[@]}"
}
echo "LeWM warmup"
echo " repo: ${REPO_ROOT}"
echo " python: ${PYTHON_BIN}"
echo " STABLEWM_HOME: ${STABLEWM_HOME}"
echo " ROCR_VISIBLE_DEVICES: ${ROCR_VISIBLE_DEVICES}"
echo " HIP_VISIBLE_DEVICES: ${HIP_VISIBLE_DEVICES}"
echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}"
echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}"
echo " MULTI_GPU: ${MULTI_GPU}"
if [[ "${MULTI_GPU}" == "1" ]]; then
echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}"
fi
echo " MULTI_NODE: ${MULTI_NODE}"
# Defaults match the checkpoint names used in this repo. If onsite checkpoint
# folders differ, override by editing these calls or passing the equivalent
# eval.py command manually.
run_warmup "pusht.yaml" "pusht/lewm" "warmup_pusht.txt"
run_warmup "reacher.yaml" "reacher/lewm" "warmup_reacher.txt"
run_warmup "cube.yaml" "cube/lewm" "warmup_cube.txt"
run_warmup "tworoom.yaml" "tworoom/lewm" "warmup_tworoom.txt"
echo
echo "Warmup complete. Logs were appended under ${OUTPUT_DIR}."

52
sth.md Normal file
View File

@@ -0,0 +1,52 @@
我建议优先做这 4 类,都是跨数据集成立的:
1. 压 rollout 内环实现
见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规
模 predict 调用,这种碎片化实现对任何任务都亏。
通用改法:
- 整条 action_sequence 一次性做 action_encoder
- emb_hist / act_emb_hist 改成预分配 buffer
- 循环里只做索引覆盖或 copy_
- 去掉循环内 torch.cat
2. 减少热路径里的搬运和同步
profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看
jepa.py:67 和 jepa.py:186。
通用目标:
- 模型侧张量尽量全程留在 GPU
- 避免热路径反复 .to(device) / 隐式 layout 修复
- 到必须和环境交互的边界再一次性转 CPU / numpy
- 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
3. 把编译成本移出正式计时
现在 torch.compile 默认开在 predictor见 eval.py:70。102s -> 45s 很
像首轮编译预热。
通用做法:
- 在正式 start_time 前做一次 dummy predict 或 dummy rollout
- 保留只编译 predictor/predict不要编译整个 solver
4. 减少临时对象和 shape bookkeeping
这是所有任务都会受益的。
重点看:
- jepa.py:100 到 jepa.py:106
- jepa.py:143 到 jepa.py:148
方向是:
- 能循环外做的 reshape不放循环里
- 能原地更新,不新建张量
- 少做 dict 字段增删和中间容器组装
不建议优先做的通用性较差方案:
- 调 TwoRoom 专属 cache 规则
- 改数据集采样逻辑
- 按小数据集特点缩短 horizon
- 直接改 CEM 超参当“优化”
如果你要我直接开始改,我建议第一批只做两件事:
- 重写 jepa.py:127 这段 rollout去掉循环内 action_encoder + cat
- 在 eval.py:306 前加 compile warmup