Initial commit

This commit is contained in:
Lucas Maes
2026-03-12 22:56:21 -04:00
committed by lucas-maes
commit 83f97d72ad
21 changed files with 1355 additions and 0 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Lucas Maes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

127
README.md Normal file
View File

@@ -0,0 +1,127 @@
# LeWorldModel
### Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
[Lucas Maes*](https://x.com/lucasmaes_), [Quentin Le Lidec*](https://quentinll.github.io/), [Damien Scieur](https://scholar.google.com/citations?user=hNscQzgAAAAJ&hl=fr), [Yann LeCun](https://yann.lecun.com/) and [Randall Balestriero](https://randallbalestriero.github.io/)
**Abstract:** Joint Embedding Predictive Architectures (JEPAs) offer a compelling framework for learning world models in compact latent spaces, yet existing methods remain fragile, relying on complex multi-term losses, exponential moving averages, pretrained encoders, or auxiliary supervision to avoid representation collapse. In this work, we introduce LeWorldModel (LeWM), the first JEPA that trains stably end-to-end from raw pixels using only two loss terms: a next-embedding prediction loss and a regularizer enforcing Gaussian-distributed latent embeddings. This reduces tunable loss hyperparameters from six to one compared to the only existing end-to-end alternative. With ~15M parameters trainable on a single GPU in a few hours, LeWM plans up to 48× faster than foundation-model-based world models while remaining competitive across diverse 2D and 3D control tasks. Beyond control, we show that LeWM's latent space encodes meaningful physical structure through probing of physical quantities. Surprise evaluation confirms that the model reliably detects physically implausible events.
<p align="center">
<b>[ <a href="https://arxiv.org/pdf/2603.19312v1">Paper</a> | <a href="https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e?usp=sharing">Data</a> | <a href="https://le-wm.github.io/">Website</a> ]</b>
</p>
<br>
<p align="center">
<img src="assets/lewm.gif" width="80%">
</p>
If you find this code useful, please reference in your paper:
```
@article{maes_lelidec2026lewm,
title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall},
journal={arXiv preprint},
year={2026}
}
```
## Using the code
This codebase builds on [stable-worldmodel](https://github.com/galilai-group/stable-worldmodel) for environment management, planning, and evaluation, and [stable-pretraining](https://github.com/galilai-group/stable-pretraining) for training. Together they reduce this repository to its core contribution: the model architecture and training objective.
**Installation:**
```bash
uv venv --python=3.10
source .venv/bin/activate
uv pip install stable-worldmodel[train,env]
```
## Data
Datasets use the HDF5 format for fast loading. Download the data from the [Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e?usp=sharing) and decompress with:
```bash
tar --zstd -xvf archive.tar.zst
```
Place the extracted `.h5` files under `$STABLEWM_HOME` (defaults to `~/.stable-wm/`). You can override this path:
```bash
export STABLEWM_HOME=/path/to/your/storage
```
Dataset names are specified without the `.h5` extension. For example, `config/train/data/pusht.yaml` references `pusht_expert_train`, which resolves to `$STABLEWM_HOME/pusht_expert_train.h5`.
## Training
`jepa.py` contains the PyTorch implementation of LeWM. Training is configured via [Hydra](https://hydra.cc/) config files under `config/train/`.
Before training, set your WandB `entity` and `project` in `config/train/lewm.yaml`:
```yaml
wandb:
config:
entity: your_entity
project: your_project
```
To launch training:
```bash
python train.py data=pusht
```
Checkpoints are saved to `$STABLEWM_HOME` upon completion.
For baseline scripts, see the stable-worldmodel [scripts](https://github.com/galilai-group/stable-worldmodel/tree/main/scripts/train) folder.
## Planning
Evaluation configs live under `config/eval/`. Set the `policy` field to the checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix:
```bash
# ✓ correct
python eval.py --config-name=pusht.yaml policy=pusht/lewm
# ✗ incorrect
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
```
## Pretrained Checkpoints
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.
<div align="center">
| Method | two-room | pusht | cube | reacher |
|:---:|:---:|:---:|:---:|:---:|
| pldm | ✓ | ✓ | ✓ | ✓ |
| lejepa | ✓ | ✓ | ✓ | ✓ |
| ivl | ✓ | ✓ | ✓ | — |
| iql | ✓ | ✓ | ✓ | — |
| gcbc | ✓ | ✓ | ✓ | — |
| dinowm | ✓ | ✓ | — | — |
| dinowm_noprop | ✓ | ✓ | ✓ | ✓ |
</div>
## Loading a checkpoint
Each tar archive contains two files per checkpoint:
- `<name>_object.ckpt` — a serialized Python object for convenient loading; this is what `eval.py` and the `stable_worldmodel` API use
- `<name>_weight.ckpt` — a weights-only checkpoint (`state_dict`) for cases where you want to load weights into your own model instance
To load the object checkpoint via the `stable_worldmodel` API:
```python
import stable_worldmodel as swm
# Load the cost model (for MPC)
cost = swm.policy.AutoCostModel('pusht/lewm')
```
Both functions accept:
- `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix
- `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`)
The returned module is in `eval` mode with its PyTorch weights accessible via `.state_dict()`.
## Contact & Contributions
Feel free to open [issues](https://github.com/lucas-maes/le-wm/issues)! For questions or collaborations, please contact `lucas.maes@mila.quebec`

BIN
assets/lewm.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 MiB

61
config/eval/cube.yaml Normal file
View File

@@ -0,0 +1,61 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/OGBCube-v0
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
env_type: single
ob_type: states
multiview: False
width: 224
height: 224
visualize_info: False
terminate_at_goal: True
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
seed: 42
policy: random # ckpt name or random
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: ogbench/cube_single_expert
callables:
# -- set state
- method: set_state
args:
qpos:
value: qpos
qvel:
value: qvel
# -- set target pos
- method: set_target_pos
args:
cube_id:
value: 0
in_dataset: False
target_pos:
value: goal_privileged_block_0_pos
target_quat:
value: goal_privileged_block_0_quat
output:
filename: ogb_cube_results.txt

View File

@@ -0,0 +1,7 @@
# @package _global_
# Local launcher configuration (no SLURM)
defaults:
- override /hydra/launcher: basic
cache_dir: null # use stable-worldmodel default cache

48
config/eval/pusht.yaml Normal file
View File

@@ -0,0 +1,48 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/PushT-v1
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
- state
seed: 42
policy: random # ckpt name or random
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: pusht_expert_train
callables:
# -- set state
- method: _set_state
args:
state:
value: state
# -- set goal state
- method: _set_goal_state
args:
goal_state:
value: goal_state
output:
filename: pusht_results.txt

50
config/eval/reacher.yaml Normal file
View File

@@ -0,0 +1,50 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/ReacherDMControl-v0
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
task: qpos_match
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
seed: 42
policy: random # ckpt name or random
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: dmc/reacher_random
callables:
# -- set state
- method: set_state
args:
qpos:
value: qpos
qvel:
value: qvel
- method: set_target_qpos
args:
target_qpos:
value: goal_qpos
output:
filename: dmc_results.txt

View File

@@ -0,0 +1,13 @@
_target_: stable_worldmodel.solver.GradientSolver
model: ???
n_steps: 30
batch_size: 1
num_samples: 100
action_noise: 0
device: "cuda"
seed: ${seed}
optimizer_cls:
_target_: hydra.utils.get_class
path: torch.optim.AdamW
optimizer_kwargs:
lr: 0.1

View File

@@ -0,0 +1,9 @@
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: "cuda"
seed: ${seed}

47
config/eval/tworoom.yaml Normal file
View File

@@ -0,0 +1,47 @@
defaults:
- launcher: local
- solver: cem
- _self_
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: ??? # make sure it's >= eval_budget
history_size: 1
frame_skip: 1
seed: 42
policy: random # ckpt name or random
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5 # frameskip
# evaluation from dataset (replay expert trajectories)
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
# -- set state
- method: _set_state
args:
state:
value: proprio
# -- set goal state
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt

View File

@@ -0,0 +1,11 @@
dataset:
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
name: reacher
keys_to_load:
- pixels
- action
- observation
keys_to_cache:
- action
- observation

View File

@@ -0,0 +1,13 @@
dataset:
name: ogbench/cube_single_expert
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
keys_to_load:
- pixels
- action
- observation
keys_to_cache:
- action
- observation
keys_to_merge:
proprio: proprio

View File

@@ -0,0 +1,13 @@
dataset:
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
name: pusht_expert_train
keys_to_load:
- pixels
- action
- proprio
- state
keys_to_cache:
- action
- proprio
- state

View File

@@ -0,0 +1,11 @@
dataset:
num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'}
frameskip: 5
name: tworoom
keys_to_load:
- pixels
- action
- proprio
keys_to_cache:
- action
- proprio

View File

@@ -0,0 +1,11 @@
# @package _global_
# Local launcher configuration (no SLURM)
defaults:
- override /hydra/launcher: basic
wandb:
enabled: True
config:
project: le-wm
entity: le-wm

64
config/train/lewm.yaml Normal file
View File

@@ -0,0 +1,64 @@
defaults:
- _self_
- data: pusht
output_model_name: lewm
subdir: ${hydra:job.id}
num_workers: 6
train_split: 0.9
seed: 3072
img_size: 224
patch_size: 14
encoder_scale: tiny
dump_object: True
trainer:
max_epochs: 100
devices: auto
accelerator: gpu
precision: bf16
gradient_clip_val: 1.0
loader:
batch_size: 128
num_workers: ${num_workers}
persistent_workers: True
prefetch_factor: 3
pin_memory: True
optimizer:
type: AdamW
lr: 5e-5
weight_decay: 1e-3
wandb:
enabled: True
config:
entity: lewm
project: lewm
name: ${output_model_name}
id: ${subdir}
resume: allow
log_model: False
wm:
type: lewm
history_size: 3
num_preds: 1
embed_dim: 192
predictor:
depth: 6
heads: 16
mlp_dim: 2048
dim_head: 64
dropout: 0.1
emb_dropout: 0.0
loss:
sigreg:
weight: 0.09
kwargs:
knots: 17
num_proj: 1024

171
eval.py Normal file
View File

@@ -0,0 +1,171 @@
import os
os.environ["MUJOCO_GL"] = "egl"
import time
from pathlib import Path
import hydra
import numpy as np
import stable_pretraining as spt
import torch
from omegaconf import DictConfig, OmegaConf
from sklearn import preprocessing
from torchvision.transforms import v2 as transforms
import stable_worldmodel as swm
def img_transform(cfg):
transform = transforms.Compose(
[
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(**spt.data.dataset_stats.ImageNet),
transforms.Resize(size=cfg.eval.img_size),
]
)
return transform
def get_episodes_length(dataset, episodes):
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
episode_idx = dataset.get_col_data(col_name)
step_idx = dataset.get_col_data("step_idx")
lengths = []
for ep_id in episodes:
lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1)
return np.array(lengths)
def get_dataset(cfg, dataset_name):
dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir())
dataset = swm.data.HDF5Dataset(
dataset_name,
keys_to_cache=cfg.dataset.keys_to_cache,
cache_dir=dataset_path,
)
return dataset
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
def run(cfg: DictConfig):
"""Run evaluation of dinowm vs random policy."""
assert (
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
), "Planning horizon must be smaller than or equal to eval_budget"
# create world environment
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
world = swm.World(**cfg.world, image_shape=(224, 224))
# create the transform
transform = {
"pixels": img_transform(cfg),
"goal": img_transform(cfg),
}
dataset = get_dataset(cfg, cfg.eval.dataset_name)
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
process = {}
for col in cfg.dataset.keys_to_cache:
if col in ["pixels"]:
continue
processor = preprocessing.StandardScaler()
col_data = stats_dataset.get_col_data(col)
col_data = col_data[~np.isnan(col_data).any(axis=1)]
processor.fit(col_data)
process[col] = processor
if col != "action":
process[f"goal_{col}"] = process[col]
# -- run evaluation
policy = cfg.get("policy", "random")
if policy != "random":
model = swm.policy.AutoCostModel(cfg.policy)
model = model.to("cuda")
model = model.eval()
model.requires_grad_(False)
model.interpolate_pos_encoding = True
config = swm.PlanConfig(**cfg.plan_config)
solver = hydra.utils.instantiate(cfg.solver, model=model)
policy = swm.policy.WorldModelPolicy(
solver=solver, config=config, process=process, transform=transform
)
else:
policy = swm.policy.RandomPolicy()
results_path = (
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
if cfg.policy != "random"
else Path(__file__).parent
)
# sample the episodes and the starting indices
episode_len = get_episodes_length(dataset, ep_indices)
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
# Map each dataset rows episode_idx to its max_start_idx
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
max_start_per_row = np.array(
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
)
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
valid_indices = np.nonzero(valid_mask)[0]
print(valid_mask.sum(), "valid starting points found for evaluation.")
g = np.random.default_rng(cfg.seed)
random_episode_indices = g.choice(
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
)
# sort increasingly to avoid issues with HDF5Dataset indexing
random_episode_indices = np.sort(valid_indices[random_episode_indices])
print(random_episode_indices)
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
if len(eval_episodes) < cfg.eval.num_eval:
raise ValueError("Not enough episodes with sufficient length for evaluation.")
world.set_policy(policy)
start_time = time.time()
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=results_path,
)
end_time = time.time()
print(metrics)
results_path = results_path / cfg.output.filename
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("a") as f:
f.write("\n") # separate from previous runs
f.write("==== CONFIG ====\n")
f.write(OmegaConf.to_yaml(cfg))
f.write("\n")
f.write("==== RESULTS ====\n")
f.write(f"metrics: {metrics}\n")
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
if __name__ == "__main__":
run()

153
jepa.py Normal file
View File

@@ -0,0 +1,153 @@
"""JEPA Implementation"""
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
def detach_clone(v):
return v.detach().clone() if torch.is_tensor(v) else v
class JEPA(nn.Module):
def __init__(
self,
encoder,
predictor,
action_encoder,
projector=None,
pred_proj=None,
):
super().__init__()
self.encoder = encoder
self.predictor = predictor
self.action_encoder = action_encoder
self.projector = projector or nn.Identity()
self.pred_proj = pred_proj or nn.Identity()
def encode(self, info):
"""Encode observations and actions into embeddings.
info: dict with pixels and action keys
"""
pixels = info['pixels'].float()
b = pixels.size(0)
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
output = self.encoder(pixels, interpolate_pos_encoding=True)
pixels_emb = output.last_hidden_state[:, 0] # cls token
emb = self.projector(pixels_emb)
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
if "action" in info:
info["act_emb"] = self.action_encoder(info["action"])
return info
def predict(self, emb, act_emb):
"""Predict next state embedding
emb: (B, T, D)
act_emb: (B, T, A_emb)
"""
preds = self.predictor(emb, act_emb)
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
return preds
####################
## Inference only ##
####################
def rollout(self, info, action_sequence, history_size: int = 3):
"""Rollout the model given an initial info dict and action sequence.
pixels: (B, S, T, C, H, W)
action_sequence: (B, S, T, action_dim)
- S is the number of action plan samples
- T is the time horizon
"""
assert "pixels" in info, "pixels not in info_dict"
H = info["pixels"].size(2)
B, S, T = action_sequence.shape[:3]
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
info["action"] = act_0
n_steps = T - H
# copy and encode initial info dict
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
_init = self.encode(_init)
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
_init = {k: detach_clone(v) for k, v in _init.items()}
# flatten batch and sample dimensions for rollout
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
act = rearrange(act_0, "b s ... -> (b s) ...")
act_future = rearrange(act_future, "b s ... -> (b s) ...")
# rollout predictor autoregressively for n_steps
HS = history_size
for t in range(n_steps):
act_emb = self.action_encoder(act)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
# predict the last state
act_emb = self.action_encoder(act) # (BS, T, A_emb)
emb_trunc = emb[:, -HS:] # (BS, HS, D)
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
emb = torch.cat([emb, pred_emb], dim=1)
# unflatten batch and sample dimensions
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
info["predicted_emb"] = pred_rollout
return info
def criterion(self, info_dict: dict):
"""Compute the cost between predicted embeddings and goal embeddings."""
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
# return last-step cost per action candidate
cost = F.mse_loss(
pred_emb[..., -1:, :],
goal_emb[..., -1:, :].detach(),
reduction="none",
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
return cost
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
""" Compute the cost of action candidates given an info dict with goal and initial state."""
assert "goal" in info_dict, "goal not in info_dict"
device = next(self.parameters()).device
for k in list(info_dict.keys()):
if torch.is_tensor(info_dict[k]):
info_dict[k] = info_dict[k].to(device)
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
goal["pixels"] = goal["goal"]
for k in info_dict:
if k.startswith("goal_"):
goal[k[len("goal_") :]] = goal.pop(k)
goal.pop("action")
goal = self.encode(goal)
info_dict["goal_emb"] = goal["emb"]
info_dict = self.rollout(info_dict, action_candidates)
cost = self.criterion(info_dict)
return cost

285
module.py Normal file
View File

@@ -0,0 +1,285 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
def modulate(x, shift, scale):
"""AdaLN-zero modulation"""
return x * (1 + scale) + shift
class SIGReg(torch.nn.Module):
"""Sketch Isotropic Gaussian Regularizer (single-GPU!)"""
def __init__(self, knots=17, num_proj=1024):
super().__init__()
self.num_proj = num_proj
t = torch.linspace(0, 3, knots, dtype=torch.float32)
dt = 3 / (knots - 1)
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
weights[[0, -1]] = dt
window = torch.exp(-t.square() / 2.0)
self.register_buffer("t", t)
self.register_buffer("phi", window)
self.register_buffer("weights", weights * window)
def forward(self, proj):
"""
proj: (T, B, D)
"""
# sample random projections
A = torch.randn(proj.size(-1), self.num_proj, device="cuda")
A = A.div_(A.norm(p=2, dim=0))
# compute the epps-pulley statistic
x_t = (proj @ A).unsqueeze(-1) * self.t
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
statistic = (err @ self.weights) * proj.size(-2)
return statistic.mean() # average over projections and time
class FeedForward(nn.Module):
"""FeedForward network used in Transformers"""
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
"""Scaled dot-product attention with causal masking"""
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.dropout = dropout
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, causal=True):
"""
x : (B, T, D)
"""
x = self.norm(x)
drop = self.dropout if self.training else 0.0
qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head)
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
out = rearrange(out, "b h t d -> b t (h d)")
return self.to_out(out)
class ConditionalBlock(nn.Module):
"""Transformer block with AdaLN-zero conditioning"""
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)
)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(c).chunk(6, dim=-1)
)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class Block(nn.Module):
"""Standard Transformer block"""
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class Transformer(nn.Module):
"""Standard Transformer with support for AdaLN-zero blocks"""
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout=0.0,
block_class=Block,
):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.layers = nn.ModuleList([])
self.input_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.cond_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.output_proj = (
nn.Linear(hidden_dim, output_dim)
if hidden_dim != output_dim
else nn.Identity()
)
for _ in range(depth):
self.layers.append(
block_class(hidden_dim, heads, dim_head, mlp_dim, dropout)
)
def forward(self, x, c=None):
if hasattr(self, "input_proj"):
x = self.input_proj(x)
if c is not None and hasattr(self, "cond_proj"):
c = self.cond_proj(c)
for block in self.layers:
x = block(x) if isinstance(block, Block) else block(x, c)
x = self.norm(x)
if hasattr(self, "output_proj"):
x = self.output_proj(x)
return x
class Embedder(nn.Module):
def __init__(
self,
input_dim=10,
smoothed_dim=10,
emb_dim=10,
mlp_scale=4,
):
super().__init__()
self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1)
self.embed = nn.Sequential(
nn.Linear(smoothed_dim, mlp_scale * emb_dim),
nn.SiLU(),
nn.Linear(mlp_scale * emb_dim, emb_dim),
)
def forward(self, x):
"""
x: (B, T, D)
"""
x = x.float()
x = x.permute(0, 2, 1)
x = self.patch_embed(x)
x = x.permute(0, 2, 1)
x = self.embed(x)
return x
class MLP(nn.Module):
"""Simple MLP with optional normalization and activation"""
def __init__(
self,
input_dim,
hidden_dim,
output_dim=None,
norm_fn=nn.LayerNorm,
act_fn=nn.GELU,
):
super().__init__()
norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
norm_fn,
act_fn(),
nn.Linear(hidden_dim, output_dim or input_dim),
)
def forward(self, x):
"""
x: (B*T, D)
"""
return self.net(x)
class ARPredictor(nn.Module):
"""Autoregressive predictor for next-step embedding prediction."""
def __init__(
self,
*,
num_frames,
depth,
heads,
mlp_dim,
input_dim,
hidden_dim,
output_dim=None,
dim_head=64,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
input_dim,
hidden_dim,
output_dim or input_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
block_class=ConditionalBlock,
)
def forward(self, x, c):
"""
x: (B, T, d)
c: (B, T, act_dim)
"""
T = x.size(1)
x = x + self.pos_embedding[:, :T]
x = self.dropout(x)
x = self.transformer(x, c)
return x

183
train.py Normal file
View File

@@ -0,0 +1,183 @@
import os
from functools import partial
from pathlib import Path
import hydra
import lightning as pl
import stable_pretraining as spt
import stable_worldmodel as swm
import torch
from lightning.pytorch.loggers import WandbLogger
from omegaconf import OmegaConf, open_dict
from jepa import JEPA
from module import ARPredictor, Embedder, MLP, SIGReg
from utils import get_column_normalizer, get_img_preprocessor, ModelObjectCallBack
def lejepa_forward(self, batch, stage, cfg):
"""encode observations, predict next states, compute losses."""
ctx_len = cfg.wm.history_size
n_preds = cfg.wm.num_preds
lambd = cfg.loss.sigreg.weight
# Replace NaN values with 0 (occurs at sequence boundaries)
batch["action"] = torch.nan_to_num(batch["action"], 0.0)
output = self.model.encode(batch)
emb = output["emb"] # (B, T, D)
act_emb = output["act_emb"]
ctx_emb = emb[:, :ctx_len]
ctx_act = act_emb[:, : ctx_len]
tgt_emb = emb[:, n_preds:] # label
pred_emb = self.model.predict(ctx_emb, ctx_act) # pred
# LeWM loss
output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean()
output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1))
output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"]
losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k}
self.log_dict(losses_dict, on_step=True, sync_dist=True)
return output
@hydra.main(version_base=None, config_path="./config/train", config_name="lewm")
def run(cfg):
#########################
## dataset ##
#########################
dataset = swm.data.HDF5Dataset(**cfg.data.dataset, transform=None)
transforms = [get_img_preprocessor(source='pixels', target='pixels', img_size=cfg.img_size)]
with open_dict(cfg):
for col in cfg.data.dataset.keys_to_load:
if col.startswith("pixels"):
continue
normalizer = get_column_normalizer(dataset, col, col)
transforms.append(normalizer)
setattr(cfg.wm, f"{col}_dim", dataset.get_dim(col))
transform = spt.data.transforms.Compose(*transforms)
dataset.transform = transform
rnd_gen = torch.Generator().manual_seed(cfg.seed)
train_set, val_set = spt.data.random_split(
dataset, lengths=[cfg.train_split, 1 - cfg.train_split], generator=rnd_gen
)
train = torch.utils.data.DataLoader(train_set, **cfg.loader,shuffle=True, drop_last=True, generator=rnd_gen)
val = torch.utils.data.DataLoader(val_set, **cfg.loader, shuffle=False, drop_last=False)
##############################
## model / optim ##
##############################
encoder = spt.backbone.utils.vit_hf(
cfg.encoder_scale,
patch_size=cfg.patch_size,
image_size=cfg.img_size,
pretrained=False,
use_mask_token=False,
)
hidden_dim = encoder.config.hidden_size
embed_dim = cfg.wm.get("embed_dim", hidden_dim)
effective_act_dim = cfg.data.dataset.frameskip * cfg.wm.action_dim
predictor = ARPredictor(
num_frames=cfg.wm.history_size,
input_dim=embed_dim,
hidden_dim=hidden_dim,
output_dim=hidden_dim,
**cfg.predictor,
)
action_encoder = Embedder(input_dim=effective_act_dim, emb_dim=embed_dim)
projector = MLP(
input_dim=hidden_dim,
output_dim=embed_dim,
hidden_dim=2048,
norm_fn=torch.nn.BatchNorm1d,
)
predictor_proj = MLP(
input_dim=hidden_dim,
output_dim=embed_dim,
hidden_dim=2048,
norm_fn=torch.nn.BatchNorm1d,
)
world_model = JEPA(
encoder=encoder,
predictor=predictor,
action_encoder=action_encoder,
projector=projector,
pred_proj=predictor_proj,
)
optimizers = {
'model_opt': {
"modules": 'model',
"optimizer": dict(cfg.optimizer),
"scheduler": {"type": "LinearWarmupCosineAnnealingLR"},
"interval": "epoch",
},
}
data_module = spt.data.DataModule(train=train, val=val)
world_model = spt.Module(
model = world_model,
sigreg = SIGReg(**cfg.loss.sigreg.kwargs),
forward=partial(lejepa_forward, cfg=cfg),
optim=optimizers,
)
##########################
## training ##
##########################
run_id = cfg.get("subdir") or ""
run_dir = Path(swm.data.utils.get_cache_dir(), run_id)
logger = None
if cfg.wandb.enabled:
logger = WandbLogger(**cfg.wandb.config)
logger.log_hyperparams(OmegaConf.to_container(cfg))
run_dir.mkdir(parents=True, exist_ok=True)
with open(run_dir / "config.yaml", "w") as f:
OmegaConf.save(cfg, f)
object_dump_callback = ModelObjectCallBack(
dirpath=run_dir, filename=cfg.output_model_name, epoch_interval=1,
)
trainer = pl.Trainer(
**cfg.trainer,
callbacks=[object_dump_callback],
num_sanity_val_steps=1,
logger=logger,
enable_checkpointing=True,
)
manager = spt.Manager(
trainer=trainer,
module=world_model,
data=data_module,
ckpt_path=run_dir / f"{cfg.output_model_name}_weights.ckpt",
)
manager()
return
if __name__ == "__main__":
run()

57
utils.py Normal file
View File

@@ -0,0 +1,57 @@
import numpy as np
import torch
from pathlib import Path
from stable_pretraining import data as dt
from lightning.pytorch.callbacks import Callback
def get_img_preprocessor(source: str, target: str, img_size: int = 224):
imagenet_stats = dt.dataset_stats.ImageNet
to_image = dt.transforms.ToImage(**imagenet_stats, source=source, target=target)
resize = dt.transforms.Resize(img_size, source=source, target=target)
return dt.transforms.Compose(to_image, resize)
def get_column_normalizer(dataset, source: str, target: str):
"""Get normalizer for a specific column in the dataset."""
col_data = dataset.get_col_data(source)
data = torch.from_numpy(np.array(col_data))
data = data[~torch.isnan(data).any(dim=1)]
mean = data.mean(0, keepdim=True).clone()
std = data.std(0, keepdim=True).clone()
def norm_fn(x):
return ((x - mean) / std).float()
normalizer = dt.transforms.WrapTorchTransform(norm_fn, source=source, target=target)
return normalizer
class ModelObjectCallBack(Callback):
"""Callback to pickle model object after each epoch."""
def __init__(self, dirpath, filename="model_object", epoch_interval: int = 1):
super().__init__()
self.dirpath = Path(dirpath)
self.filename = filename
self.epoch_interval = epoch_interval
def on_train_epoch_end(self, trainer, pl_module):
super().on_train_epoch_end(trainer, pl_module)
output_path = (
self.dirpath
/ f"{self.filename}_epoch_{trainer.current_epoch + 1}_object.ckpt"
)
if trainer.is_global_zero:
if (trainer.current_epoch + 1) % self.epoch_interval == 0:
self._dump_model(pl_module.model, output_path)
# save final epoch
if (trainer.current_epoch + 1) == trainer.max_epochs:
self._dump_model(pl_module.model, output_path)
def _dump_model(self, model, path):
try:
torch.save(model, path)
except Exception as e:
print(f"Error saving model object: {e}")