286 lines
8.2 KiB
Python
286 lines
8.2 KiB
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
|
|
def modulate(x, shift, scale):
|
|
"""AdaLN-zero modulation"""
|
|
return x * (1 + scale) + shift
|
|
|
|
class SIGReg(torch.nn.Module):
|
|
"""Sketch Isotropic Gaussian Regularizer (single-GPU!)"""
|
|
|
|
def __init__(self, knots=17, num_proj=1024):
|
|
super().__init__()
|
|
self.num_proj = num_proj
|
|
t = torch.linspace(0, 3, knots, dtype=torch.float32)
|
|
dt = 3 / (knots - 1)
|
|
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
|
|
weights[[0, -1]] = dt
|
|
window = torch.exp(-t.square() / 2.0)
|
|
self.register_buffer("t", t)
|
|
self.register_buffer("phi", window)
|
|
self.register_buffer("weights", weights * window)
|
|
|
|
def forward(self, proj):
|
|
"""
|
|
proj: (T, B, D)
|
|
"""
|
|
# sample random projections
|
|
A = torch.randn(proj.size(-1), self.num_proj, device="cuda")
|
|
A = A.div_(A.norm(p=2, dim=0))
|
|
# compute the epps-pulley statistic
|
|
x_t = (proj @ A).unsqueeze(-1) * self.t
|
|
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
|
|
statistic = (err @ self.weights) * proj.size(-2)
|
|
return statistic.mean() # average over projections and time
|
|
|
|
class FeedForward(nn.Module):
|
|
"""FeedForward network used in Transformers"""
|
|
|
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.LayerNorm(dim),
|
|
nn.Linear(dim, hidden_dim),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_dim, dim),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""Scaled dot-product attention with causal masking"""
|
|
|
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
project_out = not (heads == 1 and dim_head == dim)
|
|
self.heads = heads
|
|
self.scale = dim_head**-0.5
|
|
self.dropout = dropout
|
|
self.norm = nn.LayerNorm(dim)
|
|
self.attend = nn.Softmax(dim=-1)
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
|
self.to_out = (
|
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
|
if project_out
|
|
else nn.Identity()
|
|
)
|
|
|
|
def forward(self, x, causal=True):
|
|
"""
|
|
x : (B, T, D)
|
|
"""
|
|
x = self.norm(x)
|
|
drop = self.dropout if self.training else 0.0
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head)
|
|
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
|
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
|
|
out = rearrange(out, "b h t d -> b t (h d)")
|
|
return self.to_out(out)
|
|
|
|
|
|
class ConditionalBlock(nn.Module):
|
|
"""Transformer block with AdaLN-zero conditioning"""
|
|
|
|
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
|
super().__init__()
|
|
|
|
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
|
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)
|
|
)
|
|
|
|
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
|
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
|
|
|
def forward(self, x, c):
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
self.adaLN_modulation(c).chunk(6, dim=-1)
|
|
)
|
|
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
|
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
"""Standard Transformer block"""
|
|
|
|
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
|
super().__init__()
|
|
|
|
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
|
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
|
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
def forward(self, x):
|
|
x = x + self.attn(self.norm1(x))
|
|
x = x + self.mlp(self.norm2(x))
|
|
return x
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
"""Standard Transformer with support for AdaLN-zero blocks"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
hidden_dim,
|
|
output_dim,
|
|
depth,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
dropout=0.0,
|
|
block_class=Block,
|
|
):
|
|
super().__init__()
|
|
self.norm = nn.LayerNorm(hidden_dim)
|
|
self.layers = nn.ModuleList([])
|
|
|
|
self.input_proj = (
|
|
nn.Linear(input_dim, hidden_dim)
|
|
if input_dim != hidden_dim
|
|
else nn.Identity()
|
|
)
|
|
|
|
self.cond_proj = (
|
|
nn.Linear(input_dim, hidden_dim)
|
|
if input_dim != hidden_dim
|
|
else nn.Identity()
|
|
)
|
|
|
|
self.output_proj = (
|
|
nn.Linear(hidden_dim, output_dim)
|
|
if hidden_dim != output_dim
|
|
else nn.Identity()
|
|
)
|
|
|
|
for _ in range(depth):
|
|
self.layers.append(
|
|
block_class(hidden_dim, heads, dim_head, mlp_dim, dropout)
|
|
)
|
|
|
|
def forward(self, x, c=None):
|
|
|
|
if hasattr(self, "input_proj"):
|
|
x = self.input_proj(x)
|
|
|
|
if c is not None and hasattr(self, "cond_proj"):
|
|
c = self.cond_proj(c)
|
|
|
|
for block in self.layers:
|
|
x = block(x) if isinstance(block, Block) else block(x, c)
|
|
x = self.norm(x)
|
|
|
|
if hasattr(self, "output_proj"):
|
|
x = self.output_proj(x)
|
|
return x
|
|
|
|
class Embedder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim=10,
|
|
smoothed_dim=10,
|
|
emb_dim=10,
|
|
mlp_scale=4,
|
|
):
|
|
super().__init__()
|
|
self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1)
|
|
self.embed = nn.Sequential(
|
|
nn.Linear(smoothed_dim, mlp_scale * emb_dim),
|
|
nn.SiLU(),
|
|
nn.Linear(mlp_scale * emb_dim, emb_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: (B, T, D)
|
|
"""
|
|
x = x.float()
|
|
x = x.permute(0, 2, 1)
|
|
x = self.patch_embed(x)
|
|
x = x.permute(0, 2, 1)
|
|
x = self.embed(x)
|
|
return x
|
|
|
|
|
|
class MLP(nn.Module):
|
|
"""Simple MLP with optional normalization and activation"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
hidden_dim,
|
|
output_dim=None,
|
|
norm_fn=nn.LayerNorm,
|
|
act_fn=nn.GELU,
|
|
):
|
|
super().__init__()
|
|
norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
norm_fn,
|
|
act_fn(),
|
|
nn.Linear(hidden_dim, output_dim or input_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: (B*T, D)
|
|
"""
|
|
return self.net(x)
|
|
|
|
|
|
class ARPredictor(nn.Module):
|
|
"""Autoregressive predictor for next-step embedding prediction."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
num_frames,
|
|
depth,
|
|
heads,
|
|
mlp_dim,
|
|
input_dim,
|
|
hidden_dim,
|
|
output_dim=None,
|
|
dim_head=64,
|
|
dropout=0.0,
|
|
emb_dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim))
|
|
self.dropout = nn.Dropout(emb_dropout)
|
|
self.transformer = Transformer(
|
|
input_dim,
|
|
hidden_dim,
|
|
output_dim or input_dim,
|
|
depth,
|
|
heads,
|
|
dim_head,
|
|
mlp_dim,
|
|
dropout,
|
|
block_class=ConditionalBlock,
|
|
)
|
|
|
|
def forward(self, x, c):
|
|
"""
|
|
x: (B, T, d)
|
|
c: (B, T, act_dim)
|
|
"""
|
|
T = x.size(1)
|
|
x = x + self.pos_embedding[:, :T]
|
|
x = self.dropout(x)
|
|
x = self.transformer(x, c)
|
|
return x
|