commit 83f97d72ad067855bc89a1b74b4aff11d4dfdf0c Author: Lucas Maes <43337476+lucas-maes@users.noreply.github.com> Date: Thu Mar 12 22:56:21 2026 -0400 Initial commit diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9a0803f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Lucas Maes + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..55b0396 --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ + +# LeWorldModel +### Stable End-to-End Joint-Embedding Predictive Architecture from Pixels + +[Lucas Maes*](https://x.com/lucasmaes_), [Quentin Le Lidec*](https://quentinll.github.io/), [Damien Scieur](https://scholar.google.com/citations?user=hNscQzgAAAAJ&hl=fr), [Yann LeCun](https://yann.lecun.com/) and [Randall Balestriero](https://randallbalestriero.github.io/) + +**Abstract:** Joint Embedding Predictive Architectures (JEPAs) offer a compelling framework for learning world models in compact latent spaces, yet existing methods remain fragile, relying on complex multi-term losses, exponential moving averages, pretrained encoders, or auxiliary supervision to avoid representation collapse. In this work, we introduce LeWorldModel (LeWM), the first JEPA that trains stably end-to-end from raw pixels using only two loss terms: a next-embedding prediction loss and a regularizer enforcing Gaussian-distributed latent embeddings. This reduces tunable loss hyperparameters from six to one compared to the only existing end-to-end alternative. With ~15M parameters trainable on a single GPU in a few hours, LeWM plans up to 48× faster than foundation-model-based world models while remaining competitive across diverse 2D and 3D control tasks. Beyond control, we show that LeWM's latent space encodes meaningful physical structure through probing of physical quantities. Surprise evaluation confirms that the model reliably detects physically implausible events. + +

+ [ Paper | Data | Website ] +

+ +
+ +

+ +

