diff --git a/.gitignore b/.gitignore index 0c5b18d..682f75a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,18 @@ +.venv/* +!.venv/.gitignore +!.venv/lib/ +.venv/lib/* +!.venv/lib/python3.10/ +.venv/lib/python3.10/* +!.venv/lib/python3.10/site-packages/ +.venv/lib/python3.10/site-packages/* +!.venv/lib/python3.10/site-packages/stable_worldmodel/ +.venv/lib/python3.10/site-packages/stable_worldmodel/* +!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/ +!.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py +!.venv/lib/python3.10/site-packages/stable_worldmodel/world.py +!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py +!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py outputs/ __pycache__/ *.py[cod] diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/__init__.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/__init__.py new file mode 100644 index 0000000..bcfb0cc --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/__init__.py @@ -0,0 +1,17 @@ +from .cem import CEMSolver +from .gd import GradientSolver +from .icem import ICEMSolver +from .lagrangian import LagrangianSolver +from .mppi import MPPISolver +from .solver import Solver +from .discrete_solvers import PGDSolver + +__all__ = [ + 'Solver', + 'GradientSolver', + 'CEMSolver', + 'ICEMSolver', + 'PGDSolver', + 'MPPISolver', + 'LagrangianSolver', +] diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/discrete_solvers.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/discrete_solvers.py new file mode 100644 index 0000000..5dc97de --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/discrete_solvers.py @@ -0,0 +1,256 @@ +"""Projected Gradient Descent solver for discrete action spaces.""" + +import time +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Discrete + +from .solver import Costable + + +class PGDSolver(torch.nn.Module): + """Projected Gradient Descent solver for discrete action optimization. + + Args: + model: World model implementing the Costable protocol. + n_steps: Number of gradient descent iterations. + batch_size: Number of environments to process in parallel. + var_scale: Initial variance scale for action perturbations. + num_samples: Number of action samples to optimize in parallel. + action_noise: Noise added to actions during optimization. + device: Device for tensor computations. + seed: Random seed for reproducibility. + """ + + def __init__( + self, + model: Costable, + n_steps: int, + batch_size: int | None = None, + var_scale: float = 1, + num_samples: int = 1, + action_noise: float = 0.0, + device: str | torch.device = "cpu", + seed: int = 1234, + ) -> None: + super().__init__() + self.model = model + self.n_steps = n_steps + self.batch_size = batch_size + self.num_samples = num_samples + self.var_scale = var_scale + self.action_noise = action_noise + self.device = device + self.torch_gen = torch.Generator(device=device).manual_seed(seed) + + self._configured = False + self._n_envs = None + self._action_dim = None + self._action_simplex_dim = None + self._config = None + + def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: + """Configure the solver with environment specifications.""" + assert isinstance(action_space, Discrete), f"Action space must be discrete, got {type(action_space)}" + + self._action_space = action_space + self._n_envs = n_envs + self._config = config + self._action_dim = int(np.prod(action_space.shape[1:])) + self._action_simplex_dim = int(action_space.n) + self._configured = True + + @property + def n_envs(self) -> int: + """Number of parallel environments.""" + return self._n_envs + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + return self._action_dim * self._config.action_block + + @property + def action_simplex_dim(self) -> int: + """Simplex dimension for discrete action probabilities.""" + return self._action_simplex_dim * self._config.action_block + + @property + def horizon(self) -> int: + """Planning horizon in timesteps.""" + return self._config.horizon + + def __call__(self, *args: Any, **kwargs: Any) -> dict: + """Make solver callable, forwarding to solve().""" + return self.solve(*args, **kwargs) + + def init_action( + self, actions: torch.Tensor | None = None, from_scalar: bool = False + ) -> None: + """Initialize the action tensor for optimization.""" + if actions is None: + actions = torch.zeros((self._n_envs, 0, self.action_simplex_dim)) + elif from_scalar: + # convert scalar to one-hot + actions = torch.nn.functional.one_hot(actions, num_classes=self._action_simplex_dim).to(torch.float32) + # merge action_block dim + actions = actions.reshape(*actions.shape[:-2], self.action_simplex_dim) + assert ( + actions.shape[0] == self._n_envs + and actions.shape[1] <= self.horizon + and actions.shape[2] == self.action_simplex_dim + ) + + # fill remaining action + remaining = self.horizon - actions.shape[1] + + if remaining > 0: + new_actions = torch.zeros(self._n_envs, remaining, self.action_simplex_dim) + actions = torch.cat([actions, new_actions], dim=1).to(self.device) + + actions = actions.unsqueeze(1).repeat_interleave(self.num_samples, dim=1) # add sample dim + actions[:, 1:] += ( + torch.randn(actions[:, 1:].shape, generator=self.torch_gen, device=self.device) * self.var_scale + ) # add small noise to all samples except the first one + + # reset actions + if hasattr(self, "init"): + self.init.copy_(actions) + else: + self.register_parameter("init", torch.nn.Parameter(actions)) + + def solve( + self, + info_dict: dict, + init_action: torch.Tensor | None = None, + from_scalar: bool = False, + ) -> dict: + """Solve the planning problem using projected gradient descent.""" + start_time = time.time() + outputs = { + "cost": [], # Will store list of cost histories per batch + "actions": None, + } + + with torch.no_grad(): + self.init_action(init_action, from_scalar=from_scalar) + + # Determine batch size (default to all envs if not specified which can cause memory issues) + batch_size = self.batch_size if self.batch_size is not None else self.n_envs + total_envs = self.n_envs + + # Lists to hold results from each batch to be concatenated later + batch_top_actions_list = [] + + # --- Outer Loop: Iterate over batches --- + for start_idx in range(0, total_envs, batch_size): + end_idx = min(start_idx + batch_size, total_envs) + current_bs = end_idx - start_idx + + batch_init = self.init[start_idx:end_idx].clone().detach() + batch_init.requires_grad = True + + optim = torch.optim.SGD([batch_init], lr=1.0) + + # Prepare Batch Infos + # Slice the input info_dict and then expand dimensions + expanded_infos = {} + for k, v in info_dict.items(): + # Slice the data for the current batch indices + # Assumes input data dim 0 corresponds to n_envs + if torch.is_tensor(v): + batch_v = v[start_idx:end_idx] + batch_v = batch_v.unsqueeze(1) + batch_v = batch_v.expand(current_bs, self.num_samples, *batch_v.shape[2:]) + elif isinstance(v, np.ndarray): + batch_v = v[start_idx:end_idx] + batch_v = np.repeat(batch_v[:, None, ...], self.num_samples, axis=1) + expanded_infos[k] = batch_v + + # Perform Gradient Descent for this batch + batch_cost_history = [] + + for step in range(self.n_steps): + current_info = expanded_infos.copy() + + # Calculate cost using the batch parameter + costs = self.model.get_cost(current_info, batch_init) + + assert isinstance(costs, torch.Tensor), f"Got {type(costs)} cost, expect torch.Tensor" + assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, ( + f"Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}" + ) + assert costs.requires_grad, "Cost must requires_grad for PGD solver." + + cost = costs.sum() # Sum cost for this batch + cost.backward() + optim.step() + optim.zero_grad(set_to_none=True) + + # Add noise + if self.action_noise > 0: + batch_init.data += torch.randn(batch_init.shape, generator=self.torch_gen) * self.action_noise + + # projection onto simplex + with torch.no_grad(): + batch_init.copy_(self._project_action_simplex(batch_init)) + + batch_cost_history.append(cost.item()) + + # Store cost history for this batch + outputs["cost"].append(batch_cost_history) + + # Update the global self.init with the optimized batch values + with torch.no_grad(): + self.init[start_idx:end_idx] = batch_init + + top_idx = torch.argsort(costs, dim=1)[:, 0] + batch_indices = torch.arange(current_bs) + + top_actions_batch = batch_init[batch_indices, top_idx] + + # convert one-hot back to discrete actions + top_actions_batch = self._factor_action_block(top_actions_batch).argmax(dim=-1) + batch_top_actions_list.append(top_actions_batch.detach().cpu()) + + # Concatenate all batch results + outputs["actions"] = torch.cat(batch_top_actions_list, dim=0) + end_time = time.time() + print(f"PGDSolver.solve completed in {end_time - start_time:.4f} seconds.") + + return outputs + + def _factor_action_block(self, actions: torch.Tensor) -> torch.Tensor: + """Factor the action block dimension from action_simplex_dim.""" + original_shape = actions.shape + action_block = self._config.action_block + simplex_dim = self._action_simplex_dim + return actions.reshape(*original_shape[:-1], action_block, simplex_dim) + + def _project_action_simplex(self, actions: torch.Tensor) -> torch.Tensor: + """Project the action onto the probability simplex.""" + original_shape = actions.shape + + s = self._factor_action_block(actions).reshape(-1, self._action_simplex_dim) + + mu, _ = torch.sort(s, descending=True, dim=-1) + cumulative = mu.cumsum(dim=-1) + + d = s.size(-1) + indices = torch.arange(1, d + 1, device=s.device, dtype=s.dtype) + + threshold = (cumulative - 1) / indices + + cond = (mu > threshold).to(torch.int32) + rho = cond.cumsum(dim=-1) + valid_rho = rho * cond + rho_max = valid_rho.max(dim=-1, keepdim=True)[0] + + rho_min = torch.clamp(rho_max, min=1) + psi = (cumulative.gather(-1, rho_min - 1) - 1) / rho_min + + projected = torch.clamp(s - psi, min=0.0).reshape(original_shape) + return projected diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/icem.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/icem.py new file mode 100644 index 0000000..19f3877 --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/icem.py @@ -0,0 +1,219 @@ +"""Improved Cross Entropy Method (iCEM) solver for model-based planning.""" + +import time +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Box +from loguru import logger as logging + +from .solver import Costable + + +class ICEMSolver: + """Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention. + iCEM improves the sample efficiency over standard CEM and was introduced by + [1] for real-time planning. + + Args: + model: World model implementing the Costable protocol. + batch_size: Number of environments to process in parallel. + num_samples: Number of action candidates to sample per iteration. + var_scale: Initial variance scale for the action distribution. + n_steps: Number of CEM iterations. + topk: Number of elite samples to keep for distribution update. + noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise. + alpha: Momentum for mean/std EMA update. + n_elite_keep: Number of elites carried from previous iteration. + return_mean: If False, return best single trajectory instead of mean. + device: Device for tensor computations. + seed: Random seed for reproducibility. + + [1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and + G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning". + Conference on Robot Learning, 2020. + """ + + def __init__( + self, + model: Costable, + batch_size: int = 1, + num_samples: int = 300, + var_scale: float = 1, + n_steps: int = 30, + topk: int = 30, + noise_beta: float = 2.0, + alpha: float = 0.1, + n_elite_keep: int = 5, + return_mean: bool = True, + device: str | torch.device = "cpu", + seed: int = 1234, + ) -> None: + self.model = model + self.batch_size = batch_size + self.var_scale = var_scale + self.num_samples = num_samples + self.n_steps = n_steps + self.topk = topk + self.noise_beta = noise_beta + self.alpha = alpha + self.n_elite_keep = n_elite_keep + self.return_mean = return_mean + self.device = device + self.torch_gen = torch.Generator(device=device).manual_seed(seed) + + def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: + """Configure the solver with environment specifications.""" + self._action_space = action_space + self._n_envs = n_envs + self._config = config + self._action_dim = int(np.prod(action_space.shape[1:])) + self._configured = True + + if isinstance(action_space, Box): + self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32) + self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32) + else: + logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.") + self._action_low = None + self._action_high = None + + @property + def n_envs(self) -> int: + """Number of parallel environments.""" + return self._n_envs + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + return self._action_dim * self._config.action_block + + @property + def horizon(self) -> int: + """Planning horizon in timesteps.""" + return self._config.horizon + + def __call__(self, *args: Any, **kwargs: Any) -> dict: + """Make solver callable, forwarding to solve().""" + return self.solve(*args, **kwargs) + + def init_action_distrib( + self, actions: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Initialize the action distribution parameters (mean and variance).""" + var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim]) + mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions + + remaining = self.horizon - mean.shape[1] + if remaining > 0: + device = mean.device + new_mean = torch.zeros([self.n_envs, remaining, self.action_dim]) + mean = torch.cat([mean, new_mean], dim=1).to(device) + + return mean, var + + @torch.inference_mode() + def solve( + self, info_dict: dict, init_action: torch.Tensor | None = None + ) -> dict: + """Solve the planning problem using improved Cross Entropy Method.""" + start_time = time.time() + outputs = { + "costs": [], + "mean": [], + "var": [], + } + + mean, var = self.init_action_distrib(init_action) + mean = mean.to(self.device) + var = var.to(self.device) + + for start_idx in range(0, self.n_envs, self.batch_size): + end_idx = min(start_idx + self.batch_size, self.n_envs) + current_bs = end_idx - start_idx + + batch_mean = mean[start_idx:end_idx] + batch_var = var[start_idx:end_idx] + + expanded_infos = {} + for k, v in info_dict.items(): + v_batch = v[start_idx:end_idx] + if torch.is_tensor(v): + v_batch = v_batch.unsqueeze(1) + v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:]) + elif isinstance(v, np.ndarray): + v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1) + expanded_infos[k] = v_batch + + prev_topk_candidates = None + batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk) + + # Precompute FFT scale for colored noise + noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon) + freqs = torch.fft.rfftfreq(self.horizon, device=self.device) + freqs[0] = 1.0 + noise_scale = freqs.pow(-self.noise_beta / 2) + noise_scale[0] = noise_scale[1] + + for step in range(self.n_steps): + # Colored noise: generate with temporal axis last, then transpose + if self.horizon <= 1: + noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device) + else: + white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device) + fft = torch.fft.rfft(white, dim=-1) + colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1) + std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8) + noise = colored / std + noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim) + + candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1) + candidates[:, 0] = batch_mean + + # Inject previous elites + if prev_topk_candidates is not None: + n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1]) + candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject] + + # Clip to action bounds + if self._action_low is not None: + candidates = candidates.clamp(self._action_low, self._action_high) + + current_info = expanded_infos.copy() + costs = self.model.get_cost(current_info, candidates) + + assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}" + assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, ( + f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}" + ) + + topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False) + topk_candidates = candidates[batch_indices, topk_inds] + + prev_topk_candidates = topk_candidates + + # Momentum update + elite_mean = topk_candidates.mean(dim=1) + elite_var = topk_candidates.std(dim=1) + batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean + batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var + + final_batch_cost = topk_vals.mean(dim=1).cpu().tolist() + + if self.return_mean: + mean[start_idx:end_idx] = batch_mean + else: + mean[start_idx:end_idx] = topk_candidates[:, 0] + + var[start_idx:end_idx] = batch_var + + outputs["costs"].extend(final_batch_cost) + + outputs["actions"] = mean.detach().cpu() + outputs["mean"] = [mean.detach().cpu()] + outputs["var"] = [var.detach().cpu()] + + print(f"iCEM solve time: {time.time() - start_time:.4f} seconds") + return outputs diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/lagrangian.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/lagrangian.py new file mode 100644 index 0000000..288c485 --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/lagrangian.py @@ -0,0 +1,360 @@ +"""Lagrangian solver for stable world model.""" + +import time +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +import torch.nn.functional as F +from gymnasium.spaces import Box +from loguru import logger as logging + +from .solver import Costable + + +class LagrangianSolver(torch.nn.Module): + """Lagrangian solver for stable world model. + + get_cost returns the cost tensor (B, S). If the model also implements get_constraints, + it should return the constraint violations (B, S, C), where C is the number of constraints. + The constraint_cost should represent the cost of violating the constraints, where the constraint + is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective: + + L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2 + + If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0. + + Args: + model: World model implementing the Costable protocol. Its get_cost() returns + a plain cost tensor (B, S). If it also has get_constraints(), that method + returns constraints of shape (B, S, C). + n_steps: Number of gradient descent steps per outer iteration. + n_outer_steps: Number of dual ascent (outer) iterations. + batch_size: Number of environments to process in parallel. + num_samples: Number of action samples to optimize in parallel. + var_scale: Initial variance scale for action perturbations. + action_noise: Noise added to actions during optimization. + rho_init: Initial penalty coefficient for the quadratic constraint term. + rho_max: Maximum value of the penalty coefficient. + rho_scale: Multiplicative growth factor for rho after each outer step. + persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls. + device: Device for tensor computations. + seed: Random seed for reproducibility. + optimizer_cls: PyTorch optimizer class to use. + optimizer_kwargs: Keyword arguments for the optimizer. + """ + + def __init__( + self, + model: Costable, + n_steps: int, + n_outer_steps: int = 5, + batch_size: int | None = None, + num_samples: int = 1, + var_scale: float = 1.0, + action_noise: float = 0.0, + rho_init: float = 1.0, + rho_max: float = 1e4, + rho_scale: float = 2.0, + persist_multipliers: bool = True, + device: str | torch.device = 'cpu', + seed: int = 1234, + optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: dict | None = None, + ) -> None: + super().__init__() + self.model = model + self.n_steps = n_steps + self.n_outer_steps = n_outer_steps + self.batch_size = batch_size + self.num_samples = num_samples + self.var_scale = var_scale + self.action_noise = action_noise + self.rho_init = rho_init + self.rho_max = rho_max + self.rho_scale = rho_scale + self.persist_multipliers = persist_multipliers + self.device = device + self.torch_gen = torch.Generator(device=device).manual_seed(seed) + self.optimizer_cls = optimizer_cls + self.optimizer_kwargs = ( + optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0} + ) + + self._configured = False + self._n_envs = None + self._action_dim = None + self._config = None + self._lambdas: torch.Tensor | None = None # (n_envs, C) + + def configure( + self, *, action_space: gym.Space, n_envs: int, config: Any + ) -> None: + """Configure the solver with environment specifications.""" + self._action_space = action_space + self._n_envs = n_envs + self._config = config + self._action_dim = int(np.prod(action_space.shape[1:])) + self._configured = True + + if not isinstance(action_space, Box): + logging.warning( + f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.' + ) + + @property + def n_envs(self) -> int: + """Number of parallel environments.""" + return self._n_envs + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + return self._action_dim * self._config.action_block + + @property + def horizon(self) -> int: + """Planning horizon in timesteps.""" + return self._config.horizon + + def __call__(self, *args: Any, **kwargs: Any) -> dict: + """Make solver callable, forwarding to solve().""" + return self.solve(*args, **kwargs) + + def init_action(self, actions: torch.Tensor | None = None) -> None: + """Initialize the action tensor for optimization.""" + if actions is None: + actions = torch.zeros((self._n_envs, 0, self.action_dim)) + + remaining = self.horizon - actions.shape[1] + if remaining > 0: + new_actions = torch.zeros(self._n_envs, remaining, self.action_dim) + actions = torch.cat([actions, new_actions], dim=1).to(self.device) + + actions = actions.unsqueeze(1).repeat_interleave( + self.num_samples, dim=1 + ) + actions[:, 1:] += ( + torch.randn( + actions[:, 1:].shape, + generator=self.torch_gen, + device=self.device, + ) + * self.var_scale + ) + + if hasattr(self, 'init'): + self.init.copy_(actions) + else: + self.register_parameter('init', torch.nn.Parameter(actions)) + + def _init_multipliers(self, num_constraints: int) -> None: + """Lazily initialize Lagrange multipliers to zeros.""" + self._lambdas = torch.zeros( + self._n_envs, num_constraints, device=self.device + ) + + def _augmented_lagrangian_loss( + self, + costs: torch.Tensor, # (B, S) + constraints: torch.Tensor, # (B, S, C) + lambdas_batch: torch.Tensor, # (B, C) + rho: float, + ) -> torch.Tensor: + """Compute the augmented Lagrangian loss. + + L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2 + """ + # lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C) + linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum( + dim=-1 + ) # (B, S) + quadratic_penalty = rho * F.relu(constraints).pow(2).sum( + dim=-1 + ) # (B, S) + return (costs + linear_penalty + quadratic_penalty).sum() + + def _update_multipliers( + self, + constraints: torch.Tensor, # (B, S, C) — detached, no grad + lambdas_batch: torch.Tensor, # (B, C) + rho: float, + ) -> torch.Tensor: + """Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i)).""" + mean_g = constraints.mean(dim=1) # (B, C) + return torch.clamp(lambdas_batch + rho * mean_g, min=0.0) + + def solve( + self, info_dict: dict, init_action: torch.Tensor | None = None + ) -> dict: + """Solve the planning problem using augmented Lagrangian gradient descent.""" + start_time = time.time() + outputs: dict = { + 'cost': [], + 'constraint_violation': [], + 'actions': None, + 'lambdas': None, + } + + with torch.no_grad(): + self.init_action(init_action) + + if not self.persist_multipliers: + self._lambdas = None + + batch_size = ( + self.batch_size if self.batch_size is not None else self.n_envs + ) + total_envs = self.n_envs + batch_top_actions_list = [] + + for start_idx in range(0, total_envs, batch_size): + end_idx = min(start_idx + batch_size, total_envs) + current_bs = end_idx - start_idx + + batch_init = self.init[start_idx:end_idx].clone().detach() + batch_init.requires_grad = True + + # Expand info_dict for current batch — same pattern as GradientSolver + expanded_infos = {} + for k, v in info_dict.items(): + if torch.is_tensor(v): + batch_v = v[start_idx:end_idx] + batch_v = batch_v.unsqueeze(1) + batch_v = batch_v.expand( + current_bs, self.num_samples, *batch_v.shape[2:] + ) + elif isinstance(v, np.ndarray): + batch_v = v[start_idx:end_idx] + batch_v = np.repeat( + batch_v[:, None, ...], self.num_samples, axis=1 + ) + else: + batch_v = v + expanded_infos[k] = batch_v + + rho = self.rho_init + batch_cost_history = [] + costs = None + final_constraints = None + + for _outer in range(self.n_outer_steps): + # Fresh optimizer each outer step — avoids stale momentum after dual ascent + optim = self.optimizer_cls( + [batch_init], **self.optimizer_kwargs + ) + + for _step in range(self.n_steps): + current_info = expanded_infos.copy() + costs = self.model.get_cost(current_info, batch_init) + constraints = ( + self.model.get_constraints( + expanded_infos.copy(), batch_init + ) + if hasattr(self.model, 'get_constraints') + else None + ) + + assert isinstance(costs, torch.Tensor), ( + f'Got {type(costs)} cost, expect torch.Tensor' + ) + assert costs.ndim == 2 and costs.shape == ( + current_bs, + self.num_samples, + ), ( + f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}' + ) + assert costs.requires_grad, ( + 'Cost must requires_grad for LagrangianSolver.' + ) + + if constraints is not None: + assert constraints.ndim == 3 and constraints.shape[ + :2 + ] == (current_bs, self.num_samples), ( + f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}' + ) + if self._lambdas is None: + self._init_multipliers(constraints.shape[-1]) + lambdas_batch = self._lambdas[start_idx:end_idx] + loss = self._augmented_lagrangian_loss( + costs, constraints, lambdas_batch, rho + ) + else: + loss = costs.sum() + + loss.backward() + optim.step() + optim.zero_grad(set_to_none=True) + + if self.action_noise > 0: + batch_init.data += ( + torch.randn( + batch_init.shape, generator=self.torch_gen + ) + * self.action_noise + ) + + batch_cost_history.append(loss.item()) + + # Dual ascent after inner loop converges + if constraints is not None: + with torch.no_grad(): + final_constraints = self.model.get_constraints( + expanded_infos.copy(), batch_init + ) + lambdas_batch = self._update_multipliers( + final_constraints, lambdas_batch, rho + ) + self._lambdas[start_idx:end_idx] = lambdas_batch + rho = min(self.rho_max, rho * self.rho_scale) + + with torch.no_grad(): + mean_cost = costs.mean().item() + if constraints is not None: + viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,) + lam = lambdas_batch.mean(dim=0) # (C,) + viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist()) + lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist()) + print( + f' [outer {_outer+1}/{self.n_outer_steps}] ' + f'cost={mean_cost:.4f} | ' + f'constraint_viol=[{viol_str}] | ' + f'lambdas=[{lam_str}] | ' + f'rho={rho:.4f}' + ) + else: + print( + f' [outer {_outer+1}/{self.n_outer_steps}] ' + f'cost={mean_cost:.4f}' + ) + + outputs['cost'].append(batch_cost_history) + + if final_constraints is not None: + outputs['constraint_violation'].append( + F.relu(final_constraints).mean().item() + ) + + with torch.no_grad(): + self.init[start_idx:end_idx] = batch_init + + top_idx = torch.argsort(costs, dim=1)[:, 0] + batch_indices = torch.arange(current_bs) + top_actions_batch = batch_init[batch_indices, top_idx] + batch_top_actions_list.append(top_actions_batch.detach().cpu()) + + outputs['actions'] = torch.cat(batch_top_actions_list, dim=0) + outputs['lambdas'] = ( + self._lambdas.cpu() if self._lambdas is not None else None + ) + + constraint_info = '' + if outputs['constraint_violation']: + mean_viol = np.mean(outputs['constraint_violation']) + constraint_info = f' | constraint_violation={mean_viol:.4f}' + print( + f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.' + ) + return outputs diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/mppi.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/mppi.py new file mode 100644 index 0000000..8f48198 --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/mppi.py @@ -0,0 +1,208 @@ +"""Model Predictive Path Integral solver for model-based planning.""" + +import time +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Box +from loguru import logger as logging + +from .solver import Costable + + +class MPPISolver: + """Model Predictive Path Integral solver for action optimization. + + Args: + model: World model implementing the Costable protocol. + batch_size: Number of environments to process in parallel. + num_samples: Number of action candidates to sample per iteration. + var_scale: Initial variance scale for action noise. + n_steps: Number of MPPI iterations. + topk: Number of elite samples for weighted averaging. + temperature: Temperature parameter for softmax weighting. + device: Device for tensor computations. + seed: Random seed for reproducibility. + """ + + def __init__( + self, + model: Costable, + batch_size: int = 1, + num_samples: int = 300, + var_scale: float = 1.0, + n_steps: int = 30, + topk: int = 30, + temperature: float = 0.5, + device: str | torch.device = "cpu", + seed: int = 1234, + ) -> None: + self.model = model + self.batch_size = batch_size + self.num_samples = num_samples + self.topk = topk + self.var_scale = var_scale + self.n_steps = n_steps + self.temperature = temperature + self.device = device + self.torch_gen = torch.Generator(device=device).manual_seed(seed) + + def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: + """Configure the solver with environment specifications.""" + self._action_space = action_space + self._n_envs = n_envs + self._config = config + self._action_dim = int(np.prod(action_space.shape[1:])) + self._configured = True + + if not isinstance(action_space, Box): + logging.warning( + f"Action space is discrete, got {type(action_space)}. MPPISolver may not work as expected." + ) + + @property + def n_envs(self) -> int: + """Number of parallel environments.""" + return self._n_envs + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + return self._action_dim * self._config.action_block + + @property + def horizon(self) -> int: + """Planning horizon in timesteps.""" + return self._config.horizon + + def __call__(self, *args: Any, **kwargs: Any) -> dict: + """Make solver callable, forwarding to solve().""" + return self.solve(*args, **kwargs) + + def init_action_distrib( + self, actions: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Initialize the action distribution parameters (mean and variance).""" + var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim]) + mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions + + remaining = self.horizon - mean.shape[1] + if remaining > 0: + device = mean.device + new_mean = torch.zeros([self.n_envs, remaining, self.action_dim]) + mean = torch.cat([mean, new_mean], dim=1).to(device) + + return mean, var + + @torch.inference_mode() + def solve( + self, info_dict: dict, init_action: torch.Tensor | None = None + ) -> dict: + """Solve the planning problem using MPPI.""" + start_time = time.time() + outputs = { + "costs": [], + "mean": [], + "var": [], + } + + # -- initialize the action distribution globally + mean, var = self.init_action_distrib(init_action) + mean = mean.to(self.device) + var = var.to(self.device) + + total_envs = self.n_envs + + # --- Iterate over batches --- + for start_idx in range(0, total_envs, self.batch_size): + end_idx = min(start_idx + self.batch_size, total_envs) + current_bs = end_idx - start_idx + + # Slice Distribution Parameters for current batch + batch_mean = mean[start_idx:end_idx] + batch_var = var[start_idx:end_idx] + + # Expand Info Dict for current batch (Same as CEM) + expanded_infos = {} + for k, v in info_dict.items(): + v_batch = v[start_idx:end_idx] + if torch.is_tensor(v): + # Add sample dim: (batch, 1, ...) + v_batch = v_batch.unsqueeze(1) + # Expand: (batch, num_samples, ...) + v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:]) + elif isinstance(v, np.ndarray): + v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1) + expanded_infos[k] = v_batch + + # Optimization Loop + final_batch_cost = None + + for step in range(self.n_steps): + # Sample noise: (Batch, Num_Samples, Horizon, Dim) + noise = torch.randn( + current_bs, + self.num_samples, + self.horizon, + self.action_dim, + generator=self.torch_gen, + device=self.device, + ) + + # MPPI Logic: candidates = mean + noise * sigma + candidates = batch_mean.unsqueeze(1) + noise * batch_var.unsqueeze(1) + + # Force the first sample to be the current mean (Zero noise) + candidates[:, 0] = batch_mean + + # Evaluate candidates + costs = self.model.get_cost(expanded_infos, candidates) + + assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}" + assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, ( + f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}" + ) + + # Select Elites (Optional, based on topk) + if self.topk is not None and self.topk < self.num_samples: + # topk_vals: (Batch, K), topk_inds: (Batch, K) + topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False) + + # Gather Top-K Candidates + batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk) + # (Batch, K, Horizon, Dim) + relevant_candidates = candidates[batch_indices, topk_inds] + relevant_costs = topk_vals + else: + relevant_candidates = candidates + relevant_costs = costs + + # MPPI Weighting: Softmax(-cost / temperature) + # Stabilize softmax by subtracting min cost + min_cost = relevant_costs.min(dim=1, keepdim=True)[0] + scaled_costs = relevant_costs - min_cost + weights = torch.softmax(-scaled_costs / self.temperature, dim=1) # (Batch, K) + + # Update Mean: weighted sum of candidates + # Reshape weights for broadcasting: (Batch, K, 1, 1) + weights_expanded = weights.unsqueeze(-1).unsqueeze(-1) + batch_mean = (weights_expanded * relevant_candidates).sum(dim=1) + + # Store average cost of the utilized samples for logging + final_batch_cost = relevant_costs.mean(dim=1).cpu().tolist() + + # Write results back to global storage + mean[start_idx:end_idx] = batch_mean + # We do not update var in standard MPPI + + # Store history/metadata + outputs["costs"].extend(final_batch_cost) + + outputs["actions"] = mean.detach().cpu() + outputs["mean"] = [mean.detach().cpu()] + outputs["var"] = [var.detach().cpu()] + + print(f"MPPI solve time: {time.time() - start_time:.4f} seconds") + return outputs diff --git a/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py new file mode 100644 index 0000000..f12a62e --- /dev/null +++ b/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/solver.py @@ -0,0 +1,74 @@ +from typing import Any, Protocol, runtime_checkable + +import gymnasium as gym +import torch + + +class Costable(Protocol): + """Protocol for world model cost functions.""" + + def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: + """Compute the cost criterion for action candidates. + + Args: + info_dict: Dictionary containing environment state information. + action_candidates: Tensor of proposed actions. + + Returns: + A tensor of cost values for each action candidate. + """ + ... + + def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover + """Compute cost for given action candidates based on info dictionary. + + Args: + info_dict: Dictionary containing environment state information. + action_candidates: Tensor of proposed actions. + + Returns: + A tensor of cost values for each action candidate. + """ + ... + + +@runtime_checkable +class Solver(Protocol): + """Protocol for model-based planning solvers.""" + + def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None: + """Configure the solver with environment and planning specifications. + + Args: + action_space: The action space of the environment. + n_envs: Number of parallel environments. + config: Planning configuration object. + """ + ... + + @property + def action_dim(self) -> int: + """Flattened action dimension including action_block grouping.""" + ... + + @property + def n_envs(self) -> int: + """Number of parallel environments being planned for.""" + ... + + @property + def horizon(self) -> int: + """Planning horizon length in timesteps.""" + ... + + def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict: + """Solve the planning optimization problem to find optimal actions. + + Args: + info_dict: Dictionary containing environment state information. + init_action: Optional initial action sequence to warm-start the solver. + + Returns: + Dictionary containing optimized actions and other solver-specific info. + """ + ...