361 lines
14 KiB
Python
361 lines
14 KiB
Python
"""Lagrangian solver for stable world model."""
|
|
|
|
import time
|
|
from typing import Any
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from gymnasium.spaces import Box
|
|
from loguru import logger as logging
|
|
|
|
from .solver import Costable
|
|
|
|
|
|
class LagrangianSolver(torch.nn.Module):
|
|
"""Lagrangian solver for stable world model.
|
|
|
|
get_cost returns the cost tensor (B, S). If the model also implements get_constraints,
|
|
it should return the constraint violations (B, S, C), where C is the number of constraints.
|
|
The constraint_cost should represent the cost of violating the constraints, where the constraint
|
|
is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
|
|
|
|
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
|
|
|
|
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
|
|
|
|
Args:
|
|
model: World model implementing the Costable protocol. Its get_cost() returns
|
|
a plain cost tensor (B, S). If it also has get_constraints(), that method
|
|
returns constraints of shape (B, S, C).
|
|
n_steps: Number of gradient descent steps per outer iteration.
|
|
n_outer_steps: Number of dual ascent (outer) iterations.
|
|
batch_size: Number of environments to process in parallel.
|
|
num_samples: Number of action samples to optimize in parallel.
|
|
var_scale: Initial variance scale for action perturbations.
|
|
action_noise: Noise added to actions during optimization.
|
|
rho_init: Initial penalty coefficient for the quadratic constraint term.
|
|
rho_max: Maximum value of the penalty coefficient.
|
|
rho_scale: Multiplicative growth factor for rho after each outer step.
|
|
persist_multipliers: Whether to warm-start Lagrange multipliers across solve() calls.
|
|
device: Device for tensor computations.
|
|
seed: Random seed for reproducibility.
|
|
optimizer_cls: PyTorch optimizer class to use.
|
|
optimizer_kwargs: Keyword arguments for the optimizer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Costable,
|
|
n_steps: int,
|
|
n_outer_steps: int = 5,
|
|
batch_size: int | None = None,
|
|
num_samples: int = 1,
|
|
var_scale: float = 1.0,
|
|
action_noise: float = 0.0,
|
|
rho_init: float = 1.0,
|
|
rho_max: float = 1e4,
|
|
rho_scale: float = 2.0,
|
|
persist_multipliers: bool = True,
|
|
device: str | torch.device = 'cpu',
|
|
seed: int = 1234,
|
|
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.Adam,
|
|
optimizer_kwargs: dict | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.model = model
|
|
self.n_steps = n_steps
|
|
self.n_outer_steps = n_outer_steps
|
|
self.batch_size = batch_size
|
|
self.num_samples = num_samples
|
|
self.var_scale = var_scale
|
|
self.action_noise = action_noise
|
|
self.rho_init = rho_init
|
|
self.rho_max = rho_max
|
|
self.rho_scale = rho_scale
|
|
self.persist_multipliers = persist_multipliers
|
|
self.device = device
|
|
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
|
|
self.optimizer_cls = optimizer_cls
|
|
self.optimizer_kwargs = (
|
|
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
|
|
)
|
|
|
|
self._configured = False
|
|
self._n_envs = None
|
|
self._action_dim = None
|
|
self._config = None
|
|
self._lambdas: torch.Tensor | None = None # (n_envs, C)
|
|
|
|
def configure(
|
|
self, *, action_space: gym.Space, n_envs: int, config: Any
|
|
) -> None:
|
|
"""Configure the solver with environment specifications."""
|
|
self._action_space = action_space
|
|
self._n_envs = n_envs
|
|
self._config = config
|
|
self._action_dim = int(np.prod(action_space.shape[1:]))
|
|
self._configured = True
|
|
|
|
if not isinstance(action_space, Box):
|
|
logging.warning(
|
|
f'Action space is discrete, got {type(action_space)}. LagrangianSolver may not work as expected.'
|
|
)
|
|
|
|
@property
|
|
def n_envs(self) -> int:
|
|
"""Number of parallel environments."""
|
|
return self._n_envs
|
|
|
|
@property
|
|
def action_dim(self) -> int:
|
|
"""Flattened action dimension including action_block grouping."""
|
|
return self._action_dim * self._config.action_block
|
|
|
|
@property
|
|
def horizon(self) -> int:
|
|
"""Planning horizon in timesteps."""
|
|
return self._config.horizon
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> dict:
|
|
"""Make solver callable, forwarding to solve()."""
|
|
return self.solve(*args, **kwargs)
|
|
|
|
def init_action(self, actions: torch.Tensor | None = None) -> None:
|
|
"""Initialize the action tensor for optimization."""
|
|
if actions is None:
|
|
actions = torch.zeros((self._n_envs, 0, self.action_dim))
|
|
|
|
remaining = self.horizon - actions.shape[1]
|
|
if remaining > 0:
|
|
new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
|
|
actions = torch.cat([actions, new_actions], dim=1).to(self.device)
|
|
|
|
actions = actions.unsqueeze(1).repeat_interleave(
|
|
self.num_samples, dim=1
|
|
)
|
|
actions[:, 1:] += (
|
|
torch.randn(
|
|
actions[:, 1:].shape,
|
|
generator=self.torch_gen,
|
|
device=self.device,
|
|
)
|
|
* self.var_scale
|
|
)
|
|
|
|
if hasattr(self, 'init'):
|
|
self.init.copy_(actions)
|
|
else:
|
|
self.register_parameter('init', torch.nn.Parameter(actions))
|
|
|
|
def _init_multipliers(self, num_constraints: int) -> None:
|
|
"""Lazily initialize Lagrange multipliers to zeros."""
|
|
self._lambdas = torch.zeros(
|
|
self._n_envs, num_constraints, device=self.device
|
|
)
|
|
|
|
def _augmented_lagrangian_loss(
|
|
self,
|
|
costs: torch.Tensor, # (B, S)
|
|
constraints: torch.Tensor, # (B, S, C)
|
|
lambdas_batch: torch.Tensor, # (B, C)
|
|
rho: float,
|
|
) -> torch.Tensor:
|
|
"""Compute the augmented Lagrangian loss.
|
|
|
|
L = cost + Σ_i lambda_i * g_i + Σ_i rho * max(0, g_i)^2
|
|
"""
|
|
# lambdas_batch: (B, C) -> (B, 1, C) for broadcasting with constraints (B, S, C)
|
|
linear_penalty = (lambdas_batch.unsqueeze(1) * constraints).sum(
|
|
dim=-1
|
|
) # (B, S)
|
|
quadratic_penalty = rho * F.relu(constraints).pow(2).sum(
|
|
dim=-1
|
|
) # (B, S)
|
|
return (costs + linear_penalty + quadratic_penalty).sum()
|
|
|
|
def _update_multipliers(
|
|
self,
|
|
constraints: torch.Tensor, # (B, S, C) — detached, no grad
|
|
lambdas_batch: torch.Tensor, # (B, C)
|
|
rho: float,
|
|
) -> torch.Tensor:
|
|
"""Dual ascent: lambda_i <- max(0, lambda_i + rho * mean_samples(g_i))."""
|
|
mean_g = constraints.mean(dim=1) # (B, C)
|
|
return torch.clamp(lambdas_batch + rho * mean_g, min=0.0)
|
|
|
|
def solve(
|
|
self, info_dict: dict, init_action: torch.Tensor | None = None
|
|
) -> dict:
|
|
"""Solve the planning problem using augmented Lagrangian gradient descent."""
|
|
start_time = time.time()
|
|
outputs: dict = {
|
|
'cost': [],
|
|
'constraint_violation': [],
|
|
'actions': None,
|
|
'lambdas': None,
|
|
}
|
|
|
|
with torch.no_grad():
|
|
self.init_action(init_action)
|
|
|
|
if not self.persist_multipliers:
|
|
self._lambdas = None
|
|
|
|
batch_size = (
|
|
self.batch_size if self.batch_size is not None else self.n_envs
|
|
)
|
|
total_envs = self.n_envs
|
|
batch_top_actions_list = []
|
|
|
|
for start_idx in range(0, total_envs, batch_size):
|
|
end_idx = min(start_idx + batch_size, total_envs)
|
|
current_bs = end_idx - start_idx
|
|
|
|
batch_init = self.init[start_idx:end_idx].clone().detach()
|
|
batch_init.requires_grad = True
|
|
|
|
# Expand info_dict for current batch — same pattern as GradientSolver
|
|
expanded_infos = {}
|
|
for k, v in info_dict.items():
|
|
if torch.is_tensor(v):
|
|
batch_v = v[start_idx:end_idx]
|
|
batch_v = batch_v.unsqueeze(1)
|
|
batch_v = batch_v.expand(
|
|
current_bs, self.num_samples, *batch_v.shape[2:]
|
|
)
|
|
elif isinstance(v, np.ndarray):
|
|
batch_v = v[start_idx:end_idx]
|
|
batch_v = np.repeat(
|
|
batch_v[:, None, ...], self.num_samples, axis=1
|
|
)
|
|
else:
|
|
batch_v = v
|
|
expanded_infos[k] = batch_v
|
|
|
|
rho = self.rho_init
|
|
batch_cost_history = []
|
|
costs = None
|
|
final_constraints = None
|
|
|
|
for _outer in range(self.n_outer_steps):
|
|
# Fresh optimizer each outer step — avoids stale momentum after dual ascent
|
|
optim = self.optimizer_cls(
|
|
[batch_init], **self.optimizer_kwargs
|
|
)
|
|
|
|
for _step in range(self.n_steps):
|
|
current_info = expanded_infos.copy()
|
|
costs = self.model.get_cost(current_info, batch_init)
|
|
constraints = (
|
|
self.model.get_constraints(
|
|
expanded_infos.copy(), batch_init
|
|
)
|
|
if hasattr(self.model, 'get_constraints')
|
|
else None
|
|
)
|
|
|
|
assert isinstance(costs, torch.Tensor), (
|
|
f'Got {type(costs)} cost, expect torch.Tensor'
|
|
)
|
|
assert costs.ndim == 2 and costs.shape == (
|
|
current_bs,
|
|
self.num_samples,
|
|
), (
|
|
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
|
|
)
|
|
assert costs.requires_grad, (
|
|
'Cost must requires_grad for LagrangianSolver.'
|
|
)
|
|
|
|
if constraints is not None:
|
|
assert constraints.ndim == 3 and constraints.shape[
|
|
:2
|
|
] == (current_bs, self.num_samples), (
|
|
f'Constraints should be of shape ({current_bs}, {self.num_samples}, C), got {constraints.shape}'
|
|
)
|
|
if self._lambdas is None:
|
|
self._init_multipliers(constraints.shape[-1])
|
|
lambdas_batch = self._lambdas[start_idx:end_idx]
|
|
loss = self._augmented_lagrangian_loss(
|
|
costs, constraints, lambdas_batch, rho
|
|
)
|
|
else:
|
|
loss = costs.sum()
|
|
|
|
loss.backward()
|
|
optim.step()
|
|
optim.zero_grad(set_to_none=True)
|
|
|
|
if self.action_noise > 0:
|
|
batch_init.data += (
|
|
torch.randn(
|
|
batch_init.shape, generator=self.torch_gen
|
|
)
|
|
* self.action_noise
|
|
)
|
|
|
|
batch_cost_history.append(loss.item())
|
|
|
|
# Dual ascent after inner loop converges
|
|
if constraints is not None:
|
|
with torch.no_grad():
|
|
final_constraints = self.model.get_constraints(
|
|
expanded_infos.copy(), batch_init
|
|
)
|
|
lambdas_batch = self._update_multipliers(
|
|
final_constraints, lambdas_batch, rho
|
|
)
|
|
self._lambdas[start_idx:end_idx] = lambdas_batch
|
|
rho = min(self.rho_max, rho * self.rho_scale)
|
|
|
|
with torch.no_grad():
|
|
mean_cost = costs.mean().item()
|
|
if constraints is not None:
|
|
viol = F.relu(final_constraints).mean(dim=(0, 1)) # (C,)
|
|
lam = lambdas_batch.mean(dim=0) # (C,)
|
|
viol_str = ', '.join(f'{v:.4f}' for v in viol.tolist())
|
|
lam_str = ', '.join(f'{l:.4f}' for l in lam.tolist())
|
|
print(
|
|
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
|
f'cost={mean_cost:.4f} | '
|
|
f'constraint_viol=[{viol_str}] | '
|
|
f'lambdas=[{lam_str}] | '
|
|
f'rho={rho:.4f}'
|
|
)
|
|
else:
|
|
print(
|
|
f' [outer {_outer+1}/{self.n_outer_steps}] '
|
|
f'cost={mean_cost:.4f}'
|
|
)
|
|
|
|
outputs['cost'].append(batch_cost_history)
|
|
|
|
if final_constraints is not None:
|
|
outputs['constraint_violation'].append(
|
|
F.relu(final_constraints).mean().item()
|
|
)
|
|
|
|
with torch.no_grad():
|
|
self.init[start_idx:end_idx] = batch_init
|
|
|
|
top_idx = torch.argsort(costs, dim=1)[:, 0]
|
|
batch_indices = torch.arange(current_bs)
|
|
top_actions_batch = batch_init[batch_indices, top_idx]
|
|
batch_top_actions_list.append(top_actions_batch.detach().cpu())
|
|
|
|
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
|
|
outputs['lambdas'] = (
|
|
self._lambdas.cpu() if self._lambdas is not None else None
|
|
)
|
|
|
|
constraint_info = ''
|
|
if outputs['constraint_violation']:
|
|
mean_viol = np.mean(outputs['constraint_violation'])
|
|
constraint_info = f' | constraint_violation={mean_viol:.4f}'
|
|
print(
|
|
f'LagrangianSolver.solve completed in {time.time() - start_time:.4f} seconds{constraint_info}.'
|
|
)
|
|
return outputs
|