调整ignore的文件
This commit is contained in:
15
.gitignore
vendored
15
.gitignore
vendored
@@ -1,3 +1,18 @@
|
||||
.venv/*
|
||||
!.venv/.gitignore
|
||||
!.venv/lib/
|
||||
.venv/lib/*
|
||||
!.venv/lib/python3.10/
|
||||
.venv/lib/python3.10/*
|
||||
!.venv/lib/python3.10/site-packages/
|
||||
.venv/lib/python3.10/site-packages/*
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/*
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/world.py
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py
|
||||
!.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
|
||||
outputs/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
from .cem import CEMSolver
|
||||
from .gd import GradientSolver
|
||||
from .icem import ICEMSolver
|
||||
from .lagrangian import LagrangianSolver
|
||||
from .mppi import MPPISolver
|
||||
from .solver import Solver
|
||||
from .discrete_solvers import PGDSolver
|
||||
|
||||
__all__ = [
|
||||
'Solver',
|
||||
'GradientSolver',
|
||||
'CEMSolver',
|
||||
'ICEMSolver',
|
||||
'PGDSolver',
|
||||
'MPPISolver',
|
||||
'LagrangianSolver',
|
||||
]
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Projected Gradient Descent solver for discrete action spaces."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Discrete
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class PGDSolver(torch.nn.Module):
|
||||
"""Projected Gradient Descent solver for discrete action optimization.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
n_steps: Number of gradient descent iterations.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
var_scale: Initial variance scale for action perturbations.
|
||||
num_samples: Number of action samples to optimize in parallel.
|
||||
action_noise: Noise added to actions during optimization.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
n_steps: int,
|
||||
batch_size: int | None = None,
|
||||
var_scale: float = 1,
|
||||
num_samples: int = 1,
|
||||
action_noise: float = 0.0,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.n_steps = n_steps
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.var_scale = var_scale
|
||||
self.action_noise = action_noise
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
self._configured = False
|
||||
self._n_envs = None
|
||||
self._action_dim = None
|
||||
self._action_simplex_dim = None
|
||||
self._config = None
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
assert isinstance(action_space, Discrete), f"Action space must be discrete, got {type(action_space)}"
|
||||
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._action_simplex_dim = int(action_space.n)
|
||||
self._configured = True
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def action_simplex_dim(self) -> int:
|
||||
"""Simplex dimension for discrete action probabilities."""
|
||||
return self._action_simplex_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action(
|
||||
self, actions: torch.Tensor | None = None, from_scalar: bool = False
|
||||
) -> None:
|
||||
"""Initialize the action tensor for optimization."""
|
||||
if actions is None:
|
||||
actions = torch.zeros((self._n_envs, 0, self.action_simplex_dim))
|
||||
elif from_scalar:
|
||||
# convert scalar to one-hot
|
||||
actions = torch.nn.functional.one_hot(actions, num_classes=self._action_simplex_dim).to(torch.float32)
|
||||
# merge action_block dim
|
||||
actions = actions.reshape(*actions.shape[:-2], self.action_simplex_dim)
|
||||
assert (
|
||||
actions.shape[0] == self._n_envs
|
||||
and actions.shape[1] <= self.horizon
|
||||
and actions.shape[2] == self.action_simplex_dim
|
||||
)
|
||||
|
||||
# fill remaining action
|
||||
remaining = self.horizon - actions.shape[1]
|
||||
|
||||
if remaining > 0:
|
||||
new_actions = torch.zeros(self._n_envs, remaining, self.action_simplex_dim)
|
||||
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
||||
|
||||
actions = actions.unsqueeze(1).repeat_interleave(self.num_samples, dim=1) # add sample dim
|
||||
actions[:, 1:] += (
|
||||
torch.randn(actions[:, 1:].shape, generator=self.torch_gen, device=self.device) * self.var_scale
|
||||
) # add small noise to all samples except the first one
|
||||
|
||||
# reset actions
|
||||
if hasattr(self, "init"):
|
||||
self.init.copy_(actions)
|
||||
else:
|
||||
self.register_parameter("init", torch.nn.Parameter(actions))
|
||||
|
||||
def solve(
|
||||
self,
|
||||
info_dict: dict,
|
||||
init_action: torch.Tensor | None = None,
|
||||
from_scalar: bool = False,
|
||||
) -> dict:
|
||||
"""Solve the planning problem using projected gradient descent."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"cost": [], # Will store list of cost histories per batch
|
||||
"actions": None,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.init_action(init_action, from_scalar=from_scalar)
|
||||
|
||||
# Determine batch size (default to all envs if not specified which can cause memory issues)
|
||||
batch_size = self.batch_size if self.batch_size is not None else self.n_envs
|
||||
total_envs = self.n_envs
|
||||
|
||||
# Lists to hold results from each batch to be concatenated later
|
||||
batch_top_actions_list = []
|
||||
|
||||
# --- Outer Loop: Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, batch_size):
|
||||
end_idx = min(start_idx + batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||
batch_init.requires_grad = True
|
||||
|
||||
optim = torch.optim.SGD([batch_init], lr=1.0)
|
||||
|
||||
# Prepare Batch Infos
|
||||
# Slice the input info_dict and then expand dimensions
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
# Slice the data for the current batch indices
|
||||
# Assumes input data dim 0 corresponds to n_envs
|
||||
if torch.is_tensor(v):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = batch_v.unsqueeze(1)
|
||||
batch_v = batch_v.expand(current_bs, self.num_samples, *batch_v.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = np.repeat(batch_v[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = batch_v
|
||||
|
||||
# Perform Gradient Descent for this batch
|
||||
batch_cost_history = []
|
||||
|
||||
for step in range(self.n_steps):
|
||||
current_info = expanded_infos.copy()
|
||||
|
||||
# Calculate cost using the batch parameter
|
||||
costs = self.model.get_cost(current_info, batch_init)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Got {type(costs)} cost, expect torch.Tensor"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
assert costs.requires_grad, "Cost must requires_grad for PGD solver."
|
||||
|
||||
cost = costs.sum() # Sum cost for this batch
|
||||
cost.backward()
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
# Add noise
|
||||
if self.action_noise > 0:
|
||||
batch_init.data += torch.randn(batch_init.shape, generator=self.torch_gen) * self.action_noise
|
||||
|
||||
# projection onto simplex
|
||||
with torch.no_grad():
|
||||
batch_init.copy_(self._project_action_simplex(batch_init))
|
||||
|
||||
batch_cost_history.append(cost.item())
|
||||
|
||||
# Store cost history for this batch
|
||||
outputs["cost"].append(batch_cost_history)
|
||||
|
||||
# Update the global self.init with the optimized batch values
|
||||
with torch.no_grad():
|
||||
self.init[start_idx:end_idx] = batch_init
|
||||
|
||||
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
||||
batch_indices = torch.arange(current_bs)
|
||||
|
||||
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||
|
||||
# convert one-hot back to discrete actions
|
||||
top_actions_batch = self._factor_action_block(top_actions_batch).argmax(dim=-1)
|
||||
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
||||
|
||||
# Concatenate all batch results
|
||||
outputs["actions"] = torch.cat(batch_top_actions_list, dim=0)
|
||||
end_time = time.time()
|
||||
print(f"PGDSolver.solve completed in {end_time - start_time:.4f} seconds.")
|
||||
|
||||
return outputs
|
||||
|
||||
def _factor_action_block(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Factor the action block dimension from action_simplex_dim."""
|
||||
original_shape = actions.shape
|
||||
action_block = self._config.action_block
|
||||
simplex_dim = self._action_simplex_dim
|
||||
return actions.reshape(*original_shape[:-1], action_block, simplex_dim)
|
||||
|
||||
def _project_action_simplex(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Project the action onto the probability simplex."""
|
||||
original_shape = actions.shape
|
||||
|
||||
s = self._factor_action_block(actions).reshape(-1, self._action_simplex_dim)
|
||||
|
||||
mu, _ = torch.sort(s, descending=True, dim=-1)
|
||||
cumulative = mu.cumsum(dim=-1)
|
||||
|
||||
d = s.size(-1)
|
||||
indices = torch.arange(1, d + 1, device=s.device, dtype=s.dtype)
|
||||
|
||||
threshold = (cumulative - 1) / indices
|
||||
|
||||
cond = (mu > threshold).to(torch.int32)
|
||||
rho = cond.cumsum(dim=-1)
|
||||
valid_rho = rho * cond
|
||||
rho_max = valid_rho.max(dim=-1, keepdim=True)[0]
|
||||
|
||||
rho_min = torch.clamp(rho_max, min=1)
|
||||
psi = (cumulative.gather(-1, rho_min - 1) - 1) / rho_min
|
||||
|
||||
projected = torch.clamp(s - psi, min=0.0).reshape(original_shape)
|
||||
return projected
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Improved Cross Entropy Method (iCEM) solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class ICEMSolver:
|
||||
"""Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention.
|
||||
iCEM improves the sample efficiency over standard CEM and was introduced by
|
||||
[1] for real-time planning.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action candidates to sample per iteration.
|
||||
var_scale: Initial variance scale for the action distribution.
|
||||
n_steps: Number of CEM iterations.
|
||||
topk: Number of elite samples to keep for distribution update.
|
||||
noise_beta: Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
|
||||
alpha: Momentum for mean/std EMA update.
|
||||
n_elite_keep: Number of elites carried from previous iteration.
|
||||
return_mean: If False, return best single trajectory instead of mean.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
|
||||
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and
|
||||
G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning".
|
||||
Conference on Robot Learning, 2020.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
batch_size: int = 1,
|
||||
num_samples: int = 300,
|
||||
var_scale: float = 1,
|
||||
n_steps: int = 30,
|
||||
topk: int = 30,
|
||||
noise_beta: float = 2.0,
|
||||
alpha: float = 0.1,
|
||||
n_elite_keep: int = 5,
|
||||
return_mean: bool = True,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.var_scale = var_scale
|
||||
self.num_samples = num_samples
|
||||
self.n_steps = n_steps
|
||||
self.topk = topk
|
||||
self.noise_beta = noise_beta
|
||||
self.alpha = alpha
|
||||
self.n_elite_keep = n_elite_keep
|
||||
self.return_mean = return_mean
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if isinstance(action_space, Box):
|
||||
self._action_low = torch.tensor(action_space.low[0], device=self.device, dtype=torch.float32)
|
||||
self._action_high = torch.tensor(action_space.high[0], device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
logging.warning(f"Action space is discrete, got {type(action_space)}. ICEMSolver may not work as expected.")
|
||||
self._action_low = None
|
||||
self._action_high = None
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
|
||||
return mean, var
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using improved Cross Entropy Method."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"costs": [],
|
||||
"mean": [],
|
||||
"var": [],
|
||||
}
|
||||
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
|
||||
for start_idx in range(0, self.n_envs, self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, self.n_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_mean = mean[start_idx:end_idx]
|
||||
batch_var = var[start_idx:end_idx]
|
||||
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
v_batch = v[start_idx:end_idx]
|
||||
if torch.is_tensor(v):
|
||||
v_batch = v_batch.unsqueeze(1)
|
||||
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
prev_topk_candidates = None
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
|
||||
# Precompute FFT scale for colored noise
|
||||
noise_shape = (current_bs, self.num_samples, self.action_dim, self.horizon)
|
||||
freqs = torch.fft.rfftfreq(self.horizon, device=self.device)
|
||||
freqs[0] = 1.0
|
||||
noise_scale = freqs.pow(-self.noise_beta / 2)
|
||||
noise_scale[0] = noise_scale[1]
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Colored noise: generate with temporal axis last, then transpose
|
||||
if self.horizon <= 1:
|
||||
noise = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
|
||||
else:
|
||||
white = torch.randn(noise_shape, generator=self.torch_gen, device=self.device)
|
||||
fft = torch.fft.rfft(white, dim=-1)
|
||||
colored = torch.fft.irfft(fft * noise_scale, n=self.horizon, dim=-1)
|
||||
std = colored.std(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
noise = colored / std
|
||||
noise = noise.transpose(-1, -2) # -> (bs, num_samples, horizon, action_dim)
|
||||
|
||||
candidates = noise * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
|
||||
candidates[:, 0] = batch_mean
|
||||
|
||||
# Inject previous elites
|
||||
if prev_topk_candidates is not None:
|
||||
n_inject = min(self.n_elite_keep, prev_topk_candidates.shape[1])
|
||||
candidates[:, 1:1 + n_inject] = prev_topk_candidates[:, :n_inject]
|
||||
|
||||
# Clip to action bounds
|
||||
if self._action_low is not None:
|
||||
candidates = candidates.clamp(self._action_low, self._action_high)
|
||||
|
||||
current_info = expanded_infos.copy()
|
||||
costs = self.model.get_cost(current_info, candidates)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
|
||||
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||
topk_candidates = candidates[batch_indices, topk_inds]
|
||||
|
||||
prev_topk_candidates = topk_candidates
|
||||
|
||||
# Momentum update
|
||||
elite_mean = topk_candidates.mean(dim=1)
|
||||
elite_var = topk_candidates.std(dim=1)
|
||||
batch_mean = self.alpha * batch_mean + (1 - self.alpha) * elite_mean
|
||||
batch_var = self.alpha * batch_var + (1 - self.alpha) * elite_var
|
||||
|
||||
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
|
||||
|
||||
if self.return_mean:
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
else:
|
||||
mean[start_idx:end_idx] = topk_candidates[:, 0]
|
||||
|
||||
var[start_idx:end_idx] = batch_var
|
||||
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
|
||||
print(f"iCEM solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Lagrangian solver for stable world model."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class LagrangianSolver(torch.nn.Module):
|
||||
"""Lagrangian solver for stable world model.
|
||||
|
||||
get_cost returns the cost tensor (B, S). If the model also implements get_constraints,
|
||||
it should return the constraint violations (B, S, C), where C is the number of constraints.
|
||||
The constraint_cost should represent the cost of violating the constraints, where the constraint
|
||||
is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
|
||||
|
||||
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
|
||||
|
||||
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol. Its get_cost() returns
|
||||
a plain cost tensor (B, S). If it also has get_constraints(), that method
|
||||
returns constraints of shape (B, S, C).
|
||||
n_steps: Number of gradient descent steps per outer iteration.
|
||||
n_outer_steps: Number of dual ascent (outer) iterations.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action samples to optimize in parallel.
|
||||
var_scale: Initial variance scale for action perturbations.
|
||||
action_noise: Noise added to actions during optimization.
|
||||
rho_init: Initial penalty coefficient for the quadratic constraint term.
|
||||
rho_max: Maximum value of the penalty coefficient.
|
||||
rho_scale: Multiplicative growth factor for rho after each outer step.
|
||||
persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
optimizer_cls: PyTorch optimizer class to use.
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
n_steps: int,
|
||||
n_outer_steps: int = 5,
|
||||
batch_size: int | None = None,
|
||||
num_samples: int = 1,
|
||||
var_scale: float = 1.0,
|
||||
action_noise: float = 0.0,
|
||||
rho_init: float = 1.0,
|
||||
rho_max: float = 1e4,
|
||||
rho_scale: float = 2.0,
|
||||
persist_multipliers: bool = True,
|
||||
device: str | torch.device = 'cpu',
|
||||
seed: int = 1234,
|
||||
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||
optimizer_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.n_steps = n_steps
|
||||
self.n_outer_steps = n_outer_steps
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.var_scale = var_scale
|
||||
self.action_noise = action_noise
|
||||
self.rho_init = rho_init
|
||||
self.rho_max = rho_max
|
||||
self.rho_scale = rho_scale
|
||||
self.persist_multipliers = persist_multipliers
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
self.optimizer_cls = optimizer_cls
|
||||
self.optimizer_kwargs = (
|
||||
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
||||
)
|
||||
|
||||
self._configured = False
|
||||
self._n_envs = None
|
||||
self._action_dim = None
|
||||
self._config = None
|
||||
self._lambdas: torch.Tensor | None = None # (n_envs, C)
|
||||
|
||||
def configure(
|
||||
self, *, action_space: gym.Space, n_envs: int, config: Any
|
||||
) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(
|
||||
f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.'
|
||||
)
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
||||
"""Initialize the action tensor for optimization."""
|
||||
if actions is None:
|
||||
actions = torch.zeros((self._n_envs, 0, self.action_dim))
|
||||
|
||||
remaining = self.horizon - actions.shape[1]
|
||||
if remaining > 0:
|
||||
new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
|
||||
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
||||
|
||||
actions = actions.unsqueeze(1).repeat_interleave(
|
||||
self.num_samples, dim=1
|
||||
)
|
||||
actions[:, 1:] += (
|
||||
torch.randn(
|
||||
actions[:, 1:].shape,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
* self.var_scale
|
||||
)
|
||||
|
||||
if hasattr(self, 'init'):
|
||||
self.init.copy_(actions)
|
||||
else:
|
||||
self.register_parameter('init', torch.nn.Parameter(actions))
|
||||
|
||||
def _init_multipliers(self, num_constraints: int) -> None:
|
||||
"""Lazily initialize Lagrange multipliers to zeros."""
|
||||
self._lambdas = torch.zeros(
|
||||
self._n_envs, num_constraints, device=self.device
|
||||
)
|
||||
|
||||
def _augmented_lagrangian_loss(
|
||||
self,
|
||||
costs: torch.Tensor, # (B, S)
|
||||
constraints: torch.Tensor, # (B, S, C)
|
||||
lambdas_batch: torch.Tensor, # (B, C)
|
||||
rho: float,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the augmented Lagrangian loss.
|
||||
|
||||
L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2
|
||||
"""
|
||||
# lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C)
|
||||
linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum(
|
||||
dim=-1
|
||||
) # (B, S)
|
||||
quadratic_penalty = rho * F.relu(constraints).pow(2).sum(
|
||||
dim=-1
|
||||
) # (B, S)
|
||||
return (costs + linear_penalty + quadratic_penalty).sum()
|
||||
|
||||
def _update_multipliers(
|
||||
self,
|
||||
constraints: torch.Tensor, # (B, S, C) — detached, no grad
|
||||
lambdas_batch: torch.Tensor, # (B, C)
|
||||
rho: float,
|
||||
) -> torch.Tensor:
|
||||
"""Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i))."""
|
||||
mean_g = constraints.mean(dim=1) # (B, C)
|
||||
return torch.clamp(lambdas_batch + rho * mean_g, min=0.0)
|
||||
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using augmented Lagrangian gradient descent."""
|
||||
start_time = time.time()
|
||||
outputs: dict = {
|
||||
'cost': [],
|
||||
'constraint_violation': [],
|
||||
'actions': None,
|
||||
'lambdas': None,
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.init_action(init_action)
|
||||
|
||||
if not self.persist_multipliers:
|
||||
self._lambdas = None
|
||||
|
||||
batch_size = (
|
||||
self.batch_size if self.batch_size is not None else self.n_envs
|
||||
)
|
||||
total_envs = self.n_envs
|
||||
batch_top_actions_list = []
|
||||
|
||||
for start_idx in range(0, total_envs, batch_size):
|
||||
end_idx = min(start_idx + batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
batch_init = self.init[start_idx:end_idx].clone().detach()
|
||||
batch_init.requires_grad = True
|
||||
|
||||
# Expand info_dict for current batch — same pattern as GradientSolver
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = batch_v.unsqueeze(1)
|
||||
batch_v = batch_v.expand(
|
||||
current_bs, self.num_samples, *batch_v.shape[2:]
|
||||
)
|
||||
elif isinstance(v, np.ndarray):
|
||||
batch_v = v[start_idx:end_idx]
|
||||
batch_v = np.repeat(
|
||||
batch_v[:, None, ...], self.num_samples, axis=1
|
||||
)
|
||||
else:
|
||||
batch_v = v
|
||||
expanded_infos[k] = batch_v
|
||||
|
||||
rho = self.rho_init
|
||||
batch_cost_history = []
|
||||
costs = None
|
||||
final_constraints = None
|
||||
|
||||
for _outer in range(self.n_outer_steps):
|
||||
# Fresh optimizer each outer step — avoids stale momentum after dual ascent
|
||||
optim = self.optimizer_cls(
|
||||
[batch_init], **self.optimizer_kwargs
|
||||
)
|
||||
|
||||
for _step in range(self.n_steps):
|
||||
current_info = expanded_infos.copy()
|
||||
costs = self.model.get_cost(current_info, batch_init)
|
||||
constraints = (
|
||||
self.model.get_constraints(
|
||||
expanded_infos.copy(), batch_init
|
||||
)
|
||||
if hasattr(self.model, 'get_constraints')
|
||||
else None
|
||||
)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), (
|
||||
f'Got {type(costs)} cost, expect torch.Tensor'
|
||||
)
|
||||
assert costs.ndim == 2 and costs.shape == (
|
||||
current_bs,
|
||||
self.num_samples,
|
||||
), (
|
||||
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
||||
)
|
||||
assert costs.requires_grad, (
|
||||
'Cost must requires_grad for LagrangianSolver.'
|
||||
)
|
||||
|
||||
if constraints is not None:
|
||||
assert constraints.ndim == 3 and constraints.shape[
|
||||
:2
|
||||
] == (current_bs, self.num_samples), (
|
||||
f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}'
|
||||
)
|
||||
if self._lambdas is None:
|
||||
self._init_multipliers(constraints.shape[-1])
|
||||
lambdas_batch = self._lambdas[start_idx:end_idx]
|
||||
loss = self._augmented_lagrangian_loss(
|
||||
costs, constraints, lambdas_batch, rho
|
||||
)
|
||||
else:
|
||||
loss = costs.sum()
|
||||
|
||||
loss.backward()
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
if self.action_noise > 0:
|
||||
batch_init.data += (
|
||||
torch.randn(
|
||||
batch_init.shape, generator=self.torch_gen
|
||||
)
|
||||
* self.action_noise
|
||||
)
|
||||
|
||||
batch_cost_history.append(loss.item())
|
||||
|
||||
# Dual ascent after inner loop converges
|
||||
if constraints is not None:
|
||||
with torch.no_grad():
|
||||
final_constraints = self.model.get_constraints(
|
||||
expanded_infos.copy(), batch_init
|
||||
)
|
||||
lambdas_batch = self._update_multipliers(
|
||||
final_constraints, lambdas_batch, rho
|
||||
)
|
||||
self._lambdas[start_idx:end_idx] = lambdas_batch
|
||||
rho = min(self.rho_max, rho * self.rho_scale)
|
||||
|
||||
with torch.no_grad():
|
||||
mean_cost = costs.mean().item()
|
||||
if constraints is not None:
|
||||
viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,)
|
||||
lam = lambdas_batch.mean(dim=0) # (C,)
|
||||
viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist())
|
||||
lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist())
|
||||
print(
|
||||
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
||||
f'cost={mean_cost:.4f} | '
|
||||
f'constraint_viol=[{viol_str}] | '
|
||||
f'lambdas=[{lam_str}] | '
|
||||
f'rho={rho:.4f}'
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
||||
f'cost={mean_cost:.4f}'
|
||||
)
|
||||
|
||||
outputs['cost'].append(batch_cost_history)
|
||||
|
||||
if final_constraints is not None:
|
||||
outputs['constraint_violation'].append(
|
||||
F.relu(final_constraints).mean().item()
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.init[start_idx:end_idx] = batch_init
|
||||
|
||||
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
||||
batch_indices = torch.arange(current_bs)
|
||||
top_actions_batch = batch_init[batch_indices, top_idx]
|
||||
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
||||
|
||||
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
||||
outputs['lambdas'] = (
|
||||
self._lambdas.cpu() if self._lambdas is not None else None
|
||||
)
|
||||
|
||||
constraint_info = ''
|
||||
if outputs['constraint_violation']:
|
||||
mean_viol = np.mean(outputs['constraint_violation'])
|
||||
constraint_info = f' | constraint_violation={mean_viol:.4f}'
|
||||
print(
|
||||
f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.'
|
||||
)
|
||||
return outputs
|
||||
@@ -0,0 +1,208 @@
|
||||
"""Model Predictive Path Integral solver for model-based planning."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from loguru import logger as logging
|
||||
|
||||
from .solver import Costable
|
||||
|
||||
|
||||
class MPPISolver:
|
||||
"""Model Predictive Path Integral solver for action optimization.
|
||||
|
||||
Args:
|
||||
model: World model implementing the Costable protocol.
|
||||
batch_size: Number of environments to process in parallel.
|
||||
num_samples: Number of action candidates to sample per iteration.
|
||||
var_scale: Initial variance scale for action noise.
|
||||
n_steps: Number of MPPI iterations.
|
||||
topk: Number of elite samples for weighted averaging.
|
||||
temperature: Temperature parameter for softmax weighting.
|
||||
device: Device for tensor computations.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Costable,
|
||||
batch_size: int = 1,
|
||||
num_samples: int = 300,
|
||||
var_scale: float = 1.0,
|
||||
n_steps: int = 30,
|
||||
topk: int = 30,
|
||||
temperature: float = 0.5,
|
||||
device: str | torch.device = "cpu",
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.num_samples = num_samples
|
||||
self.topk = topk
|
||||
self.var_scale = var_scale
|
||||
self.n_steps = n_steps
|
||||
self.temperature = temperature
|
||||
self.device = device
|
||||
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment specifications."""
|
||||
self._action_space = action_space
|
||||
self._n_envs = n_envs
|
||||
self._config = config
|
||||
self._action_dim = int(np.prod(action_space.shape[1:]))
|
||||
self._configured = True
|
||||
|
||||
if not isinstance(action_space, Box):
|
||||
logging.warning(
|
||||
f"Action space is discrete, got {type(action_space)}. MPPISolver may not work as expected."
|
||||
)
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments."""
|
||||
return self._n_envs
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
return self._action_dim * self._config.action_block
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon in timesteps."""
|
||||
return self._config.horizon
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
||||
"""Make solver callable, forwarding to solve()."""
|
||||
return self.solve(*args, **kwargs)
|
||||
|
||||
def init_action_distrib(
|
||||
self, actions: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the action distribution parameters (mean and variance)."""
|
||||
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
|
||||
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
|
||||
|
||||
remaining = self.horizon - mean.shape[1]
|
||||
if remaining > 0:
|
||||
device = mean.device
|
||||
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
|
||||
mean = torch.cat([mean, new_mean], dim=1).to(device)
|
||||
|
||||
return mean, var
|
||||
|
||||
@torch.inference_mode()
|
||||
def solve(
|
||||
self, info_dict: dict, init_action: torch.Tensor | None = None
|
||||
) -> dict:
|
||||
"""Solve the planning problem using MPPI."""
|
||||
start_time = time.time()
|
||||
outputs = {
|
||||
"costs": [],
|
||||
"mean": [],
|
||||
"var": [],
|
||||
}
|
||||
|
||||
# -- initialize the action distribution globally
|
||||
mean, var = self.init_action_distrib(init_action)
|
||||
mean = mean.to(self.device)
|
||||
var = var.to(self.device)
|
||||
|
||||
total_envs = self.n_envs
|
||||
|
||||
# --- Iterate over batches ---
|
||||
for start_idx in range(0, total_envs, self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, total_envs)
|
||||
current_bs = end_idx - start_idx
|
||||
|
||||
# Slice Distribution Parameters for current batch
|
||||
batch_mean = mean[start_idx:end_idx]
|
||||
batch_var = var[start_idx:end_idx]
|
||||
|
||||
# Expand Info Dict for current batch (Same as CEM)
|
||||
expanded_infos = {}
|
||||
for k, v in info_dict.items():
|
||||
v_batch = v[start_idx:end_idx]
|
||||
if torch.is_tensor(v):
|
||||
# Add sample dim: (batch, 1, ...)
|
||||
v_batch = v_batch.unsqueeze(1)
|
||||
# Expand: (batch, num_samples, ...)
|
||||
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
|
||||
elif isinstance(v, np.ndarray):
|
||||
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
|
||||
expanded_infos[k] = v_batch
|
||||
|
||||
# Optimization Loop
|
||||
final_batch_cost = None
|
||||
|
||||
for step in range(self.n_steps):
|
||||
# Sample noise: (Batch, Num_Samples, Horizon, Dim)
|
||||
noise = torch.randn(
|
||||
current_bs,
|
||||
self.num_samples,
|
||||
self.horizon,
|
||||
self.action_dim,
|
||||
generator=self.torch_gen,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# MPPI Logic: candidates = mean + noise * sigma
|
||||
candidates = batch_mean.unsqueeze(1) + noise * batch_var.unsqueeze(1)
|
||||
|
||||
# Force the first sample to be the current mean (Zero noise)
|
||||
candidates[:, 0] = batch_mean
|
||||
|
||||
# Evaluate candidates
|
||||
costs = self.model.get_cost(expanded_infos, candidates)
|
||||
|
||||
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
|
||||
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
|
||||
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
|
||||
)
|
||||
|
||||
# Select Elites (Optional, based on topk)
|
||||
if self.topk is not None and self.topk < self.num_samples:
|
||||
# topk_vals: (Batch, K), topk_inds: (Batch, K)
|
||||
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
|
||||
|
||||
# Gather Top-K Candidates
|
||||
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
|
||||
# (Batch, K, Horizon, Dim)
|
||||
relevant_candidates = candidates[batch_indices, topk_inds]
|
||||
relevant_costs = topk_vals
|
||||
else:
|
||||
relevant_candidates = candidates
|
||||
relevant_costs = costs
|
||||
|
||||
# MPPI Weighting: Softmax(-cost / temperature)
|
||||
# Stabilize softmax by subtracting min cost
|
||||
min_cost = relevant_costs.min(dim=1, keepdim=True)[0]
|
||||
scaled_costs = relevant_costs - min_cost
|
||||
weights = torch.softmax(-scaled_costs / self.temperature, dim=1) # (Batch, K)
|
||||
|
||||
# Update Mean: weighted sum of candidates
|
||||
# Reshape weights for broadcasting: (Batch, K, 1, 1)
|
||||
weights_expanded = weights.unsqueeze(-1).unsqueeze(-1)
|
||||
batch_mean = (weights_expanded * relevant_candidates).sum(dim=1)
|
||||
|
||||
# Store average cost of the utilized samples for logging
|
||||
final_batch_cost = relevant_costs.mean(dim=1).cpu().tolist()
|
||||
|
||||
# Write results back to global storage
|
||||
mean[start_idx:end_idx] = batch_mean
|
||||
# We do not update var in standard MPPI
|
||||
|
||||
# Store history/metadata
|
||||
outputs["costs"].extend(final_batch_cost)
|
||||
|
||||
outputs["actions"] = mean.detach().cpu()
|
||||
outputs["mean"] = [mean.detach().cpu()]
|
||||
outputs["var"] = [var.detach().cpu()]
|
||||
|
||||
print(f"MPPI solve time: {time.time() - start_time:.4f} seconds")
|
||||
return outputs
|
||||
@@ -0,0 +1,74 @@
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
|
||||
|
||||
class Costable(Protocol):
|
||||
"""Protocol for world model cost functions."""
|
||||
|
||||
def criterion(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the cost criterion for action candidates.
|
||||
|
||||
Args:
|
||||
info_dict: Dictionary containing environment state information.
|
||||
action_candidates: Tensor of proposed actions.
|
||||
|
||||
Returns:
|
||||
A tensor of cost values for each action candidate.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
"""Compute cost for given action candidates based on info dictionary.
|
||||
|
||||
Args:
|
||||
info_dict: Dictionary containing environment state information.
|
||||
action_candidates: Tensor of proposed actions.
|
||||
|
||||
Returns:
|
||||
A tensor of cost values for each action candidate.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Solver(Protocol):
|
||||
"""Protocol for model-based planning solvers."""
|
||||
|
||||
def configure(self, *, action_space: gym.Space, n_envs: int, config: Any) -> None:
|
||||
"""Configure the solver with environment and planning specifications.
|
||||
|
||||
Args:
|
||||
action_space: The action space of the environment.
|
||||
n_envs: Number of parallel environments.
|
||||
config: Planning configuration object.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
"""Flattened action dimension including action_block grouping."""
|
||||
...
|
||||
|
||||
@property
|
||||
def n_envs(self) -> int:
|
||||
"""Number of parallel environments being planned for."""
|
||||
...
|
||||
|
||||
@property
|
||||
def horizon(self) -> int:
|
||||
"""Planning horizon length in timesteps."""
|
||||
...
|
||||
|
||||
def solve(self, info_dict: dict, init_action: torch.Tensor | None = None) -> dict:
|
||||
"""Solve the planning optimization problem to find optimal actions.
|
||||
|
||||
Args:
|
||||
info_dict: Dictionary containing environment state information.
|
||||
init_action: Optional initial action sequence to warm-start the solver.
|
||||
|
||||
Returns:
|
||||
Dictionary containing optimized actions and other solver-specific info.
|
||||
"""
|
||||
...
|
||||
Reference in New Issue
Block a user