import os from functools import partial from pathlib import Path import hydra import lightning as pl import stable_pretraining as spt import stable_worldmodel as swm import torch from lightning.pytorch.loggers import WandbLogger from omegaconf import OmegaConf, open_dict from jepa import JEPA from module import ARPredictor, Embedder, MLP, SIGReg from utils import get_column_normalizer, get_img_preprocessor, ModelObjectCallBack def lejepa_forward(self, batch, stage, cfg): """encode observations, predict next states, compute losses.""" ctx_len = cfg.wm.history_size n_preds = cfg.wm.num_preds lambd = cfg.loss.sigreg.weight # Replace NaN values with 0 (occurs at sequence boundaries) batch["action"] = torch.nan_to_num(batch["action"], 0.0) output = self.model.encode(batch) emb = output["emb"] # (B, T, D) act_emb = output["act_emb"] ctx_emb = emb[:, :ctx_len] ctx_act = act_emb[:, : ctx_len] tgt_emb = emb[:, n_preds:] # label pred_emb = self.model.predict(ctx_emb, ctx_act) # pred # LeWM loss output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean() output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1)) output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"] losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k} self.log_dict(losses_dict, on_step=True, sync_dist=True) return output @hydra.main(version_base=None, config_path="./config/train", config_name="lewm") def run(cfg): ######################### ## dataset ## ######################### dataset = swm.data.HDF5Dataset(**cfg.data.dataset, transform=None) transforms = [get_img_preprocessor(source='pixels', target='pixels', img_size=cfg.img_size)] with open_dict(cfg): for col in cfg.data.dataset.keys_to_load: if col.startswith("pixels"): continue normalizer = get_column_normalizer(dataset, col, col) transforms.append(normalizer) setattr(cfg.wm, f"{col}_dim", dataset.get_dim(col)) transform = spt.data.transforms.Compose(*transforms) dataset.transform = transform rnd_gen = torch.Generator().manual_seed(cfg.seed) train_set, val_set = spt.data.random_split( dataset, lengths=[cfg.train_split, 1 - cfg.train_split], generator=rnd_gen ) train = torch.utils.data.DataLoader(train_set, **cfg.loader,shuffle=True, drop_last=True, generator=rnd_gen) val = torch.utils.data.DataLoader(val_set, **cfg.loader, shuffle=False, drop_last=False) ############################## ## model / optim ## ############################## encoder = spt.backbone.utils.vit_hf( cfg.encoder_scale, patch_size=cfg.patch_size, image_size=cfg.img_size, pretrained=False, use_mask_token=False, ) hidden_dim = encoder.config.hidden_size embed_dim = cfg.wm.get("embed_dim", hidden_dim) effective_act_dim = cfg.data.dataset.frameskip * cfg.wm.action_dim predictor = ARPredictor( num_frames=cfg.wm.history_size, input_dim=embed_dim, hidden_dim=hidden_dim, output_dim=hidden_dim, **cfg.predictor, ) action_encoder = Embedder(input_dim=effective_act_dim, emb_dim=embed_dim) projector = MLP( input_dim=hidden_dim, output_dim=embed_dim, hidden_dim=2048, norm_fn=torch.nn.BatchNorm1d, ) predictor_proj = MLP( input_dim=hidden_dim, output_dim=embed_dim, hidden_dim=2048, norm_fn=torch.nn.BatchNorm1d, ) world_model = JEPA( encoder=encoder, predictor=predictor, action_encoder=action_encoder, projector=projector, pred_proj=predictor_proj, ) optimizers = { 'model_opt': { "modules": 'model', "optimizer": dict(cfg.optimizer), "scheduler": {"type": "LinearWarmupCosineAnnealingLR"}, "interval": "epoch", }, } data_module = spt.data.DataModule(train=train, val=val) world_model = spt.Module( model = world_model, sigreg = SIGReg(**cfg.loss.sigreg.kwargs), forward=partial(lejepa_forward, cfg=cfg), optim=optimizers, ) ########################## ## training ## ########################## run_id = cfg.get("subdir") or "" run_dir = Path(swm.data.utils.get_cache_dir(), run_id) logger = None if cfg.wandb.enabled: logger = WandbLogger(**cfg.wandb.config) logger.log_hyperparams(OmegaConf.to_container(cfg)) run_dir.mkdir(parents=True, exist_ok=True) with open(run_dir / "config.yaml", "w") as f: OmegaConf.save(cfg, f) object_dump_callback = ModelObjectCallBack( dirpath=run_dir, filename=cfg.output_model_name, epoch_interval=1, ) trainer = pl.Trainer( **cfg.trainer, callbacks=[object_dump_callback], num_sanity_val_steps=1, logger=logger, enable_checkpointing=True, ) manager = spt.Manager( trainer=trainer, module=world_model, data=data_module, ckpt_path=run_dir / f"{cfg.output_model_name}_weights.ckpt", ) manager() return if __name__ == "__main__": run()