fix: use proj.device instead of hardcoded cuda, fix README typos

- Replace hardcoded device="cuda" with proj.device in SIGReg for
  portability (e.g. macOS MPS, CPU)
- Fix "Both functions accept" → "This function accepts" (only one
  function is shown)
- Fix "please reference in your paper" → "please reference it in
  your paper"
This commit is contained in:
Haiyang Luo
2026-03-23 23:30:53 -07:00
parent 83f97d72ad
commit d6475e6133
2 changed files with 3 additions and 3 deletions

View File

@@ -27,7 +27,7 @@ class SIGReg(torch.nn.Module):
proj: (T, B, D)
"""
# sample random projections
A = torch.randn(proj.size(-1), self.num_proj, device="cuda")
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
A = A.div_(A.norm(p=2, dim=0))
# compute the epps-pulley statistic
x_t = (proj @ A).unsqueeze(-1) * self.t