184 lines
5.3 KiB
Python
184 lines
5.3 KiB
Python
import os
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import hydra
|
|
import lightning as pl
|
|
import stable_pretraining as spt
|
|
import stable_worldmodel as swm
|
|
import torch
|
|
from lightning.pytorch.loggers import WandbLogger
|
|
from omegaconf import OmegaConf, open_dict
|
|
|
|
from jepa import JEPA
|
|
from module import ARPredictor, Embedder, MLP, SIGReg
|
|
from utils import get_column_normalizer, get_img_preprocessor, ModelObjectCallBack
|
|
|
|
|
|
def lejepa_forward(self, batch, stage, cfg):
|
|
"""encode observations, predict next states, compute losses."""
|
|
|
|
ctx_len = cfg.wm.history_size
|
|
n_preds = cfg.wm.num_preds
|
|
lambd = cfg.loss.sigreg.weight
|
|
|
|
# Replace NaN values with 0 (occurs at sequence boundaries)
|
|
batch["action"] = torch.nan_to_num(batch["action"], 0.0)
|
|
|
|
output = self.model.encode(batch)
|
|
|
|
emb = output["emb"] # (B, T, D)
|
|
act_emb = output["act_emb"]
|
|
|
|
ctx_emb = emb[:, :ctx_len]
|
|
ctx_act = act_emb[:, : ctx_len]
|
|
|
|
tgt_emb = emb[:, n_preds:] # label
|
|
pred_emb = self.model.predict(ctx_emb, ctx_act) # pred
|
|
|
|
# LeWM loss
|
|
output["pred_loss"] = (pred_emb - tgt_emb).pow(2).mean()
|
|
output["sigreg_loss"]= self.sigreg(emb.transpose(0, 1))
|
|
output["loss"] = output["pred_loss"] + lambd * output["sigreg_loss"]
|
|
|
|
losses_dict = {f"{stage}/{k}": v.detach() for k, v in output.items() if "loss" in k}
|
|
self.log_dict(losses_dict, on_step=True, sync_dist=True)
|
|
return output
|
|
|
|
@hydra.main(version_base=None, config_path="./config/train", config_name="lewm")
|
|
def run(cfg):
|
|
#########################
|
|
## dataset ##
|
|
#########################
|
|
|
|
dataset = swm.data.HDF5Dataset(**cfg.data.dataset, transform=None)
|
|
transforms = [get_img_preprocessor(source='pixels', target='pixels', img_size=cfg.img_size)]
|
|
|
|
with open_dict(cfg):
|
|
for col in cfg.data.dataset.keys_to_load:
|
|
if col.startswith("pixels"):
|
|
continue
|
|
|
|
normalizer = get_column_normalizer(dataset, col, col)
|
|
transforms.append(normalizer)
|
|
|
|
setattr(cfg.wm, f"{col}_dim", dataset.get_dim(col))
|
|
|
|
transform = spt.data.transforms.Compose(*transforms)
|
|
dataset.transform = transform
|
|
|
|
rnd_gen = torch.Generator().manual_seed(cfg.seed)
|
|
train_set, val_set = spt.data.random_split(
|
|
dataset, lengths=[cfg.train_split, 1 - cfg.train_split], generator=rnd_gen
|
|
)
|
|
|
|
train = torch.utils.data.DataLoader(train_set, **cfg.loader,shuffle=True, drop_last=True, generator=rnd_gen)
|
|
val = torch.utils.data.DataLoader(val_set, **cfg.loader, shuffle=False, drop_last=False)
|
|
|
|
##############################
|
|
## model / optim ##
|
|
##############################
|
|
|
|
encoder = spt.backbone.utils.vit_hf(
|
|
cfg.encoder_scale,
|
|
patch_size=cfg.patch_size,
|
|
image_size=cfg.img_size,
|
|
pretrained=False,
|
|
use_mask_token=False,
|
|
)
|
|
|
|
hidden_dim = encoder.config.hidden_size
|
|
embed_dim = cfg.wm.get("embed_dim", hidden_dim)
|
|
effective_act_dim = cfg.data.dataset.frameskip * cfg.wm.action_dim
|
|
|
|
predictor = ARPredictor(
|
|
num_frames=cfg.wm.history_size,
|
|
input_dim=embed_dim,
|
|
hidden_dim=hidden_dim,
|
|
output_dim=hidden_dim,
|
|
**cfg.predictor,
|
|
)
|
|
|
|
action_encoder = Embedder(input_dim=effective_act_dim, emb_dim=embed_dim)
|
|
|
|
projector = MLP(
|
|
input_dim=hidden_dim,
|
|
output_dim=embed_dim,
|
|
hidden_dim=2048,
|
|
norm_fn=torch.nn.BatchNorm1d,
|
|
)
|
|
|
|
predictor_proj = MLP(
|
|
input_dim=hidden_dim,
|
|
output_dim=embed_dim,
|
|
hidden_dim=2048,
|
|
norm_fn=torch.nn.BatchNorm1d,
|
|
)
|
|
|
|
world_model = JEPA(
|
|
encoder=encoder,
|
|
predictor=predictor,
|
|
action_encoder=action_encoder,
|
|
projector=projector,
|
|
pred_proj=predictor_proj,
|
|
)
|
|
|
|
optimizers = {
|
|
'model_opt': {
|
|
"modules": 'model',
|
|
"optimizer": dict(cfg.optimizer),
|
|
"scheduler": {"type": "LinearWarmupCosineAnnealingLR"},
|
|
"interval": "epoch",
|
|
},
|
|
}
|
|
|
|
data_module = spt.data.DataModule(train=train, val=val)
|
|
world_model = spt.Module(
|
|
model = world_model,
|
|
sigreg = SIGReg(**cfg.loss.sigreg.kwargs),
|
|
forward=partial(lejepa_forward, cfg=cfg),
|
|
optim=optimizers,
|
|
)
|
|
|
|
##########################
|
|
## training ##
|
|
##########################
|
|
|
|
run_id = cfg.get("subdir") or ""
|
|
run_dir = Path(swm.data.utils.get_cache_dir(), run_id)
|
|
|
|
logger = None
|
|
if cfg.wandb.enabled:
|
|
logger = WandbLogger(**cfg.wandb.config)
|
|
logger.log_hyperparams(OmegaConf.to_container(cfg))
|
|
|
|
run_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(run_dir / "config.yaml", "w") as f:
|
|
OmegaConf.save(cfg, f)
|
|
|
|
object_dump_callback = ModelObjectCallBack(
|
|
dirpath=run_dir, filename=cfg.output_model_name, epoch_interval=1,
|
|
)
|
|
|
|
trainer = pl.Trainer(
|
|
**cfg.trainer,
|
|
callbacks=[object_dump_callback],
|
|
num_sanity_val_steps=1,
|
|
logger=logger,
|
|
enable_checkpointing=True,
|
|
)
|
|
|
|
manager = spt.Manager(
|
|
trainer=trainer,
|
|
module=world_model,
|
|
data=data_module,
|
|
ckpt_path=run_dir / f"{cfg.output_model_name}_weights.ckpt",
|
|
)
|
|
|
|
manager()
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|