+ +If you find this code useful, please reference in your paper: +``` +@article{maes_lelidec2026lewm, + title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels}, + author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall}, + journal={arXiv preprint}, + year={2026} +} +``` + +## Using the code +This codebase builds on [stable-worldmodel](https://github.com/galilai-group/stable-worldmodel) for environment management, planning, and evaluation, and [stable-pretraining](https://github.com/galilai-group/stable-pretraining) for training. Together they reduce this repository to its core contribution: the model architecture and training objective. + +**Installation:** +```bash +uv venv --python=3.10 +source .venv/bin/activate +uv pip install stable-worldmodel[train,env] +``` + +## Data + +Datasets use the HDF5 format for fast loading. Download the data from the [Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e?usp=sharing) and decompress with: + +```bash +tar --zstd -xvf archive.tar.zst +``` + +Place the extracted `.h5` files under `$STABLEWM_HOME` (defaults to `~/.stable-wm/`). You can override this path: +```bash +export STABLEWM_HOME=/path/to/your/storage +``` + +Dataset names are specified without the `.h5` extension. For example, `config/train/data/pusht.yaml` references `pusht_expert_train`, which resolves to `$STABLEWM_HOME/pusht_expert_train.h5`. + +## Training + +`jepa.py` contains the PyTorch implementation of LeWM. Training is configured via [Hydra](https://hydra.cc/) config files under `config/train/`. + +Before training, set your WandB `entity` and `project` in `config/train/lewm.yaml`: +```yaml +wandb: + config: + entity: your_entity + project: your_project +``` + +To launch training: +```bash +python train.py data=pusht +``` + +Checkpoints are saved to `$STABLEWM_HOME` upon completion. + +For baseline scripts, see the stable-worldmodel [scripts](https://github.com/galilai-group/stable-worldmodel/tree/main/scripts/train) folder. + +## Planning + +Evaluation configs live under `config/eval/`. Set the `policy` field to the checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix: + +```bash +# ✓ correct +python eval.py --config-name=pusht.yaml policy=pusht/lewm + +# ✗ incorrect +python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt +``` + +## Pretrained Checkpoints + +Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`. + +
+ +| Method | two-room | pusht | cube | reacher | +|:---:|:---:|:---:|:---:|:---:| +| pldm | ✓ | ✓ | ✓ | ✓ | +| lejepa | ✓ | ✓ | ✓ | ✓ | +| ivl | ✓ | ✓ | ✓ | — | +| iql | ✓ | ✓ | ✓ | — | +| gcbc | ✓ | ✓ | ✓ | — | +| dinowm | ✓ | ✓ | — | — | +| dinowm_noprop | ✓ | ✓ | ✓ | ✓ | + +
+ +## Loading a checkpoint + +Each tar archive contains two files per checkpoint: +- `_object.ckpt` — a serialized Python object for convenient loading; this is what `eval.py` and the `stable_worldmodel` API use +- `_weight.ckpt` — a weights-only checkpoint (`state_dict`) for cases where you want to load weights into your own model instance + +To load the object checkpoint via the `stable_worldmodel` API: + +```python +import stable_worldmodel as swm + +# Load the cost model (for MPC) +cost = swm.policy.AutoCostModel('pusht/lewm') +``` + +Both functions accept: +- `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix +- `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`) + +The returned module is in `eval` mode with its PyTorch weights accessible via `.state_dict()`. + +## Contact & Contributions +Feel free to open [issues](https://github.com/lucas-maes/le-wm/issues)! For questions or collaborations, please contact `lucas.maes@mila.quebec` diff --git a/assets/lewm.gif b/assets/lewm.gif new file mode 100644 index 0000000..a85d681 Binary files /dev/null and b/assets/lewm.gif differ diff --git a/config/eval/cube.yaml b/config/eval/cube.yaml new file mode 100644 index 0000000..3ba34bf --- /dev/null +++ b/config/eval/cube.yaml @@ -0,0 +1,61 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/OGBCube-v0 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + env_type: single + ob_type: states + multiview: False + width: 224 + height: 224 + visualize_info: False + terminate_at_goal: True + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + +seed: 42 +policy: random # ckpt name or random + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: ogbench/cube_single_expert + callables: + # -- set state + - method: set_state + args: + qpos: + value: qpos + qvel: + value: qvel + # -- set target pos + - method: set_target_pos + args: + cube_id: + value: 0 + in_dataset: False + target_pos: + value: goal_privileged_block_0_pos + target_quat: + value: goal_privileged_block_0_quat + +output: + filename: ogb_cube_results.txt + diff --git a/config/eval/launcher/local.yaml b/config/eval/launcher/local.yaml new file mode 100644 index 0000000..a1c7f4b --- /dev/null +++ b/config/eval/launcher/local.yaml @@ -0,0 +1,7 @@ +# @package _global_ +# Local launcher configuration (no SLURM) + +defaults: + - override /hydra/launcher: basic + +cache_dir: null # use stable-worldmodel default cache diff --git a/config/eval/pusht.yaml b/config/eval/pusht.yaml new file mode 100644 index 0000000..6584ef0 --- /dev/null +++ b/config/eval/pusht.yaml @@ -0,0 +1,48 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/PushT-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + - state + +seed: 42 +policy: random # ckpt name or random + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: pusht_expert_train + callables: + # -- set state + - method: _set_state + args: + state: + value: state + # -- set goal state + - method: _set_goal_state + args: + goal_state: + value: goal_state + +output: + filename: pusht_results.txt \ No newline at end of file diff --git a/config/eval/reacher.yaml b/config/eval/reacher.yaml new file mode 100644 index 0000000..d0c62dc --- /dev/null +++ b/config/eval/reacher.yaml @@ -0,0 +1,50 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/ReacherDMControl-v0 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + task: qpos_match + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + +seed: 42 +policy: random # ckpt name or random + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: dmc/reacher_random + callables: + # -- set state + - method: set_state + args: + qpos: + value: qpos + qvel: + value: qvel + + - method: set_target_qpos + args: + target_qpos: + value: goal_qpos + +output: + filename: dmc_results.txt + diff --git a/config/eval/solver/adam.yaml b/config/eval/solver/adam.yaml new file mode 100644 index 0000000..763b0f0 --- /dev/null +++ b/config/eval/solver/adam.yaml @@ -0,0 +1,13 @@ +_target_: stable_worldmodel.solver.GradientSolver +model: ??? +n_steps: 30 +batch_size: 1 +num_samples: 100 +action_noise: 0 +device: "cuda" +seed: ${seed} +optimizer_cls: + _target_: hydra.utils.get_class + path: torch.optim.AdamW +optimizer_kwargs: + lr: 0.1 \ No newline at end of file diff --git a/config/eval/solver/cem.yaml b/config/eval/solver/cem.yaml new file mode 100644 index 0000000..8d24fda --- /dev/null +++ b/config/eval/solver/cem.yaml @@ -0,0 +1,9 @@ +_target_: stable_worldmodel.solver.CEMSolver +model: ??? +batch_size: 1 +num_samples: 300 +var_scale: 1.0 +n_steps: 30 +topk: 30 +device: "cuda" +seed: ${seed} diff --git a/config/eval/tworoom.yaml b/config/eval/tworoom.yaml new file mode 100644 index 0000000..dd20571 --- /dev/null +++ b/config/eval/tworoom.yaml @@ -0,0 +1,47 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + +seed: 42 +policy: random # ckpt name or random + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + # -- set state + - method: _set_state + args: + state: + value: proprio + # -- set goal state + - method: _set_goal_state + args: + goal_state: + value: goal_proprio + +output: + filename: tworoom_results.txt \ No newline at end of file diff --git a/config/train/data/dmc.yaml b/config/train/data/dmc.yaml new file mode 100644 index 0000000..2c0a4e7 --- /dev/null +++ b/config/train/data/dmc.yaml @@ -0,0 +1,11 @@ +dataset: + num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'} + frameskip: 5 + name: reacher + keys_to_load: + - pixels + - action + - observation + keys_to_cache: + - action + - observation \ No newline at end of file diff --git a/config/train/data/ogb.yaml b/config/train/data/ogb.yaml new file mode 100644 index 0000000..d825958 --- /dev/null +++ b/config/train/data/ogb.yaml @@ -0,0 +1,13 @@ +dataset: + name: ogbench/cube_single_expert + num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'} + frameskip: 5 + keys_to_load: + - pixels + - action + - observation + keys_to_cache: + - action + - observation + keys_to_merge: + proprio: proprio \ No newline at end of file diff --git a/config/train/data/pusht.yaml b/config/train/data/pusht.yaml new file mode 100644 index 0000000..41979dd --- /dev/null +++ b/config/train/data/pusht.yaml @@ -0,0 +1,13 @@ +dataset: + num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'} + frameskip: 5 + name: pusht_expert_train + keys_to_load: + - pixels + - action + - proprio + - state + keys_to_cache: + - action + - proprio + - state \ No newline at end of file diff --git a/config/train/data/tworoom.yaml b/config/train/data/tworoom.yaml new file mode 100644 index 0000000..460d5fb --- /dev/null +++ b/config/train/data/tworoom.yaml @@ -0,0 +1,11 @@ +dataset: + num_steps: ${eval:'${wm.num_preds} + ${wm.history_size}'} + frameskip: 5 + name: tworoom + keys_to_load: + - pixels + - action + - proprio + keys_to_cache: + - action + - proprio \ No newline at end of file diff --git a/config/train/launcher/local.yaml b/config/train/launcher/local.yaml new file mode 100644 index 0000000..0f350cd --- /dev/null +++ b/config/train/launcher/local.yaml @@ -0,0 +1,11 @@ +# @package _global_ +# Local launcher configuration (no SLURM) + +defaults: + - override /hydra/launcher: basic + +wandb: + enabled: True + config: + project: le-wm + entity: le-wm \ No newline at end of file diff --git a/config/train/lewm.yaml b/config/train/lewm.yaml new file mode 100644 index 0000000..0cb29af --- /dev/null +++ b/config/train/lewm.yaml @@ -0,0 +1,64 @@ +defaults: + - _self_ + - data: pusht + +output_model_name: lewm +subdir: ${hydra:job.id} + +num_workers: 6 +train_split: 0.9 +seed: 3072 +img_size: 224 +patch_size: 14 +encoder_scale: tiny +dump_object: True + +trainer: + max_epochs: 100 + devices: auto + accelerator: gpu + precision: bf16 + gradient_clip_val: 1.0 + +loader: + batch_size: 128 + num_workers: ${num_workers} + persistent_workers: True + prefetch_factor: 3 + pin_memory: True + +optimizer: + type: AdamW + lr: 5e-5 + weight_decay: 1e-3 + +wandb: + enabled: True + config: + entity: lewm + project: lewm + name: ${output_model_name} + id: ${subdir} + resume: allow + log_model: False + +wm: + type: lewm + history_size: 3 + num_preds: 1 + embed_dim: 192 + +predictor: + depth: 6 + heads: 16 + mlp_dim: 2048 + dim_head: 64 + dropout: 0.1 + emb_dropout: 0.0 + +loss: + sigreg: + weight: 0.09 + kwargs: + knots: 17 + num_proj: 1024 diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..859afd1 --- /dev/null +++ b/eval.py @@ -0,0 +1,171 @@ +import os + +os.environ["MUJOCO_GL"] = "egl" + +import time +from pathlib import Path + +import hydra +import numpy as np +import stable_pretraining as spt +import torch +from omegaconf import DictConfig, OmegaConf +from sklearn import preprocessing +from torchvision.transforms import v2 as transforms +import stable_worldmodel as swm + +def img_transform(cfg): + transform = transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + transforms.Normalize(**spt.data.dataset_stats.ImageNet), + transforms.Resize(size=cfg.eval.img_size), + ] + ) + return transform + + +def get_episodes_length(dataset, episodes): + col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" + + episode_idx = dataset.get_col_data(col_name) + step_idx = dataset.get_col_data("step_idx") + lengths = [] + for ep_id in episodes: + lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1) + return np.array(lengths) + + +def get_dataset(cfg, dataset_name): + dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir()) + dataset = swm.data.HDF5Dataset( + dataset_name, + keys_to_cache=cfg.dataset.keys_to_cache, + cache_dir=dataset_path, + ) + return dataset + +@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") +def run(cfg: DictConfig): + """Run evaluation of dinowm vs random policy.""" + assert ( + cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget + ), "Planning horizon must be smaller than or equal to eval_budget" + + # create world environment + cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget + world = swm.World(**cfg.world, image_shape=(224, 224)) + + # create the transform + transform = { + "pixels": img_transform(cfg), + "goal": img_transform(cfg), + } + + dataset = get_dataset(cfg, cfg.eval.dataset_name) + stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats) + col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" + ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True) + + process = {} + for col in cfg.dataset.keys_to_cache: + if col in ["pixels"]: + continue + processor = preprocessing.StandardScaler() + col_data = stats_dataset.get_col_data(col) + col_data = col_data[~np.isnan(col_data).any(axis=1)] + processor.fit(col_data) + process[col] = processor + + if col != "action": + process[f"goal_{col}"] = process[col] + + # -- run evaluation + policy = cfg.get("policy", "random") + + if policy != "random": + model = swm.policy.AutoCostModel(cfg.policy) + model = model.to("cuda") + model = model.eval() + model.requires_grad_(False) + model.interpolate_pos_encoding = True + config = swm.PlanConfig(**cfg.plan_config) + solver = hydra.utils.instantiate(cfg.solver, model=model) + policy = swm.policy.WorldModelPolicy( + solver=solver, config=config, process=process, transform=transform + ) + + else: + policy = swm.policy.RandomPolicy() + + results_path = ( + Path(swm.data.utils.get_cache_dir(), cfg.policy).parent + if cfg.policy != "random" + else Path(__file__).parent + ) + + # sample the episodes and the starting indices + episode_len = get_episodes_length(dataset, ep_indices) + max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1 + max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)} + # Map each dataset row’s episode_idx to its max_start_idx + col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" + max_start_per_row = np.array( + [max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)] + ) + + # remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row + valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row + valid_indices = np.nonzero(valid_mask)[0] + print(valid_mask.sum(), "valid starting points found for evaluation.") + + g = np.random.default_rng(cfg.seed) + random_episode_indices = g.choice( + len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False + ) + + # sort increasingly to avoid issues with HDF5Dataset indexing + random_episode_indices = np.sort(valid_indices[random_episode_indices]) + + print(random_episode_indices) + + eval_episodes = dataset.get_row_data(random_episode_indices)[col_name] + eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"] + + if len(eval_episodes) < cfg.eval.num_eval: + raise ValueError("Not enough episodes with sufficient length for evaluation.") + + world.set_policy(policy) + + start_time = time.time() + metrics = world.evaluate_from_dataset( + dataset, + start_steps=eval_start_idx.tolist(), + goal_offset_steps=cfg.eval.goal_offset_steps, + eval_budget=cfg.eval.eval_budget, + episodes_idx=eval_episodes.tolist(), + callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), + video_path=results_path, + ) + end_time = time.time() + + print(metrics) + + results_path = results_path / cfg.output.filename + results_path.parent.mkdir(parents=True, exist_ok=True) + + with results_path.open("a") as f: + f.write("\n") # separate from previous runs + + f.write("==== CONFIG ====\n") + f.write(OmegaConf.to_yaml(cfg)) + f.write("\n") + + f.write("==== RESULTS ====\n") + f.write(f"metrics: {metrics}\n") + f.write(f"evaluation_time: {end_time - start_time} seconds\n") + + +if __name__ == "__main__": + run() diff --git a/jepa.py b/jepa.py new file mode 100644 index 0000000..486fe93 --- /dev/null +++ b/jepa.py @@ -0,0 +1,153 @@ +"""JEPA Implementation""" + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +def detach_clone(v): + return v.detach().clone() if torch.is_tensor(v) else v + +class JEPA(nn.Module): + + def __init__( + self, + encoder, + predictor, + action_encoder, + projector=None, + pred_proj=None, + ): + super().__init__() + + self.encoder = encoder + self.predictor = predictor + self.action_encoder = action_encoder + self.projector = projector or nn.Identity() + self.pred_proj = pred_proj or nn.Identity() + + def encode(self, info): + """Encode observations and actions into embeddings. + info: dict with pixels and action keys + """ + + pixels = info['pixels'].float() + b = pixels.size(0) + pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding + output = self.encoder(pixels, interpolate_pos_encoding=True) + pixels_emb = output.last_hidden_state[:, 0] # cls token + emb = self.projector(pixels_emb) + info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b) + + if "action" in info: + info["act_emb"] = self.action_encoder(info["action"]) + + return info + + def predict(self, emb, act_emb): + """Predict next state embedding + emb: (B, T, D) + act_emb: (B, T, A_emb) + """ + preds = self.predictor(emb, act_emb) + preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) + preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) + return preds + + #################### + ## Inference only ## + #################### + + def rollout(self, info, action_sequence, history_size: int = 3): + """Rollout the model given an initial info dict and action sequence. + pixels: (B, S, T, C, H, W) + action_sequence: (B, S, T, action_dim) + - S is the number of action plan samples + - T is the time horizon + """ + + assert "pixels" in info, "pixels not in info_dict" + H = info["pixels"].size(2) + B, S, T = action_sequence.shape[:3] + act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) + info["action"] = act_0 + n_steps = T - H + + # copy and encode initial info dict + _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)} + _init = self.encode(_init) + emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1) + _init = {k: detach_clone(v) for k, v in _init.items()} + + # flatten batch and sample dimensions for rollout + emb = rearrange(emb, "b s ... -> (b s) ...").clone() + act = rearrange(act_0, "b s ... -> (b s) ...") + act_future = rearrange(act_future, "b s ... -> (b s) ...") + + # rollout predictor autoregressively for n_steps + HS = history_size + for t in range(n_steps): + act_emb = self.action_encoder(act) + emb_trunc = emb[:, -HS:] # (BS, HS, D) + act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) + pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) + emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D) + + next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim) + act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim) + + # predict the last state + act_emb = self.action_encoder(act) # (BS, T, A_emb) + emb_trunc = emb[:, -HS:] # (BS, HS, D) + act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) + pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) + emb = torch.cat([emb, pred_emb], dim=1) + + # unflatten batch and sample dimensions + pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S) + info["predicted_emb"] = pred_rollout + + return info + + def criterion(self, info_dict: dict): + """Compute the cost between predicted embeddings and goal embeddings.""" + pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) + goal_emb = info_dict["goal_emb"] # (B, S, T, dim) + + goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb) + + # return last-step cost per action candidate + cost = F.mse_loss( + pred_emb[..., -1:, :], + goal_emb[..., -1:, :].detach(), + reduction="none", + ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) + + return cost + + def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): + """ Compute the cost of action candidates given an info dict with goal and initial state.""" + + assert "goal" in info_dict, "goal not in info_dict" + + device = next(self.parameters()).device + for k in list(info_dict.keys()): + if torch.is_tensor(info_dict[k]): + info_dict[k] = info_dict[k].to(device) + + goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)} + goal["pixels"] = goal["goal"] + + for k in info_dict: + if k.startswith("goal_"): + goal[k[len("goal_") :]] = goal.pop(k) + + goal.pop("action") + goal = self.encode(goal) + + info_dict["goal_emb"] = goal["emb"] + info_dict = self.rollout(info_dict, action_candidates) + + cost = self.criterion(info_dict) + + return cost diff --git a/module.py b/module.py new file mode 100644 index 0000000..948567c --- /dev/null +++ b/module.py @@ -0,0 +1,285 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + +def modulate(x, shift, scale): + """AdaLN-zero modulation""" + return x * (1 + scale) + shift + +class SIGReg(torch.nn.Module): + """Sketch Isotropic Gaussian Regularizer (single-GPU!)""" + + def __init__(self, knots=17, num_proj=1024): + super().__init__() + self.num_proj = num_proj + t = torch.linspace(0, 3, knots, dtype=torch.float32) + dt = 3 / (knots - 1) + weights = torch.full((knots,), 2 * dt, dtype=torch.float32) + weights[[0, -1]] = dt + window = torch.exp(-t.square() / 2.0) + self.register_buffer("t", t) + self.register_buffer("phi", window) + self.register_buffer("weights", weights * window) + + def forward(self, proj): + """ + proj: (T, B, D) + """ + # sample random projections + A = torch.randn(proj.size(-1), self.num_proj, device="cuda") + A = A.div_(A.norm(p=2, dim=0)) + # compute the epps-pulley statistic + x_t = (proj @ A).unsqueeze(-1) * self.t + err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square() + statistic = (err @ self.weights) * proj.size(-2) + return statistic.mean() # average over projections and time + +class FeedForward(nn.Module): + """FeedForward network used in Transformers""" + + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + """Scaled dot-product attention with causal masking""" + + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + self.heads = heads + self.scale = dim_head**-0.5 + self.dropout = dropout + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, causal=True): + """ + x : (B, T, D) + """ + x = self.norm(x) + drop = self.dropout if self.training else 0.0 + qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head) + q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal) + out = rearrange(out, "b h t d -> b t (h d)") + return self.to_out(out) + + +class ConditionalBlock(nn.Module): + """Transformer block with AdaLN-zero conditioning""" + + def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): + super().__init__() + + self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + self.mlp = FeedForward(dim, mlp_dim, dropout=dropout) + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True) + ) + + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(c).chunk(6, dim=-1) + ) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class Block(nn.Module): + """Standard Transformer block""" + + def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): + super().__init__() + + self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + self.mlp = FeedForward(dim, mlp_dim, dropout=dropout) + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + """Standard Transformer with support for AdaLN-zero blocks""" + + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + depth, + heads, + dim_head, + mlp_dim, + dropout=0.0, + block_class=Block, + ): + super().__init__() + self.norm = nn.LayerNorm(hidden_dim) + self.layers = nn.ModuleList([]) + + self.input_proj = ( + nn.Linear(input_dim, hidden_dim) + if input_dim != hidden_dim + else nn.Identity() + ) + + self.cond_proj = ( + nn.Linear(input_dim, hidden_dim) + if input_dim != hidden_dim + else nn.Identity() + ) + + self.output_proj = ( + nn.Linear(hidden_dim, output_dim) + if hidden_dim != output_dim + else nn.Identity() + ) + + for _ in range(depth): + self.layers.append( + block_class(hidden_dim, heads, dim_head, mlp_dim, dropout) + ) + + def forward(self, x, c=None): + + if hasattr(self, "input_proj"): + x = self.input_proj(x) + + if c is not None and hasattr(self, "cond_proj"): + c = self.cond_proj(c) + + for block in self.layers: + x = block(x) if isinstance(block, Block) else block(x, c) + x = self.norm(x) + + if hasattr(self, "output_proj"): + x = self.output_proj(x) + return x + +class Embedder(nn.Module): + def __init__( + self, + input_dim=10, + smoothed_dim=10, + emb_dim=10, + mlp_scale=4, + ): + super().__init__() + self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1) + self.embed = nn.Sequential( + nn.Linear(smoothed_dim, mlp_scale * emb_dim), + nn.SiLU(), + nn.Linear(mlp_scale * emb_dim, emb_dim), + ) + + def forward(self, x): + """ + x: (B, T, D) + """ + x = x.float() + x = x.permute(0, 2, 1) + x = self.patch_embed(x) + x = x.permute(0, 2, 1) + x = self.embed(x) + return x + + +class MLP(nn.Module): + """Simple MLP with optional normalization and activation""" + + def __init__( + self, + input_dim, + hidden_dim, + output_dim=None, + norm_fn=nn.LayerNorm, + act_fn=nn.GELU, + ): + super().__init__() + norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + norm_fn, + act_fn(), + nn.Linear(hidden_dim, output_dim or input_dim), + ) + + def forward(self, x): + """ + x: (B*T, D) + """ + return self.net(x) + + +class ARPredictor(nn.Module): + """Autoregressive predictor for next-step embedding prediction.""" + + def __init__( + self, + *, + num_frames, + depth, + heads, + mlp_dim, + input_dim, + hidden_dim, + output_dim=None, + dim_head=64, + dropout=0.0, + emb_dropout=0.0, + ): + super().__init__() + self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim)) + self.dropout = nn.Dropout(emb_dropout) + self.transformer = Transformer( + input_dim, + hidden_dim, + output_dim or input_dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + block_class=ConditionalBlock, + ) + + def forward(self, x, c): + """ + x: (B, T, d) + c: (B, T, act_dim) + """ + T = x.size(1) + x = x + self.pos_embedding[:, :T] + x = self.dropout(x) + x = self.transformer(x, c) + return x diff --git a/train.py b/train.py new file mode 100644 index 0000000..044ebfc --- /dev/null +++ b/train.py @@ -0,0 +1,183 @@ +import os +from functools import partial +from pathlib import Path + +import hydra +import lightning as pl +import stable_pretraining as spt +import stable_worldmodel as swm +import torch +from lightning.pytorch.loggers import WandbLogger +from omegaconf import OmegaConf, open_dict + +from jepa import JEPA +from module import ARPredictor, Embedder, MLP, SIGReg +from utils import get_column_normalizer, get_img_preprocessor, ModelObjectCallBack + + +def lejepa_forward(self, batch, stage, cfg): + """encode observations, predict next states, compute losses.""" + + ctx_len = cfg.wm.history_size + n_preds = cfg.wm.num_preds + lambd = cfg.loss.sigreg.weight + + # Replace NaN values with 0 (occurs at sequence boundaries) + batch["action"] = torch.nan_to_num(batch["action"], 0.0) + + output = self.model.encode(batch) + + emb = output["emb"] # (B, T, D) + act_emb = output["act_emb"] + + ctx_emb = emb[:, :ctx_len] + ctx_act = act_emb[:, : ctx_len] + + tgt_emb = emb[:, n_preds:] # label + pred_emb = self.model.predict(ctx_emb, ctx_act) # pred + + # LeWM loss + output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean() + output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1)) + output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"] + + losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k} + self.log_dict(losses_dict, on_step=True, sync_dist=True) + return output + +@hydra.main(version_base=None, config_path="./config/train", config_name="lewm") +def run(cfg): + ######################### + ## dataset ## + ######################### + + dataset = swm.data.HDF5Dataset(**cfg.data.dataset, transform=None) + transforms = [get_img_preprocessor(source='pixels', target='pixels', img_size=cfg.img_size)] + + with open_dict(cfg): + for col in cfg.data.dataset.keys_to_load: + if col.startswith("pixels"): + continue + + normalizer = get_column_normalizer(dataset, col, col) + transforms.append(normalizer) + + setattr(cfg.wm, f"{col}_dim", dataset.get_dim(col)) + + transform = spt.data.transforms.Compose(*transforms) + dataset.transform = transform + + rnd_gen = torch.Generator().manual_seed(cfg.seed) + train_set, val_set = spt.data.random_split( + dataset, lengths=[cfg.train_split, 1 - cfg.train_split], generator=rnd_gen + ) + + train = torch.utils.data.DataLoader(train_set, **cfg.loader,shuffle=True, drop_last=True, generator=rnd_gen) + val = torch.utils.data.DataLoader(val_set, **cfg.loader, shuffle=False, drop_last=False) + + ############################## + ## model / optim ## + ############################## + + encoder = spt.backbone.utils.vit_hf( + cfg.encoder_scale, + patch_size=cfg.patch_size, + image_size=cfg.img_size, + pretrained=False, + use_mask_token=False, + ) + + hidden_dim = encoder.config.hidden_size + embed_dim = cfg.wm.get("embed_dim", hidden_dim) + effective_act_dim = cfg.data.dataset.frameskip * cfg.wm.action_dim + + predictor = ARPredictor( + num_frames=cfg.wm.history_size, + input_dim=embed_dim, + hidden_dim=hidden_dim, + output_dim=hidden_dim, + **cfg.predictor, + ) + + action_encoder = Embedder(input_dim=effective_act_dim, emb_dim=embed_dim) + + projector = MLP( + input_dim=hidden_dim, + output_dim=embed_dim, + hidden_dim=2048, + norm_fn=torch.nn.BatchNorm1d, + ) + + predictor_proj = MLP( + input_dim=hidden_dim, + output_dim=embed_dim, + hidden_dim=2048, + norm_fn=torch.nn.BatchNorm1d, + ) + + world_model = JEPA( + encoder=encoder, + predictor=predictor, + action_encoder=action_encoder, + projector=projector, + pred_proj=predictor_proj, + ) + + optimizers = { + 'model_opt': { + "modules": 'model', + "optimizer": dict(cfg.optimizer), + "scheduler": {"type": "LinearWarmupCosineAnnealingLR"}, + "interval": "epoch", + }, + } + + data_module = spt.data.DataModule(train=train, val=val) + world_model = spt.Module( + model = world_model, + sigreg = SIGReg(**cfg.loss.sigreg.kwargs), + forward=partial(lejepa_forward, cfg=cfg), + optim=optimizers, + ) + + ########################## + ## training ## + ########################## + + run_id = cfg.get("subdir") or "" + run_dir = Path(swm.data.utils.get_cache_dir(), run_id) + + logger = None + if cfg.wandb.enabled: + logger = WandbLogger(**cfg.wandb.config) + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + run_dir.mkdir(parents=True, exist_ok=True) + with open(run_dir / "config.yaml", "w") as f: + OmegaConf.save(cfg, f) + + object_dump_callback = ModelObjectCallBack( + dirpath=run_dir, filename=cfg.output_model_name, epoch_interval=1, + ) + + trainer = pl.Trainer( + **cfg.trainer, + callbacks=[object_dump_callback], + num_sanity_val_steps=1, + logger=logger, + enable_checkpointing=True, + ) + + manager = spt.Manager( + trainer=trainer, + module=world_model, + data=data_module, + ckpt_path=run_dir / f"{cfg.output_model_name}_weights.ckpt", + ) + + manager() + return + + +if __name__ == "__main__": + run() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..a1c234e --- /dev/null +++ b/utils.py @@ -0,0 +1,57 @@ +import numpy as np +import torch +from pathlib import Path +from stable_pretraining import data as dt +from lightning.pytorch.callbacks import Callback + +def get_img_preprocessor(source: str, target: str, img_size: int = 224): + imagenet_stats = dt.dataset_stats.ImageNet + to_image = dt.transforms.ToImage(**imagenet_stats, source=source, target=target) + resize = dt.transforms.Resize(img_size, source=source, target=target) + return dt.transforms.Compose(to_image, resize) + + +def get_column_normalizer(dataset, source: str, target: str): + """Get normalizer for a specific column in the dataset.""" + col_data = dataset.get_col_data(source) + data = torch.from_numpy(np.array(col_data)) + data = data[~torch.isnan(data).any(dim=1)] + mean = data.mean(0, keepdim=True).clone() + std = data.std(0, keepdim=True).clone() + + def norm_fn(x): + return ((x - mean) / std).float() + + normalizer = dt.transforms.WrapTorchTransform(norm_fn, source=source, target=target) + return normalizer + +class ModelObjectCallBack(Callback): + """Callback to pickle model object after each epoch.""" + + def __init__(self, dirpath, filename="model_object", epoch_interval: int = 1): + super().__init__() + self.dirpath = Path(dirpath) + self.filename = filename + self.epoch_interval = epoch_interval + + def on_train_epoch_end(self, trainer, pl_module): + super().on_train_epoch_end(trainer, pl_module) + + output_path = ( + self.dirpath + / f"{self.filename}_epoch_{trainer.current_epoch + 1}_object.ckpt" + ) + + if trainer.is_global_zero: + if (trainer.current_epoch + 1) % self.epoch_interval == 0: + self._dump_model(pl_module.model, output_path) + + # save final epoch + if (trainer.current_epoch + 1) == trainer.max_epochs: + self._dump_model(pl_module.model, output_path) + + def _dump_model(self, model, path): + try: + torch.save(model, path) + except Exception as e: + print(f"Error saving model object: {e}") \ No newline at end of file