调整ignore的文件

This commit is contained in:
qihuanye
2026-05-04 08:05:47 +00:00
parent cf43af0729
commit e84074d6d6
7 changed files with 1149 additions and 0 deletions

15
.gitignore vendored
View File

@@ -1,3 +1,18 @@
.venv/*
!.venv/.gitignore
!.venv/lib/
.venv/lib/*
!.venv/lib/python3.10/
.venv/lib/python3.10/*
!.venv/lib/python3.10/site-packages/
.venv/lib/python3.10/site-packages/*
!.venv/lib/python3.10/site-packages/stable_worldmodel/
.venv/lib/python3.10/site-packages/stable_worldmodel/*
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/
!.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
!.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
outputs/
__pycache__/
*.py[cod]

View File

@@ -0,0 +1,17 @@
from .cem import CEMSolver
from .gd import GradientSolver
from .icem import ICEMSolver
from .lagrangian import LagrangianSolver
from .mppi import MPPISolver
from .solver import Solver
from .discrete_solvers import PGDSolver
__all__ = [
'Solver',
'GradientSolver',
'CEMSolver',
'ICEMSolver',
'PGDSolver',
'MPPISolver',
'LagrangianSolver',
]

View File

@@ -0,0 +1,256 @@
"""Projected Gradient Descent solver for discrete action spaces."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Discrete
from .solver import Costable
class PGDSolver(torch.nn.Module):
"""Projected Gradient Descent solver for discrete action optimization.
Args:
model: World model implementing the Costable protocol.
n_steps: Number of gradient descent iterations.
batch_size: Number of environments to process in parallel.
var_scale: Initial variance scale for action perturbations.
num_samples: Number of action samples to optimize in parallel.
action_noise: Noise added to actions during optimization.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self._configured = False
self._n_envs = None
self._action_dim = None
self._action_simplex_dim = None
self._config = None
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
assert isinstance(action_space, Discrete), f"Action space must be discrete, got {type(action_space)}"
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._action_simplex_dim = int(action_space.n)
self._configured = True
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def action_simplex_dim(self) -> int:
"""Simplex dimension for discrete action probabilities."""
return self._action_simplex_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(
self, actions: torch.Tensor | None = None, from_scalar: bool = False
) -> None:
"""Initialize the action tensor for optimization."""
if actions is None:
actions = torch.zeros((self._n_envs, 0, self.action_simplex_dim))
elif from_scalar:
# convert scalar to one-hot
actions = torch.nn.functional.one_hot(actions, num_classes=self._action_simplex_dim).to(torch.float32)
# merge action_block dim
actions = actions.reshape(*actions.shape[:-2], self.action_simplex_dim)
assert (
actions.shape[0] == self._n_envs
and actions.shape[1] <= self.horizon
and actions.shape[2] == self.action_simplex_dim
)
# fill remaining action
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(self._n_envs, remaining, self.action_simplex_dim)
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
actions = actions.unsqueeze(1).repeat_interleave(self.num_samples, dim=1) # add sample dim
actions[:, 1:] += (
torch.randn(actions[:, 1:].shape, generator=self.torch_gen, device=self.device) * self.var_scale
) # add small noise to all samples except the first one
# reset actions
if hasattr(self, "init"):
self.init.copy_(actions)
else:
self.register_parameter("init", torch.nn.Parameter(actions))
def solve(
self,
info_dict: dict,
init_action: torch.Tensor | None = None,
from_scalar: bool = False,
) -> dict:
"""Solve the planning problem using projected gradient descent."""
start_time = time.time()
outputs = {
"cost": [], # Will store list of cost histories per batch
"actions": None,
}
with torch.no_grad():
self.init_action(init_action, from_scalar=from_scalar)
# Determine batch size (default to all envs if not specified which can cause memory issues)
batch_size = self.batch_size if self.batch_size is not None else self.n_envs
total_envs = self.n_envs
# Lists to hold results from each batch to be concatenated later
batch_top_actions_list = []
# --- Outer Loop: Iterate over batches ---
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
optim = torch.optim.SGD([batch_init], lr=1.0)
# Prepare Batch Infos
# Slice the input info_dict and then expand dimensions
expanded_infos = {}
for k, v in info_dict.items():
# Slice the data for the current batch indices
# Assumes input data dim 0 corresponds to n_envs
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(current_bs, self.num_samples, *batch_v.shape[2:])
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(batch_v[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = batch_v
# Perform Gradient Descent for this batch
batch_cost_history = []
for step in range(self.n_steps):
current_info = expanded_infos.copy()
# Calculate cost using the batch parameter
costs = self.model.get_cost(current_info, batch_init)
assert isinstance(costs, torch.Tensor), f"Got {type(costs)} cost, expect torch.Tensor"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
assert costs.requires_grad, "Cost must requires_grad for PGD solver."
cost = costs.sum() # Sum cost for this batch
cost.backward()
optim.step()
optim.zero_grad(set_to_none=True)
# Add noise
if self.action_noise > 0:
batch_init.data += torch.randn(batch_init.shape, generator=self.torch_gen) * self.action_noise
# projection onto simplex
with torch.no_grad():
batch_init.copy_(self._project_action_simplex(batch_init))
batch_cost_history.append(cost.item())
# Store cost history for this batch
outputs["cost"].append(batch_cost_history)
# Update the global self.init with the optimized batch values
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = torch.argsort(costs, dim=1)[:, 0]
batch_indices = torch.arange(current_bs)
top_actions_batch = batch_init[batch_indices, top_idx]
# convert one-hot back to discrete actions
top_actions_batch = self._factor_action_block(top_actions_batch).argmax(dim=-1)
batch_top_actions_list.append(top_actions_batch.detach().cpu())
# Concatenate all batch results
outputs["actions"] = torch.cat(batch_top_actions_list, dim=0)
end_time = time.time()
print(f"PGDSolver.solve completed in {end_time - start_time:.4f} seconds.")
return outputs
def _factor_action_block(self, actions: torch.Tensor) -> torch.Tensor:
"""Factor the action block dimension from action_simplex_dim."""
original_shape = actions.shape
action_block = self._config.action_block
simplex_dim = self._action_simplex_dim
return actions.reshape(*original_shape[:-1], action_block, simplex_dim)
def _project_action_simplex(self, actions: torch.Tensor) -> torch.Tensor:
"""Project the action onto the probability simplex."""
original_shape = actions.shape
s = self._factor_action_block(actions).reshape(-1, self._action_simplex_dim)
mu, _ = torch.sort(s, descending=True, dim=-1)
cumulative = mu.cumsum(dim=-1)
d = s.size(-1)
indices = torch.arange(1, d + 1, device=s.device, dtype=s.dtype)
threshold = (cumulative - 1) / indices
cond = (mu > threshold).to(torch.int32)
rho = cond.cumsum(dim=-1)
valid_rho = rho * cond
rho_max = valid_rho.max(dim=-1, keepdim=True)[0]
rho_min = torch.clamp(rho_max, min=1)
psi = (cumulative.gather(-1, rho_min - 1) - 1) / rho_min
projected = torch.clamp(s - psi, min=0.0).reshape(original_shape)
return projected

View File

@@ -0,0 +1,219 @@
"""Improved Cross Entropy Method (iCEM) solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class ICEMSolver:
"""Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention.
iCEM improves the sample efficiency over standard CEM and was introduced by
[1] for real-time planning.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for the action distribution.
n_steps: Number of CEM iterations.
topk: Number of elite samples to keep for distribution update.
noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
alpha: Momentum for mean/std EMA update.
n_elite_keep: Number of elites carried from previous iteration.
return_mean: If False, return best single trajectory instead of mean.
device: Device for tensor computations.
seed: Random seed for reproducibility.
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and
G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning".
Conference on Robot Learning, 2020.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
noise_beta: float = 2.0,
alpha: float = 0.1,
n_elite_keep: int = 5,
return_mean: bool = True,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.noise_beta = noise_beta
self.alpha = alpha
self.n_elite_keep = n_elite_keep
self.return_mean = return_mean
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if isinstance(action_space, Box):
self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32)
self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32)
else:
logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.")
self._action_low = None
self._action_high = None
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using improved Cross Entropy Method."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [],
"var": [],
}
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
for start_idx in range(0, self.n_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, self.n_envs)
current_bs = end_idx - start_idx
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
expanded_infos = {}
for k, v in info_dict.items():
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
v_batch = v_batch.unsqueeze(1)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
prev_topk_candidates = None
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Precompute FFT scale for colored noise
noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon)
freqs = torch.fft.rfftfreq(self.horizon, device=self.device)
freqs[0] = 1.0
noise_scale = freqs.pow(-self.noise_beta / 2)
noise_scale[0] = noise_scale[1]
for step in range(self.n_steps):
# Colored noise: generate with temporal axis last, then transpose
if self.horizon <= 1:
noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
else:
white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
fft = torch.fft.rfft(white, dim=-1)
colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1)
std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8)
noise = colored / std
noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim)
candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
candidates[:, 0] = batch_mean
# Inject previous elites
if prev_topk_candidates is not None:
n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1])
candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject]
# Clip to action bounds
if self._action_low is not None:
candidates = candidates.clamp(self._action_low, self._action_high)
current_info = expanded_infos.copy()
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
topk_candidates = candidates[batch_indices, topk_inds]
prev_topk_candidates = topk_candidates
# Momentum update
elite_mean = topk_candidates.mean(dim=1)
elite_var = topk_candidates.std(dim=1)
batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean
batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
if self.return_mean:
mean[start_idx:end_idx] = batch_mean
else:
mean[start_idx:end_idx] = topk_candidates[:, 0]
var[start_idx:end_idx] = batch_var
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"iCEM solve time: {time.time() - start_time:.4f} seconds")
return outputs

View File

@@ -0,0 +1,360 @@
"""Lagrangian solver for stable world model."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class LagrangianSolver(torch.nn.Module):
"""Lagrangian solver for stable world model.
get_cost returns the cost tensor (B, S). If the model also implements get_constraints,
it should return the constraint violations (B, S, C), where C is the number of constraints.
The constraint_cost should represent the cost of violating the constraints, where the constraint
is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
Args:
model: World model implementing the Costable protocol. Its get_cost() returns
a plain cost tensor (B, S). If it also has get_constraints(), that method
returns constraints of shape (B, S, C).
n_steps: Number of gradient descent steps per outer iteration.
n_outer_steps: Number of dual ascent (outer) iterations.
batch_size: Number of environments to process in parallel.
num_samples: Number of action samples to optimize in parallel.
var_scale: Initial variance scale for action perturbations.
action_noise: Noise added to actions during optimization.
rho_init: Initial penalty coefficient for the quadratic constraint term.
rho_max: Maximum value of the penalty coefficient.
rho_scale: Multiplicative growth factor for rho after each outer step.
persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls.
device: Device for tensor computations.
seed: Random seed for reproducibility.
optimizer_cls: PyTorch optimizer class to use.
optimizer_kwargs: Keyword arguments for the optimizer.
"""
def __init__(
self,
model: Costable,
n_steps: int,
n_outer_steps: int = 5,
batch_size: int | None = None,
num_samples: int = 1,
var_scale: float = 1.0,
action_noise: float = 0.0,
rho_init: float = 1.0,
rho_max: float = 1e4,
rho_scale: float = 2.0,
persist_multipliers: bool = True,
device: str | torch.device = 'cpu',
seed: int = 1234,
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: dict | None = None,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.n_outer_steps = n_outer_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.rho_init = rho_init
self.rho_max = rho_max
self.rho_scale = rho_scale
self.persist_multipliers = persist_multipliers
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = (
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
)
self._configured = False
self._n_envs = None
self._action_dim = None
self._config = None
self._lambdas: torch.Tensor | None = None # (n_envs, C)
def configure(
self, *, action_space: gym.Space, n_envs: int, config: Any
) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.'
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(self, actions: torch.Tensor | None = None) -> None:
"""Initialize the action tensor for optimization."""
if actions is None:
actions = torch.zeros((self._n_envs, 0, self.action_dim))
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
actions = actions.unsqueeze(1).repeat_interleave(
self.num_samples, dim=1
)
actions[:, 1:] += (
torch.randn(
actions[:, 1:].shape,
generator=self.torch_gen,
device=self.device,
)
* self.var_scale
)
if hasattr(self, 'init'):
self.init.copy_(actions)
else:
self.register_parameter('init', torch.nn.Parameter(actions))
def _init_multipliers(self, num_constraints: int) -> None:
"""Lazily initialize Lagrange multipliers to zeros."""
self._lambdas = torch.zeros(
self._n_envs, num_constraints, device=self.device
)
def _augmented_lagrangian_loss(
self,
costs: torch.Tensor, # (B, S)
constraints: torch.Tensor, # (B, S, C)
lambdas_batch: torch.Tensor, # (B, C)
rho: float,
) -> torch.Tensor:
"""Compute the augmented Lagrangian loss.
L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2
"""
# lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C)
linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum(
dim=-1
) # (B, S)
quadratic_penalty = rho * F.relu(constraints).pow(2).sum(
dim=-1
) # (B, S)
return (costs + linear_penalty + quadratic_penalty).sum()
def _update_multipliers(
self,
constraints: torch.Tensor, # (B, S, C) — detached, no grad
lambdas_batch: torch.Tensor, # (B, C)
rho: float,
) -> torch.Tensor:
"""Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i))."""
mean_g = constraints.mean(dim=1) # (B, C)
return torch.clamp(lambdas_batch + rho * mean_g, min=0.0)
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using augmented Lagrangian gradient descent."""
start_time = time.time()
outputs: dict = {
'cost': [],
'constraint_violation': [],
'actions': None,
'lambdas': None,
}
with torch.no_grad():
self.init_action(init_action)
if not self.persist_multipliers:
self._lambdas = None
batch_size = (
self.batch_size if self.batch_size is not None else self.n_envs
)
total_envs = self.n_envs
batch_top_actions_list = []
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
# Expand info_dict for current batch — same pattern as GradientSolver
expanded_infos = {}
for k, v in info_dict.items():
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(
current_bs, self.num_samples, *batch_v.shape[2:]
)
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(
batch_v[:, None, ...], self.num_samples, axis=1
)
else:
batch_v = v
expanded_infos[k] = batch_v
rho = self.rho_init
batch_cost_history = []
costs = None
final_constraints = None
for _outer in range(self.n_outer_steps):
# Fresh optimizer each outer step — avoids stale momentum after dual ascent
optim = self.optimizer_cls(
[batch_init], **self.optimizer_kwargs
)
for _step in range(self.n_steps):
current_info = expanded_infos.copy()
costs = self.model.get_cost(current_info, batch_init)
constraints = (
self.model.get_constraints(
expanded_infos.copy(), batch_init
)
if hasattr(self.model, 'get_constraints')
else None
)
assert isinstance(costs, torch.Tensor), (
f'Got {type(costs)} cost, expect torch.Tensor'
)
assert costs.ndim == 2 and costs.shape == (
current_bs,
self.num_samples,
), (
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
)
assert costs.requires_grad, (
'Cost must requires_grad for LagrangianSolver.'
)
if constraints is not None:
assert constraints.ndim == 3 and constraints.shape[
:2
] == (current_bs, self.num_samples), (
f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}'
)
if self._lambdas is None:
self._init_multipliers(constraints.shape[-1])
lambdas_batch = self._lambdas[start_idx:end_idx]
loss = self._augmented_lagrangian_loss(
costs, constraints, lambdas_batch, rho
)
else:
loss = costs.sum()
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
if self.action_noise > 0:
batch_init.data += (
torch.randn(
batch_init.shape, generator=self.torch_gen
)
* self.action_noise
)
batch_cost_history.append(loss.item())
# Dual ascent after inner loop converges
if constraints is not None:
with torch.no_grad():
final_constraints = self.model.get_constraints(
expanded_infos.copy(), batch_init
)
lambdas_batch = self._update_multipliers(
final_constraints, lambdas_batch, rho
)
self._lambdas[start_idx:end_idx] = lambdas_batch
rho = min(self.rho_max, rho * self.rho_scale)
with torch.no_grad():
mean_cost = costs.mean().item()
if constraints is not None:
viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,)
lam = lambdas_batch.mean(dim=0) # (C,)
viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist())
lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist())
print(
f' [outer {_outer+1}/{self.n_outer_steps}] '
f'cost={mean_cost:.4f} | '
f'constraint_viol=[{viol_str}] | '
f'lambdas=[{lam_str}] | '
f'rho={rho:.4f}'
)
else:
print(
f' [outer {_outer+1}/{self.n_outer_steps}] '
f'cost={mean_cost:.4f}'
)
outputs['cost'].append(batch_cost_history)
if final_constraints is not None:
outputs['constraint_violation'].append(
F.relu(final_constraints).mean().item()
)
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = torch.argsort(costs, dim=1)[:, 0]
batch_indices = torch.arange(current_bs)
top_actions_batch = batch_init[batch_indices, top_idx]
batch_top_actions_list.append(top_actions_batch.detach().cpu())
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
outputs['lambdas'] = (
self._lambdas.cpu() if self._lambdas is not None else None
)
constraint_info = ''
if outputs['constraint_violation']:
mean_viol = np.mean(outputs['constraint_violation'])
constraint_info = f' | constraint_violation={mean_viol:.4f}'
print(
f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.'
)
return outputs

View File

@@ -0,0 +1,208 @@
"""Model Predictive Path Integral solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class MPPISolver:
"""Model Predictive Path Integral solver for action optimization.
Args:
model: World model implementing the Costable protocol.
batch_size: Number of environments to process in parallel.
num_samples: Number of action candidates to sample per iteration.
var_scale: Initial variance scale for action noise.
n_steps: Number of MPPI iterations.
topk: Number of elite samples for weighted averaging.
temperature: Temperature parameter for softmax weighting.
device: Device for tensor computations.
seed: Random seed for reproducibility.
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1.0,
n_steps: int = 30,
topk: int = 30,
temperature: float = 0.5,
device: str | torch.device = "cpu",
seed: int = 1234,
) -> None:
self.model = model
self.batch_size = batch_size
self.num_samples = num_samples
self.topk = topk
self.var_scale = var_scale
self.n_steps = n_steps
self.temperature = temperature
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f"Action space is discrete, got {type(action_space)}. MPPISolver may not work as expected."
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action_distrib(
self, actions: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize the action distribution parameters (mean and variance)."""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
@torch.inference_mode()
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using MPPI."""
start_time = time.time()
outputs = {
"costs": [],
"mean": [],
"var": [],
}
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
total_envs = self.n_envs
# --- Iterate over batches ---
for start_idx in range(0, total_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, total_envs)
current_bs = end_idx - start_idx
# Slice Distribution Parameters for current batch
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
# Expand Info Dict for current batch (Same as CEM)
expanded_infos = {}
for k, v in info_dict.items():
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
# Add sample dim: (batch, 1, ...)
v_batch = v_batch.unsqueeze(1)
# Expand: (batch, num_samples, ...)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
# Optimization Loop
final_batch_cost = None
for step in range(self.n_steps):
# Sample noise: (Batch, Num_Samples, Horizon, Dim)
noise = torch.randn(
current_bs,
self.num_samples,
self.horizon,
self.action_dim,
generator=self.torch_gen,
device=self.device,
)
# MPPI Logic: candidates = mean + noise * sigma
candidates = batch_mean.unsqueeze(1) + noise * batch_var.unsqueeze(1)
# Force the first sample to be the current mean (Zero noise)
candidates[:, 0] = batch_mean
# Evaluate candidates
costs = self.model.get_cost(expanded_infos, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
# Select Elites (Optional, based on topk)
if self.topk is not None and self.topk < self.num_samples:
# topk_vals: (Batch, K), topk_inds: (Batch, K)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
# Gather Top-K Candidates
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# (Batch, K, Horizon, Dim)
relevant_candidates = candidates[batch_indices, topk_inds]
relevant_costs = topk_vals
else:
relevant_candidates = candidates
relevant_costs = costs
# MPPI Weighting: Softmax(-cost / temperature)
# Stabilize softmax by subtracting min cost
min_cost = relevant_costs.min(dim=1, keepdim=True)[0]
scaled_costs = relevant_costs - min_cost
weights = torch.softmax(-scaled_costs / self.temperature, dim=1) # (Batch, K)
# Update Mean: weighted sum of candidates
# Reshape weights for broadcasting: (Batch, K, 1, 1)
weights_expanded = weights.unsqueeze(-1).unsqueeze(-1)
batch_mean = (weights_expanded * relevant_candidates).sum(dim=1)
# Store average cost of the utilized samples for logging
final_batch_cost = relevant_costs.mean(dim=1).cpu().tolist()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
# We do not update var in standard MPPI
# Store history/metadata
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"MPPI solve time: {time.time() - start_time:.4f} seconds")
return outputs

View File

@@ -0,0 +1,74 @@
from typing import Any, Protocol, runtime_checkable
import gymnasium as gym
import torch
class Costable(Protocol):
"""Protocol for world model cost functions."""
def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor:
"""Compute the cost criterion for action candidates.
Args:
info_dict: Dictionary containing environment state information.
action_candidates: Tensor of proposed actions.
Returns:
A tensor of cost values for each action candidate.
"""
...
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover
"""Compute cost for given action candidates based on info dictionary.
Args:
info_dict: Dictionary containing environment state information.
action_candidates: Tensor of proposed actions.
Returns:
A tensor of cost values for each action candidate.
"""
...
@runtime_checkable
class Solver(Protocol):
"""Protocol for model-based planning solvers."""
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
"""Configure the solver with environment and planning specifications.
Args:
action_space: The action space of the environment.
n_envs: Number of parallel environments.
config: Planning configuration object.
"""
...
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
...
@property
def n_envs(self) -> int:
"""Number of parallel environments being planned for."""
...
@property
def horizon(self) -> int:
"""Planning horizon length in timesteps."""
...
def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict:
"""Solve the planning optimization problem to find optimal actions.
Args:
info_dict: Dictionary containing environment state information.
init_action: Optional initial action sequence to warm-start the solver.
Returns:
Dictionary containing optimized actions and other solver-specific info.
"""
...