Files
lewm/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/icem.py
2026-05-04 08:05:47 +00:00

220 lines
8.9 KiB
Python

"""Improved Cross Entropy Method (iCEM) solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class ICEMSolver:
"""Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention.
iCEM improves the sample efficiency over standard CEM and was introduced by
[1] for real-time planning.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for the action distribution.
n_steps: Number of CEM iterations.
topk: Number of elite samples to keep for distribution update.
noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
alpha: Momentum for mean/std EMA update.
n_elite_keep: Number of elites carried from previous iteration.
return_mean: If False, return best single trajectory instead of mean.
device: Device for tensor computations.
seed: Random seed for reproducibility.
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and
G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning".
Conference on Robot Learning, 2020.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
noise_beta: float = 2.0,
alpha: float = 0.1,
n_elite_keep: int = 5,
return_mean: bool = True,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.noise_beta = noise_beta
self.alpha = alpha
self.n_elite_keep = n_elite_keep
self.return_mean = return_mean
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if isinstance(action_space, Box):
self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32)
self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32)
else:
logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.")
self._action_low = None
self._action_high = None
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using improved Cross Entropy Method."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [],
"var": [],
}
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
for start_idx in range(0, self.n_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, self.n_envs)
current_bs = end_idx - start_idx
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
expanded_infos = {}
for k, v in info_dict.items():
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
v_batch = v_batch.unsqueeze(1)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
prev_topk_candidates = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Precompute FFT scale for colored noise
noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon)
freqs = torch.fft.rfftfreq(self.horizon, device=self.device)
freqs[0] = 1.0
noise_scale = freqs.pow(-self.noise_beta / 2)
noise_scale[0] = noise_scale[1]
for step in range(self.n_steps):
# Colored noise: generate with temporal axis last, then transpose
if self.horizon <= 1:
noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
else:
white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
fft = torch.fft.rfft(white, dim=-1)
colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1)
std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8)
noise = colored / std
noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim)
candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
candidates[:, 0] = batch_mean
# Inject previous elites
if prev_topk_candidates is not None:
n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1])
candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject]
# Clip to action bounds
if self._action_low is not None:
candidates = candidates.clamp(self._action_low, self._action_high)
current_info = expanded_infos.copy()
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
topk_candidates = candidates[batch_indices, topk_inds]
prev_topk_candidates = topk_candidates
# Momentum update
elite_mean = topk_candidates.mean(dim=1)
elite_var = topk_candidates.std(dim=1)
batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean
batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
if self.return_mean:
mean[start_idx:end_idx] = batch_mean
else:
mean[start_idx:end_idx] = topk_candidates[:, 0]
var[start_idx:end_idx] = batch_var
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"iCEM solve time: {time.time() - start_time:.4f} seconds")
return outputs