Files
lewm/module.py
Haiyang Luo d6475e6133 fix: use proj.device instead of hardcoded cuda, fix README typos
- Replace hardcoded device="cuda" with proj.device in SIGReg for
  portability (e.g. macOS MPS, CPU)
- Fix "Both functions accept" → "This function accepts" (only one
  function is shown)
- Fix "please reference in your paper" → "please reference it in
  your paper"
2026-03-23 23:30:53 -07:00

286 lines
8.2 KiB
Python

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
def modulate(x, shift, scale):
"""AdaLN-zero modulation"""
return x * (1 + scale) + shift
class SIGReg(torch.nn.Module):
"""Sketch Isotropic Gaussian Regularizer (single-GPU!)"""
def __init__(self, knots=17, num_proj=1024):
super().__init__()
self.num_proj = num_proj
t = torch.linspace(0, 3, knots, dtype=torch.float32)
dt = 3 / (knots - 1)
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
weights[[0, -1]] = dt
window = torch.exp(-t.square() / 2.0)
self.register_buffer("t", t)
self.register_buffer("phi", window)
self.register_buffer("weights", weights * window)
def forward(self, proj):
"""
proj: (T, B, D)
"""
# sample random projections
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
A = A.div_(A.norm(p=2, dim=0))
# compute the epps-pulley statistic
x_t = (proj @ A).unsqueeze(-1) * self.t
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
statistic = (err @ self.weights) * proj.size(-2)
return statistic.mean() # average over projections and time
class FeedForward(nn.Module):
"""FeedForward network used in Transformers"""
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
"""Scaled dot-product attention with causal masking"""
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.dropout = dropout
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, causal=True):
"""
x : (B, T, D)
"""
x = self.norm(x)
drop = self.dropout if self.training else 0.0
qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head)
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
out = rearrange(out, "b h t d -> b t (h d)")
return self.to_out(out)
class ConditionalBlock(nn.Module):
"""Transformer block with AdaLN-zero conditioning"""
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)
)
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(c).chunk(6, dim=-1)
)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class Block(nn.Module):
"""Standard Transformer block"""
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class Transformer(nn.Module):
"""Standard Transformer with support for AdaLN-zero blocks"""
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout=0.0,
block_class=Block,
):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.layers = nn.ModuleList([])
self.input_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.cond_proj = (
nn.Linear(input_dim, hidden_dim)
if input_dim != hidden_dim
else nn.Identity()
)
self.output_proj = (
nn.Linear(hidden_dim, output_dim)
if hidden_dim != output_dim
else nn.Identity()
)
for _ in range(depth):
self.layers.append(
block_class(hidden_dim, heads, dim_head, mlp_dim, dropout)
)
def forward(self, x, c=None):
if hasattr(self, "input_proj"):
x = self.input_proj(x)
if c is not None and hasattr(self, "cond_proj"):
c = self.cond_proj(c)
for block in self.layers:
x = block(x) if isinstance(block, Block) else block(x, c)
x = self.norm(x)
if hasattr(self, "output_proj"):
x = self.output_proj(x)
return x
class Embedder(nn.Module):
def __init__(
self,
input_dim=10,
smoothed_dim=10,
emb_dim=10,
mlp_scale=4,
):
super().__init__()
self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1)
self.embed = nn.Sequential(
nn.Linear(smoothed_dim, mlp_scale * emb_dim),
nn.SiLU(),
nn.Linear(mlp_scale * emb_dim, emb_dim),
)
def forward(self, x):
"""
x: (B, T, D)
"""
x = x.float()
x = x.permute(0, 2, 1)
x = self.patch_embed(x)
x = x.permute(0, 2, 1)
x = self.embed(x)
return x
class MLP(nn.Module):
"""Simple MLP with optional normalization and activation"""
def __init__(
self,
input_dim,
hidden_dim,
output_dim=None,
norm_fn=nn.LayerNorm,
act_fn=nn.GELU,
):
super().__init__()
norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
norm_fn,
act_fn(),
nn.Linear(hidden_dim, output_dim or input_dim),
)
def forward(self, x):
"""
x: (B*T, D)
"""
return self.net(x)
class ARPredictor(nn.Module):
"""Autoregressive predictor for next-step embedding prediction."""
def __init__(
self,
*,
num_frames,
depth,
heads,
mlp_dim,
input_dim,
hidden_dim,
output_dim=None,
dim_head=64,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(
input_dim,
hidden_dim,
output_dim or input_dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
block_class=ConditionalBlock,
)
def forward(self, x, c):
"""
x: (B, T, d)
c: (B, T, act_dim)
"""
T = x.size(1)
x = x + self.pos_embedding[:, :T]
x = self.dropout(x)
x = self.transformer(x, c)
return x