fix: use proj.device instead of hardcoded cuda, fix README typos
- Replace hardcoded device="cuda" with proj.device in SIGReg for portability (e.g. macOS MPS, CPU) - Fix "Both functions accept" → "This function accepts" (only one function is shown) - Fix "please reference in your paper" → "please reference it in your paper"
This commit is contained in:
@@ -27,7 +27,7 @@ class SIGReg(torch.nn.Module):
|
||||
proj: (T, B, D)
|
||||
"""
|
||||
# sample random projections
|
||||
A = torch.randn(proj.size(-1), self.num_proj, device="cuda")
|
||||
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
||||
A = A.div_(A.norm(p=2, dim=0))
|
||||
# compute the epps-pulley statistic
|
||||
x_t = (proj @ A).unsqueeze(-1) * self.t
|
||||
|
||||
Reference in New Issue
Block a user