From d6475e61338087d0158924c7452adede025b05d2 Mon Sep 17 00:00:00 2001 From: Haiyang Luo <41023868+Hoiyeuhng@users.noreply.github.com> Date: Mon, 23 Mar 2026 23:30:53 -0700 Subject: [PATCH] fix: use proj.device instead of hardcoded cuda, fix README typos MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace hardcoded device="cuda" with proj.device in SIGReg for portability (e.g. macOS MPS, CPU) - Fix "Both functions accept" → "This function accepts" (only one function is shown) - Fix "please reference in your paper" → "please reference it in your paper" --- README.md | 4 ++-- module.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 55b0396..9b1aa9e 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@

-If you find this code useful, please reference in your paper: +If you find this code useful, please reference it in your paper: ``` @article{maes_lelidec2026lewm, title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels}, @@ -117,7 +117,7 @@ import stable_worldmodel as swm cost = swm.policy.AutoCostModel('pusht/lewm') ``` -Both functions accept: +This function accepts: - `run_name` — checkpoint path **relative to `$STABLEWM_HOME`**, without the `_object.ckpt` suffix - `cache_dir` — optional override for the checkpoint root (defaults to `$STABLEWM_HOME`) diff --git a/module.py b/module.py index 948567c..16c4907 100644 --- a/module.py +++ b/module.py @@ -27,7 +27,7 @@ class SIGReg(torch.nn.Module): proj: (T, B, D) """ # sample random projections - A = torch.randn(proj.size(-1), self.num_proj, device="cuda") + A = torch.randn(proj.size(-1), self.num_proj, device=proj.device) A = A.div_(A.norm(p=2, dim=0)) # compute the epps-pulley statistic x_t = (proj @ A).unsqueeze(-1) * self.t