fix: use proj.device instead of hardcoded cuda, fix README typos
- Replace hardcoded device="cuda" with proj.device in SIGReg for portability (e.g. macOS MPS, CPU) - Fix "Both functions accept" → "This function accepts" (only one function is shown) - Fix "please reference in your paper" → "please reference it in your paper"
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
<img src="assets/lewm.gif" width="80%">
|
<img src="assets/lewm.gif" width="80%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
If you find this code useful, please reference in your paper:
|
If you find this code useful, please reference it in your paper:
|
||||||
```
|
```
|
||||||
@article{maes_lelidec2026lewm,
|
@article{maes_lelidec2026lewm,
|
||||||
title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
|
title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
|
||||||
@@ -117,7 +117,7 @@ import stable_worldmodel as swm
|
|||||||
cost = swm.policy.AutoCostModel('pusht/lewm')
|
cost = swm.policy.AutoCostModel('pusht/lewm')
|
||||||
```
|
```
|
||||||
|
|
||||||
Both functions accept:
|
This function accepts:
|
||||||
- `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix
|
- `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix
|
||||||
- `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`)
|
- `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`)
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class SIGReg(torch.nn.Module):
|
|||||||
proj: (T, B, D)
|
proj: (T, B, D)
|
||||||
"""
|
"""
|
||||||
# sample random projections
|
# sample random projections
|
||||||
A = torch.randn(proj.size(-1), self.num_proj, device="cuda")
|
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
||||||
A = A.div_(A.norm(p=2, dim=0))
|
A = A.div_(A.norm(p=2, dim=0))
|
||||||
# compute the epps-pulley statistic
|
# compute the epps-pulley statistic
|
||||||
x_t = (proj @ A).unsqueeze(-1) * self.t
|
x_t = (proj @ A).unsqueeze(-1) * self.t
|
||||||
|
|||||||
Reference in New Issue
Block a user