diff --git a/README.md b/README.md index 55b0396..9b1aa9e 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@

-If you find this code useful, please reference in your paper: +If you find this code useful, please reference it in your paper: ``` @article{maes_lelidec2026lewm, title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels}, @@ -117,7 +117,7 @@ import stable_worldmodel as swm cost = swm.policy.AutoCostModel('pusht/lewm') ``` -Both functions accept: +This function accepts: - `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix - `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`) diff --git a/module.py b/module.py index 948567c..16c4907 100644 --- a/module.py +++ b/module.py @@ -27,7 +27,7 @@ class SIGReg(torch.nn.Module): proj: (T, B, D) """ # sample random projections - A = torch.randn(proj.size(-1), self.num_proj, device="cuda") + A = torch.randn(proj.size(-1), self.num_proj, device=proj.device) A = A.div_(A.norm(p=2, dim=0)) # compute the epps-pulley statistic x_t = (proj @ A).unsqueeze(-1) * self.t