Merge pull request #1 from Hoiyeuhng/fix/device-agnostic-and-readme

fix: use proj.device instead of hardcoded cuda, fix README typos
This commit is contained in:
Lucas Maes
2026-03-24 14:40:26 +01:00
committed by GitHub
2 changed files with 3 additions and 3 deletions

View File

@@ -16,7 +16,7 @@
<img src="assets/lewm.gif" width="80%"> <img src="assets/lewm.gif" width="80%">
</p> </p>
If you find this code useful, please reference in your paper: If you find this code useful, please reference it in your paper:
``` ```
@article{maes_lelidec2026lewm, @article{maes_lelidec2026lewm,
title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels}, title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
@@ -117,7 +117,7 @@ import stable_worldmodel as swm
cost = swm.policy.AutoCostModel('pusht/lewm') cost = swm.policy.AutoCostModel('pusht/lewm')
``` ```
Both functions accept: This function accepts:
- `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix - `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix
- `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`) - `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`)

View File

@@ -27,7 +27,7 @@ class SIGReg(torch.nn.Module):
proj: (T, B, D) proj: (T, B, D)
""" """
# sample random projections # sample random projections
A = torch.randn(proj.size(-1), self.num_proj, device="cuda") A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
A = A.div_(A.norm(p=2, dim=0)) A = A.div_(A.norm(p=2, dim=0))
# compute the epps-pulley statistic # compute the epps-pulley statistic
x_t = (proj @ A).unsqueeze(-1) * self.t x_t = (proj @ A).unsqueeze(-1) * self.t