Files
lewm/.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py

253 lines
8.8 KiB
Python

"""Gradient-based solver for model-based planning."""
import time
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
class GradientSolver(torch.nn.Module):
"""Gradient-based solver using backpropagation through the world model.
Args:
model: World model implementing the Costable protocol.
n_steps: Number of gradient descent iterations.
batch_size: Number of environments to process in parallel.
var_scale: Initial variance scale for action perturbations.
num_samples: Number of action samples to optimize in parallel.
action_noise: Noise added to actions during optimization.
device: Device for tensor computations.
seed: Random seed for reproducibility.
optimizer_cls: PyTorch optimizer class to use.
optimizer_kwargs: Keyword arguments for the optimizer.
"""
def __init__(
self,
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | torch.device = 'cpu',
seed: int = 1234,
optimizer_cls: type[torch.optim.Optimizer] = torch.optim.SGD,
optimizer_kwargs: dict | None = None,
) -> None:
super().__init__()
self.model = model
self.n_steps = n_steps
self.batch_size = batch_size
self.num_samples = num_samples
self.var_scale = var_scale
self.action_noise = action_noise
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = (
optimizer_kwargs if optimizer_kwargs is not None else {'lr': 1.0}
)
self._configured = False
self._n_envs = None
self._action_dim = None
self._config = None
def configure(
self, *, action_space: gym.Space, n_envs: int, config: Any
) -> None:
"""Configure the solver with environment specifications."""
self._action_space = action_space
self._n_envs = n_envs
self._config = config
self._action_dim = int(np.prod(action_space.shape[1:]))
self._configured = True
if not isinstance(action_space, Box):
logging.warning(
f'Action space is discrete, got {type(action_space)}. GradientSolver may not work as expected.'
)
@property
def n_envs(self) -> int:
"""Number of parallel environments."""
return self._n_envs
@property
def action_dim(self) -> int:
"""Flattened action dimension including action_block grouping."""
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
"""Planning horizon in timesteps."""
return self._config.horizon
def __call__(self, *args: Any, **kwargs: Any) -> dict:
"""Make solver callable, forwarding to solve()."""
return self.solve(*args, **kwargs)
def init_action(self, actions: torch.Tensor | None = None) -> None:
"""Initialize the action tensor for optimization."""
device = torch.device(self.device)
if actions is None:
actions = torch.zeros(
(self._n_envs, 0, self.action_dim), device=device
)
elif actions.device != device:
actions = actions.to(device, non_blocking=True)
# fill remaining action
remaining = self.horizon - actions.shape[1]
if remaining > 0:
new_actions = torch.zeros(
self._n_envs,
remaining,
self.action_dim,
device=actions.device,
dtype=actions.dtype,
)
actions = torch.cat([actions, new_actions], dim=1)
actions = actions.unsqueeze(1).repeat_interleave(
self.num_samples, dim=1
) # add sample dim
actions[:, 1:] += (
torch.randn(
actions[:, 1:].shape,
generator=self.torch_gen,
device=self.device,
)
* self.var_scale
) # add small noise to all samples except the first one
# reset actions
if hasattr(self, 'init'):
self.init.copy_(actions)
else:
self.register_parameter('init', torch.nn.Parameter(actions))
def solve(
self, info_dict: dict, init_action: torch.Tensor | None = None
) -> dict:
"""Solve the planning problem using gradient descent."""
start_time = time.time()
outputs = {
'cost': [], # Will store list of cost histories per batch
'actions': None,
}
with torch.no_grad():
self.init_action(init_action)
# Determine batch size (default to all envs if not specified which can cause memory issues)
batch_size = (
self.batch_size if self.batch_size is not None else self.n_envs
)
total_envs = self.n_envs
# Lists to hold results from each batch to be concatenated later
batch_top_actions_list = []
# --- Outer Loop: Iterate over batches ---
for start_idx in range(0, total_envs, batch_size):
end_idx = min(start_idx + batch_size, total_envs)
current_bs = end_idx - start_idx
batch_init = self.init[start_idx:end_idx].clone().detach()
batch_init.requires_grad = True
# We initialize the optimizer class passed in __init__ with the kwargs
optim = self.optimizer_cls([batch_init], **self.optimizer_kwargs)
# Prepare Batch Infos
# Slice the input info_dict and then expand dimensions
expanded_infos = {}
for k, v in info_dict.items():
# Slice the data for the current batch indices
# Assumes input data dim 0 corresponds to n_envs
if torch.is_tensor(v):
batch_v = v[start_idx:end_idx]
batch_v = batch_v.unsqueeze(1)
batch_v = batch_v.expand(
current_bs, self.num_samples, *batch_v.shape[2:]
)
elif isinstance(v, np.ndarray):
batch_v = v[start_idx:end_idx]
batch_v = np.repeat(
batch_v[:, None, ...], self.num_samples, axis=1
)
expanded_infos[k] = batch_v
final_batch_cost = None
for step in range(self.n_steps):
current_info = expanded_infos.copy()
# Calculate cost using the batch parameter
costs = self.model.get_cost(current_info, batch_init)
assert isinstance(costs, torch.Tensor), (
f'Got {type(costs)} cost, expect torch.Tensor'
)
assert (
costs.ndim == 2
and costs.shape[0] == current_bs
and costs.shape[1] == self.num_samples
), (
f'Cost should be of shape ({current_bs}, {self.num_samples}), got {costs.shape}'
)
assert costs.requires_grad, (
'Cost must requires_grad for GD solver.'
)
cost = costs.sum() # Sum cost for this batch
cost.backward()
optim.step()
optim.zero_grad(set_to_none=True)
# Add noise
if self.action_noise > 0:
batch_init.data += (
torch.randn(
batch_init.shape,
generator=self.torch_gen,
device=self.device,
)
* self.action_noise
)
final_batch_cost = costs.detach().min(dim=1).values
# Store cost history for this batch
outputs['cost'].append(final_batch_cost)
# Update the global self.init with the optimized batch values
with torch.no_grad():
self.init[start_idx:end_idx] = batch_init
top_idx = costs.argmin(dim=1)
batch_indices = torch.arange(current_bs, device=self.device)
top_actions_batch = batch_init[batch_indices, top_idx]
batch_top_actions_list.append(top_actions_batch.detach())
# Concatenate all batch results
outputs['actions'] = torch.cat(batch_top_actions_list, dim=0)
outputs['cost'] = torch.cat(outputs['cost']).cpu().tolist()
end_time = time.time()
print(
f'GradientSolver.solve completed in {end_time - start_time:.4f} seconds.'
)
return outputs