Files
lewm/train.py
2026-03-23 14:22:26 +01:00

184 lines
5.3 KiB
Python

import os
from functools import partial
from pathlib import Path
import hydra
import lightning as pl
import stable_pretraining as spt
import stable_worldmodel as swm
import torch
from lightning.pytorch.loggers import WandbLogger
from omegaconf import OmegaConf, open_dict
from jepa import JEPA
from module import ARPredictor, Embedder, MLP, SIGReg
from utils import get_column_normalizer, get_img_preprocessor, ModelObjectCallBack
def lejepa_forward(self, batch, stage, cfg):
"""encode observations, predict next states, compute losses."""
ctx_len = cfg.wm.history_size
n_preds = cfg.wm.num_preds
lambd = cfg.loss.sigreg.weight
# Replace NaN values with 0 (occurs at sequence boundaries)
batch["action"] = torch.nan_to_num(batch["action"], 0.0)
output = self.model.encode(batch)
emb = output["emb"] # (B, T, D)
act_emb = output["act_emb"]
ctx_emb = emb[:, :ctx_len]
ctx_act = act_emb[:, : ctx_len]
tgt_emb = emb[:, n_preds:] # label
pred_emb = self.model.predict(ctx_emb, ctx_act) # pred
# LeWM loss
output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean()
output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1))
output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"]
losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k}
self.log_dict(losses_dict, on_step=True, sync_dist=True)
return output
@hydra.main(version_base=None, config_path="./config/train", config_name="lewm")
def run(cfg):
#########################
## dataset ##
#########################
dataset = swm.data.HDF5Dataset(**cfg.data.dataset, transform=None)
transforms = [get_img_preprocessor(source='pixels', target='pixels', img_size=cfg.img_size)]
with open_dict(cfg):
for col in cfg.data.dataset.keys_to_load:
if col.startswith("pixels"):
continue
normalizer = get_column_normalizer(dataset, col, col)
transforms.append(normalizer)
setattr(cfg.wm, f"{col}_dim", dataset.get_dim(col))
transform = spt.data.transforms.Compose(*transforms)
dataset.transform = transform
rnd_gen = torch.Generator().manual_seed(cfg.seed)
train_set, val_set = spt.data.random_split(
dataset, lengths=[cfg.train_split, 1 - cfg.train_split], generator=rnd_gen
)
train = torch.utils.data.DataLoader(train_set, **cfg.loader,shuffle=True, drop_last=True, generator=rnd_gen)
val = torch.utils.data.DataLoader(val_set, **cfg.loader, shuffle=False, drop_last=False)
##############################
## model / optim ##
##############################
encoder = spt.backbone.utils.vit_hf(
cfg.encoder_scale,
patch_size=cfg.patch_size,
image_size=cfg.img_size,
pretrained=False,
use_mask_token=False,
)
hidden_dim = encoder.config.hidden_size
embed_dim = cfg.wm.get("embed_dim", hidden_dim)
effective_act_dim = cfg.data.dataset.frameskip * cfg.wm.action_dim
predictor = ARPredictor(
num_frames=cfg.wm.history_size,
input_dim=embed_dim,
hidden_dim=hidden_dim,
output_dim=hidden_dim,
**cfg.predictor,
)
action_encoder = Embedder(input_dim=effective_act_dim, emb_dim=embed_dim)
projector = MLP(
input_dim=hidden_dim,
output_dim=embed_dim,
hidden_dim=2048,
norm_fn=torch.nn.BatchNorm1d,
)
predictor_proj = MLP(
input_dim=hidden_dim,
output_dim=embed_dim,
hidden_dim=2048,
norm_fn=torch.nn.BatchNorm1d,
)
world_model = JEPA(
encoder=encoder,
predictor=predictor,
action_encoder=action_encoder,
projector=projector,
pred_proj=predictor_proj,
)
optimizers = {
'model_opt': {
"modules": 'model',
"optimizer": dict(cfg.optimizer),
"scheduler": {"type": "LinearWarmupCosineAnnealingLR"},
"interval": "epoch",
},
}
data_module = spt.data.DataModule(train=train, val=val)
world_model = spt.Module(
model = world_model,
sigreg = SIGReg(**cfg.loss.sigreg.kwargs),
forward=partial(lejepa_forward, cfg=cfg),
optim=optimizers,
)
##########################
## training ##
##########################
run_id = cfg.get("subdir") or ""
run_dir = Path(swm.data.utils.get_cache_dir(), run_id)
logger = None
if cfg.wandb.enabled:
logger = WandbLogger(**cfg.wandb.config)
logger.log_hyperparams(OmegaConf.to_container(cfg))
run_dir.mkdir(parents=True, exist_ok=True)
with open(run_dir / "config.yaml", "w") as f:
OmegaConf.save(cfg, f)
object_dump_callback = ModelObjectCallBack(
dirpath=run_dir, filename=cfg.output_model_name, epoch_interval=1,
)
trainer = pl.Trainer(
**cfg.trainer,
callbacks=[object_dump_callback],
num_sanity_val_steps=1,
logger=logger,
enable_checkpointing=True,
)
manager = spt.Manager(
trainer=trainer,
module=world_model,
data=data_module,
ckpt_path=run_dir / f"{cfg.output_model_name}_weights.ckpt",
)
manager()
return
if __name__ == "__main__":
run()