Files
unifolm-world-model-action/model_architecture_analysis.md

40 KiB
Raw Permalink Blame History

UnifoLM World Model Action - 模型架构详细分析

目录

  1. 整体架构概览
  2. 推理流程分析
  3. 核心组件详解
  4. 性能瓶颈分析
  5. 内核融合优化建议

1. 整体架构概览

1.1 模型层次结构

DDPM (顶层模型)
├── DiffusionWrapper (条件包装器)
│   └── UNet3D (核心扩散模型)
│       ├── 时间嵌入 (Time Embedding)
│       ├── 下采样块 (Downsampling Blocks)
│       ├── 中间块 (Middle Blocks)
│       └── 上采样块 (Upsampling Blocks)
├── VAE (变分自编码器)
│   ├── Encoder (编码器)
│   └── Decoder (解码器)
├── CLIP Image Encoder (图像编码器)
├── Text Encoder (文本编码器)
├── State Projector (状态投影器)
└── Action Projector (动作投影器)

1.2 推理阶段数据流

输入观测 (Observation)
    ↓
[1] 条件编码阶段
    ├── 图像 → CLIP Encoder → Image Embedding
    ├── 图像 → VAE Encoder → Latent Condition
    ├── 文本 → Text Encoder → Text Embedding
    ├── 状态 → State Projector → State Embedding
    └── 动作 → Action Projector → Action Embedding
    ↓
[2] DDIM采样阶段 (n步迭代)
    ├── 初始化噪声 x_T
    └── For step in [0, n]:
        ├── 模型前向传播 (UNet3D)
        │   ├── 时间步嵌入
        │   ├── 条件注入 (CrossAttention)
        │   └── 噪声预测
        ├── DDIM更新公式
        └── x_{t-1} = f(x_t, noise_pred)
    ↓
[3] VAE解码阶段
    └── Latent → VAE Decoder → 视频帧

2. 推理流程分析

2.1 阶段1: 生成动作 (sim_mode=False)

目的: 根据观测和指令生成动作序列

输入:

  • observation.images.top: 历史图像观测 [B, C, T_obs, H, W]
  • observation.state: 历史状态 [B, T_obs, state_dim]
  • action: 历史动作 [B, T_action, action_dim]
  • instruction: 文本指令

输出:

  • pred_videos: 预测视频 [B, C, T, H, W]
  • pred_actions: 预测动作序列 [B, T, action_dim]

关键特点:

  • 动作条件被置零 (cond_action_emb = torch.zeros_like(...))
  • 使用文本指令作为主要引导

2.2 阶段2: 世界模型交互 (sim_mode=True)

目的: 根据动作预测未来观测

输入:

  • observation.images.top: 当前图像
  • observation.state: 当前状态
  • action: 阶段1生成的动作序列

输出:

  • pred_videos: 预测的未来视频帧
  • pred_states: 预测的未来状态

关键特点:

  • 不使用文本指令 (text_input=False)
  • 动作条件被实际使用

3. 核心组件详解

3.1 DDIM采样器 (DDIMSampler)

代码位置: src/unifolm_wma/models/samplers/ddim.py

核心方法: ddim_sampling() (第168-300行)

实际代码结构:

def ddim_sampling(self, cond, shape, x_T=None, ddim_steps=50, ...):
    # 初始化
    timesteps = self.ddim_timesteps[:ddim_steps]
    x = x_T if x_T is not None else torch.randn(shape, device=device)

    # 主循环
    for i, step in enumerate(iterator):
        # 获取时间步
        index = total_steps - i - 1
        ts = torch.full((b,), step, device=device, dtype=torch.long)

        # 模型前向传播 (核心瓶颈)
        outs = self.p_sample_ddim(x, cond, ts, index=index, ...)
        x, pred_x0 = outs

    return x

性能数据 (来自 profiling 报告,--ddim_steps 50):

  • DDIM采样调用: 22次 (action_generation + world_model_interaction 各11次)
  • 单次采样(50步)平均耗时: 35.58s (总计 782.70s)
  • 平均每步耗时: ~0.712s (35.58s / 50)
  • 当前 unconditional_guidance_scale=1.0 时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向

3.2 DiffusionWrapper (条件路由器)

代码位置: src/unifolm_wma/models/ddpms.py:2413-2524

作用: 将输入和条件路由到内部扩散模型

实际代码 (第2469-2479行):

elif self.conditioning_key == 'hybrid':
    xc = torch.cat([x] + c_concat, dim=1)  # 拼接latent条件
    cc = torch.cat(c_crossattn, 1)          # 拼接cross-attention条件
    cc_action = c_crossattn_action
    out = self.diffusion_model(xc, x_action, x_state, t,
                               context=cc, context_action=cc_action, **kwargs)

条件类型:

  1. c_concat: 通道拼接条件 (VAE编码的图像)
  2. c_crossattn: 交叉注意力条件 (文本、图像、状态、动作embedding)
  3. c_crossattn_action: 动作头专用条件

3.3 WMAModel (核心扩散模型)

代码位置: src/unifolm_wma/modules/networks/wma_model.py:326-849

配置文件: configs/inference/world_model_interaction.yaml:69-104

实际配置参数:

in_channels: 8              # 输入通道 (4 latent + 4 VAE条件)
out_channels: 4             # 输出通道
model_channels: 320         # 基础通道数
channel_mult: [1, 2, 4, 4]  # 通道倍增: [320, 640, 1280, 1280]
num_res_blocks: 2           # 每个分辨率2个ResBlock
attention_resolutions: [4, 2, 1]  # 在这些分辨率启用注意力
num_head_channels: 64       # 每个注意力头64通道
transformer_depth: 1        # Transformer深度
context_dim: 1024           # 交叉注意力上下文维度
temporal_length: 16         # 时间序列长度

架构层次 (详见附录A.1):

  • 4个下采样阶段 (每阶段2个ResBlock + Attention)
  • 1个中间块 (2个ResBlock + Attention)
  • 3个上采样阶段 (每阶段2个ResBlock + Attention)
  • 总计: 16个ResBlock + 32个Transformer

3.4 VAE (变分自编码器)

代码位置: src/unifolm_wma/models/autoencoder.py

配置文件: configs/inference/world_model_interaction.yaml:159-180

实际配置参数:

AutoencoderKL:
  embed_dim: 4              # Latent维度
  z_channels: 4             # Latent通道数
  in_channels: 3            # RGB输入
  out_ch: 3                 # RGB输出
  ch: 128                   # 基础通道数
  ch_mult: [1, 2, 4, 4]     # 通道倍增: [128, 256, 512, 512]
  num_res_blocks: 2         # 每层2个ResBlock
  attn_resolutions: []      # VAE中不使用注意力

编码/解码过程:

# 编码: [B, 3, 320, 512] → [B, 4, 40, 64] (8×8下采样)
z = model.encode_first_stage(img)

# 解码: [B, 4, 40, 64] → [B, 3, 320, 512]
video = model.decode_first_stage(samples)

性能数据:

  • VAE编码: 0.90s (22次, 平均0.041s/次)
  • VAE解码: 12.44s (22次, 平均0.566s/次)
  • 压缩比: 8×8 = 64倍空间压缩

详细架构: 见附录A.4

3.5 条件编码器

性能说明: 本次 profiling 未对各条件编码器单独计时,统一计入 synthesis/conditioning_prep,总计 2.92s (22次, 平均0.133s/次)。

3.5.1 CLIP图像编码器

代码位置: src/unifolm_wma/modules/encoders/condition.py - FrozenOpenCLIPImageEmbedderV2

配置文件: configs/inference/world_model_interaction.yaml:188-204

实际配置:

FrozenOpenCLIPImageEmbedderV2:
  freeze: true
  # 使用OpenCLIP ViT-H/14
  # 输出: [B, 1280]

Resampler (图像投影器):
  dim: 1024              # 输出维度
  depth: 4               # Transformer深度
  heads: 12              # 12个注意力头
  num_queries: 16        # 16个查询token
  embedding_dim: 1280    # CLIP输出维度

数据流: 图像 [B, 3, H, W] → CLIP → [B, 1280] → Resampler → [B, 16, 1024]

3.5.2 文本编码器

代码位置: src/unifolm_wma/modules/encoders/condition.py - FrozenOpenCLIPEmbedder

配置文件: configs/inference/world_model_interaction.yaml:182-186

实际配置:

FrozenOpenCLIPEmbedder:
  freeze: True
  layer: "penultimate"  # 使用倒数第二层
  # 输出: [B, seq_len, 1024]

3.5.3 状态投影器

代码位置: src/unifolm_wma/models/ddpms.py:2014-2026 - MLPProjector

MLPProjector实现 (src/unifolm_wma/utils/projector.py:14-37):

class MLPProjector(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
        if mlp_type == "gelu-mlp":
            self.projector = nn.Sequential(
                nn.Linear(input_dim, output_dim, bias=True),
                nn.GELU(approximate='tanh'),
                nn.Linear(output_dim, output_dim, bias=True),
            )

数据流: 状态 [B, T_obs, 16] → MLPProjector → [B, T_obs, 1024] + agent_state_pos_emb

3.5.4 动作投影器

代码位置: src/unifolm_wma/models/ddpms.py:2020-2024 - MLPProjector

数据流: 动作 [B, T_action, 16] → MLPProjector → [B, T_action, 1024] + agent_action_pos_emb

位置嵌入定义:

# ddpms.py:2023-2026
self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))

4. 性能瓶颈分析

4.1 时间分布 (profiling 报告, 11次迭代 / 22次采样)

说明: profile_section 存在嵌套,宏观统计的总和不是 wall time以下以每段的 total/avg 为准。

Section Count Total(s) Avg(s) 说明
iteration_total 11 836.79 76.07 单次迭代总耗时
action_generation 11 399.71 36.34 生成动作 (DDIM 50步)
world_model_interaction 11 398.38 36.22 世界模型交互 (DDIM 50步)
synthesis/ddim_sampling 22 782.70 35.58 单次采样
synthesis/conditioning_prep 22 2.92 0.13 条件编码汇总
synthesis/decode_first_stage 22 12.44 0.57 VAE解码
save_results 11 38.67 3.52 I/O保存
model_loading/config 1 49.77 49.77 一次性开销
model_loading/checkpoint 1 11.83 11.83 一次性开销
model_to_cuda 1 8.91 8.91 一次性开销

4.2 DDIM采样详细分析

DDIM采样是主要瓶颈 (基于 50 步采样):

  • 采样调用次数: 22 次 (11 次迭代 × 2 阶段)
  • 采样总耗时: 782.70s,平均 35.58s/次
  • 平均每步耗时: ~0.712s (35.58s / 50)
  • unconditional_guidance_scale=1.0 时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向
  • conditioning_prep + ddim_sampling + decode_first_stageddim_sampling 占约 98%

4.3 瓶颈总结

关键发现:

  1. DDIM采样占比最高 - 单次迭代平均 76.07s,其中采样约 71.15s (≈93%)
  2. CUDA算子时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%)Attention 约 3.0%
  3. CPU侧 copy/to 仍明显 (aten::copy_, aten::to/_to_copy 在报告中耗时靠前)
  4. VAE解码为次级瓶颈 (0.57s/次)

4.4 GPU显存概览

  • Peak allocated: 17890.50 MB
  • Average allocated: 16129.98 MB

5. 内核融合优化建议

5.1 优化策略概览

基于性能分析,优化应聚焦于:

  1. UNet3D模型前向传播 (最高优先级)
  2. VAE解码器 (次要优先级)
  3. 批处理和并行化 (辅助优化)

5.2 WMAModel内核融合机会

5.2.1 时间步嵌入融合

代码位置: src/unifolm_wma/utils/diffusion.py - timestep_embedding()

当前实现 (实际代码):

# 1. 正弦位置编码
def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half) / half)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    return embedding

# 2. 时间嵌入网络 (在WMAModel.__init__中)
self.time_embed = nn.Sequential(
    nn.Linear(model_channels, time_embed_dim),  # 320 → 1280
    nn.SiLU(),
    nn.Linear(time_embed_dim, time_embed_dim),  # 1280 → 1280
)

融合机会:

  • Linear + SiLU + Linear 可融合为单个kernel
  • 正弦编码计算可与第一个Linear融合

预期收益: 减少2-3次kernel启动开销

5.2.2 ResBlock内核融合

代码位置: src/unifolm_wma/modules/networks/wma_model.py:130-263 - class ResBlock

当前实现 (实际代码):

# in_layers: GroupNorm + SiLU + Conv
self.in_layers = nn.Sequential(
    normalization(channels),      # GroupNorm
    nn.SiLU(),
    conv_nd(dims, channels, out_channels, 3, padding=1)
)

# emb_layers: SiLU + Linear
self.emb_layers = nn.Sequential(
    nn.SiLU(),
    nn.Linear(emb_channels, out_channels)
)

# out_layers: GroupNorm + SiLU + Dropout + Conv
self.out_layers = nn.Sequential(
    normalization(out_channels),  # GroupNorm
    nn.SiLU(),
    nn.Dropout(p=dropout),
    zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
)

融合机会:

  1. GroupNorm + SiLU 可融合 (in_layers和out_layers各一次)
  2. emb_layersSiLU + Linear 可融合
  3. 残差加法可与下一层的GroupNorm融合

实际瓶颈: 16个ResBlock × 50步 × 2次 = 1600次ResBlock调用

预期收益: 每个ResBlock节省50-60%的kernel启动开销

5.2.3 注意力机制优化

代码位置: src/unifolm_wma/modules/attention.py

实际配置:

  • SpatialTransformer: 空间维度注意力
  • TemporalTransformer: 时间维度注意力
  • 总计: 32个Transformer × 50步 × 2次 = 3200次注意力调用

优化方案: 使用 PyTorch 内置的 Flash Attention:

from torch.nn.functional import scaled_dot_product_attention

# 替换标准注意力计算
out = scaled_dot_product_attention(Q, K, V, is_causal=False)

预期收益: 注意力层加速2-3倍整体加速30-40%

5.3 VAE解码器优化

代码位置: src/unifolm_wma/models/autoencoder.py

当前性能: 12.44s (22次调用, 平均0.566s/次)

优化方案:

  1. 混合精度: 使用FP16进行解码

    with torch.cuda.amp.autocast():
        video = vae.decode(latent)
    
  2. 批处理优化: 确保VAE解码使用批处理而非逐帧

预期收益: 加速20-30%

5.4 实施建议

5.4.1 使用 torch.compile() (最简单)

代码位置: scripts/evaluation/world_model_interaction.py

实际实施位置: 在模型加载后添加:

# 在模型加载并移动到GPU后添加
config = OmegaConf.load(args.config)
model = instantiate_from_config(config.model)
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
model = model.cuda()

# 添加 torch.compile() 优化
model.model.diffusion_model = torch.compile(
    model.model.diffusion_model,
    mode='max-autotune',  # 或 'reduce-overhead'
    fullgraph=True
)

优点:

  • 无需修改模型代码
  • 自动融合操作
  • 支持动态shape

预期收益: 20-40%加速

5.4.2 使用 Flash Attention

代码位置: src/unifolm_wma/modules/attention.py

当前实现分析: 代码已经支持 xformers (xformers.ops.memory_efficient_attention)。当 xformers 可用时,CrossAttention 类会自动使用 efficient_forward 方法:

# attention.py:90-91
if XFORMERS_IS_AVAILBLE and temporal_length is None:
    self.forward = self.efficient_forward

进一步优化方案: 如果 xformers 不可用,可以使用 PyTorch 内置的 Flash Attention:

from torch.nn.functional import scaled_dot_product_attention
out = scaled_dot_product_attention(q, k, v, is_causal=False)

预期收益: 注意力层加速2-3倍

5.4.3 混合精度推理

代码位置: scripts/evaluation/world_model_interaction.py

实际实施位置: 在推理调用处添加混合精度上下文:

# 在 image_guided_synthesis_sim_mode 调用处添加
with torch.cuda.amp.autocast():
    pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(
        model, sample['instruction'], observation, noise_shape,
        action_cond_step=args.exe_steps,
        ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
        unconditional_guidance_scale=args.unconditional_guidance_scale,
        fs=model_input_fs, timestep_spacing=args.timestep_spacing,
        guidance_rescale=args.guidance_rescale, sim_mode=False)

注意事项:

  • 模型会自动在FP16和FP32之间切换
  • 某些操作(如LayerNorm)会自动保持FP32精度
  • 无需手动转换模型权重

预期收益: 30-50%加速 + 减少50%显存

5.5 优化路线图

阶段1: 快速优化

目标: 获得20-40%加速,无需修改模型代码

实施步骤:

  1. 启用 torch.compile() - 在模型加载后添加
  2. 启用 torch.backends.cudnn.benchmark = True - 在推理开始前设置
  3. 使用混合精度推理 (FP16) - 在推理调用处添加

实施代码:

# 在推理函数开始处添加
torch.backends.cudnn.benchmark = True

# 在模型加载后添加 torch.compile()
model = model.cuda()
model.model.diffusion_model = torch.compile(
    model.model.diffusion_model,
    mode='max-autotune'
)

# 在推理循环中使用混合精度
with torch.cuda.amp.autocast():
    pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)

阶段2: 中级优化

目标: 获得50-70%加速

实施步骤:

  1. 确保 xformers 已安装并启用 - 检查 XFORMERS_IS_AVAILBLE 标志
  2. 优化VAE解码器批处理 - 检查 src/unifolm_wma/models/autoencoder.py 中的 decode() 方法
  3. 分析并优化内存访问模式 - 使用 torch.cuda.memory_stats() 分析

关键修改点:

  • 确认 xformers 已正确安装: pip install xformers
  • CrossAttention 类中,当 xformers 可用时会自动使用 efficient_forward
  • 确保VAE解码使用批处理而非逐帧处理

阶段3: 深度优化

目标: 获得2-3倍加速

实施步骤:

  1. 自定义CUDA kernel融合关键操作
  2. 优化卷积操作
    • 分析 Conv2D 操作的性能 (模型实际使用 Conv2D 而非 Conv3D)
  3. 优化数据加载和预处理pipeline

需要的技能:

  • CUDA编程
  • PyTorch C++扩展
  • 性能分析工具 (Nsight Systems, nvprof)

5.6 预期总体收益

基于以上优化策略和实际代码分析,预期性能提升:

优化阶段 预期加速比 实施难度 主要修改文件
阶段1: 快速优化 1.2-1.4x scripts/evaluation/world_model_interaction.py
阶段2: 中级优化 1.5-1.7x src/unifolm_wma/modules/attention.py, src/unifolm_wma/models/autoencoder.py
阶段3: 深度优化 2.0-3.0x src/unifolm_wma/modules/networks/wma_model.py + 自定义CUDA kernel

总体目标: 通过系统性优化实现 2-3倍加速


6. 关键代码位置索引

为方便内核融合实施,以下是关键代码位置:

6.1 核心模型文件

组件 文件路径 关键类/函数
DDPM主模型 src/unifolm_wma/models/ddpms.py class DDPM
条件包装器 src/unifolm_wma/models/ddpms.py:2413 class DiffusionWrapper
DDIM采样器 src/unifolm_wma/models/samplers/ddim.py class DDIMSampler
VAE编解码 src/unifolm_wma/models/autoencoder.py encode_first_stage, decode_first_stage

6.2 推理脚本

文件 说明
scripts/evaluation/world_model_interaction.py 推理脚本

6.3 配置文件

文件 说明
configs/inference/world_model_interaction.yaml 推理配置
unitree_g1_pack_camera/case1/run_world_model_interaction.sh 运行脚本

7. 下一步行动建议

7.1 立即可执行的优化

最小改动,最大收益:

  1. 启用 torch.compile()

    # 在模型加载并移动到GPU后添加
    model.model.diffusion_model = torch.compile(
        model.model.diffusion_model,
        mode='max-autotune'
    )
    
  2. 启用 cuDNN benchmark

    torch.backends.cudnn.benchmark = True
    
  3. 混合精度推理

    with torch.cuda.amp.autocast():
        pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
    

预期收益: 20-40%加速,无风险

7.2 需要深入探索的部分

为了更精确的优化,建议进一步分析:

  1. 注意力层的具体实现

    • 代码位置: src/unifolm_wma/modules/attention.py
    • 分析目标:
      • CrossAttention 类 (第48-398行) - 核心注意力实现
      • BasicTransformerBlock 类 (第400-469行) - Transformer块
      • 确认 xformers 是否已启用 (XFORMERS_IS_AVAILBLE 标志)
  2. ResBlock的详细结构

  3. 内存瓶颈分析

    • 分析工具: 使用 torch.cuda.memory_stats()torch.profiler
    • 分析位置: scripts/evaluation/world_model_interaction.py
    • 分析目标:
      • 识别内存拷贝热点
      • 优化中间张量的生命周期
      • 减少不必要的内存分配
  4. 计算瓶颈定位

    • 分析工具: Nsight Systems 或 PyTorch Profiler
    • 分析目标:
      • 识别 kernel 启动开销
      • 分析 GPU 利用率
      • 找到计算密集型操作

8. 参考资料

8.1 优化技术文档

8.2 相关论文

  • DDIM: Denoising Diffusion Implicit Models
  • Flash Attention: Fast and Memory-Efficient Exact Attention
  • Efficient Diffusion Models for Vision

9. 总结

9.1 关键发现

  1. DDIM采样仍是主要瓶颈 - 单次采样(50步)平均 35.58s
  2. Linear/GEMM 与 Convolution 为主要 CUDA 时间来源 - Attention 占比相对较小
  3. VAE解码为次级优化目标 - 0.57s/次

9.2 优化优先级

高优先级 (立即实施):

  • torch.compile()
  • cuDNN benchmark
  • 混合精度推理

中优先级 (1周内):

  • Flash Attention集成
  • VAE批处理优化

低优先级 (长期):

  • 自定义CUDA kernel
  • 模型架构改进

9.3 预期成果

通过系统性优化,预期可获得 1.5-3倍加速 (视采样步数与编译/混合精度策略而定)。


文档版本: v1.2 创建日期: 2026-01-17 最后更新: 2026-01-18 更新内容: 校准DDIM步数为50并替换为最新profiling数据


附录A: 实际模型架构详解

基于代码分析,以下是真实的模型实现细节。

A.1 WMAModel 实际配置

配置文件: configs/inference/world_model_interaction.yaml:69-104

WMAModel参数:
  in_channels: 8              # 输入通道 (4 latent + 4 concat条件)
  out_channels: 4             # 输出通道 (latent空间)
  model_channels: 320         # 基础通道数
  channel_mult: [1, 2, 4, 4]  # 通道倍增: [320, 640, 1280, 1280]
  num_res_blocks: 2           # 每个分辨率2个ResBlock
  attention_resolutions: [4, 2, 1]  # 在这些分辨率启用注意力
  num_head_channels: 64       # 每个注意力头64通道
  transformer_depth: 1        # Transformer深度
  context_dim: 1024           # 交叉注意力上下文维度
  temporal_length: 16         # 时间序列长度
  dropout: 0.1

架构层次:

输入: [B, 8, 16, 40, 64]  (8通道 = 4 latent + 4 VAE条件)
  ↓
下采样路径 (4个阶段):
  Stage 0: [B, 320, 16, 40, 64]   - 2个ResBlock + SpatialTransformer + TemporalTransformer
  Stage 1: [B, 640, 16, 20, 32]   - Downsample + 2个ResBlock + Attention
  Stage 2: [B, 1280, 16, 10, 16]  - Downsample + 2个ResBlock + Attention
  Stage 3: [B, 1280, 16, 5, 8]    - Downsample + 2个ResBlock + Attention
  ↓
中间块: [B, 1280, 16, 5, 8]
  - ResBlock + SpatialTransformer + TemporalTransformer + ResBlock
  ↓
上采样路径 (3个阶段):
  Stage 2: [B, 1280, 16, 10, 16]  - Upsample + 2个ResBlock + Attention
  Stage 1: [B, 640, 16, 20, 32]   - Upsample + 2个ResBlock + Attention
  Stage 0: [B, 320, 16, 40, 64]   - Upsample + 2个ResBlock + Attention
  ↓
输出: [B, 4, 16, 40, 64]  (预测的噪声或速度)

A.2 ResBlock 实际实现

位置: src/unifolm_wma/modules/networks/wma_model.py:130-263

实际代码结构:

class ResBlock:
    def __init__(self, channels, emb_channels, dropout, ...):
        # 输入层: GroupNorm + SiLU + Conv
        self.in_layers = nn.Sequential(
            normalization(channels),      # GroupNorm
            nn.SiLU(),                    # 激活函数
            conv_nd(dims, channels, out_channels, 3, padding=1)
        )
        
        # 时间步嵌入层: SiLU + Linear
        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_channels, out_channels)
        )
        
        # 输出层: GroupNorm + SiLU + Dropout + Conv
        self.out_layers = nn.Sequential(
            normalization(out_channels),  # GroupNorm
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
        )
        
        # 残差连接
        self.skip_connection = ...
        
        # 可选的时间卷积
        if use_temporal_conv:
            self.temporal_conv = TemporalConvBlock(...)
    
    def forward(self, x, emb):
        h = self.in_layers(x)              # GroupNorm + SiLU + Conv
        emb_out = self.emb_layers(emb)     # 时间步嵌入
        h = h + emb_out                    # 加入时间步信息
        h = self.out_layers(h)             # GroupNorm + SiLU + Dropout + Conv
        h = self.skip_connection(x) + h    # 残差连接
        
        if use_temporal_conv:
            h = self.temporal_conv(h)      # 时间维度卷积
        return h

内核融合机会:

  1. GroupNorm + SiLU 可融合 (in_layers和out_layers各一次)
  2. emb_layers(SiLU + Linear) 可融合
  3. 残差加法 + 下一层的GroupNorm 可融合

A.3 注意力机制实际实现

SpatialTransformer (空间注意力):

  • 位置: src/unifolm_wma/modules/attention.py:472-558
  • 在特征图的空间维度(H×W)上执行自注意力和交叉注意力
  • 使用 transformer_depth=1即每个位置1层Transformer
  • 当 xformers 可用时,使用 efficient_forward 方法进行高效注意力计算

TemporalTransformer (时间注意力):

  • 位置: src/unifolm_wma/modules/attention.py:561-680
  • 在时间维度(T=16帧)上执行自注意力
  • 配置: temporal_selfatt_only=True (仅时间自注意力,不做交叉注意力)
  • 使用相对位置编码: use_relative_position=False (实际未启用)

CrossAttention (核心注意力层):

  • 位置: src/unifolm_wma/modules/attention.py:48-398
  • 支持多种交叉注意力: 图像、文本、状态、动作
  • 当 xformers 可用时自动使用 xformers.ops.memory_efficient_attention

注意力头配置:

  • num_head_channels=64: 每个头64通道
  • 对于320通道: 320/64 = 5个注意力头
  • 对于640通道: 640/64 = 10个注意力头
  • 对于1280通道: 1280/64 = 20个注意力头

A.4 VAE 实际配置

配置文件: configs/inference/world_model_interaction.yaml:159-180

AutoencoderKL:
  embed_dim: 4              # Latent维度
  z_channels: 4             # Latent通道数
  resolution: 256           # 基础分辨率
  in_channels: 3            # RGB输入
  out_ch: 3                 # RGB输出
  ch: 128                   # 基础通道数
  ch_mult: [1, 2, 4, 4]     # 通道倍增
  num_res_blocks: 2         # 每层2个ResBlock
  attn_resolutions: []      # VAE中不使用注意力
  dropout: 0.0

编码器架构:

输入: [B, 3, 320, 512]
  ↓ Conv 3→128
  ↓ ResBlock×2 [128, 320, 512]
  ↓ Downsample [128, 160, 256]
  ↓ ResBlock×2 [256, 160, 256]
  ↓ Downsample [256, 80, 128]
  ↓ ResBlock×2 [512, 80, 128]
  ↓ Downsample [512, 40, 64]
  ↓ ResBlock×2 [512, 40, 64]
  ↓ ResBlock + Conv
输出: [B, 4, 40, 64]  (8×8下采样)

解码器架构 (编码器的镜像):

输入: [B, 4, 40, 64]
  ↓ Conv + ResBlock
  ↓ ResBlock×2 [512, 40, 64]
  ↓ Upsample [512, 80, 128]
  ↓ ResBlock×2 [512, 80, 128]
  ↓ Upsample [256, 160, 256]
  ↓ ResBlock×2 [256, 160, 256]
  ↓ Upsample [128, 320, 512]
  ↓ ResBlock×2 [128, 320, 512]
  ↓ Conv 128→3
输出: [B, 3, 320, 512]

A.5 条件编码器实际配置

CLIP图像编码器

配置: configs/inference/world_model_interaction.yaml:188-191

FrozenOpenCLIPImageEmbedderV2:
  freeze: true
  # 使用OpenCLIP的ViT-H/14模型
  # 输出维度: 1280

图像投影器 (Resampler):

Resampler:
  dim: 1024              # 输出维度
  depth: 4               # Transformer深度
  dim_head: 64           # 注意力头维度
  heads: 12              # 12个注意力头
  num_queries: 16        # 16个查询token
  embedding_dim: 1280    # CLIP输出维度
  output_dim: 1024       # 最终输出维度
  video_length: 16       # 视频长度

数据流:

图像 [B, 3, H, W]
  ↓ CLIP Encoder
  ↓ [B, 1280]
  ↓ Resampler (Perceiver-style)
  ↓ [B, 16, 1024]  (16个token每个1024维)

文本编码器

配置: configs/inference/world_model_interaction.yaml:182-186

FrozenOpenCLIPEmbedder:
  freeze: True
  layer: "penultimate"  # 使用倒数第二层
  # 输出维度: 1024

数据流:

文本指令 "pick up the box"
  ↓ OpenCLIP Text Encoder
  ↓ [B, seq_len, 1024]

动作/状态投影器

代码位置: src/unifolm_wma/models/ddpms.py:2014-2026

MLPProjector实现 (src/unifolm_wma/utils/projector.py:14-37):

class MLPProjector(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
        if mlp_type == "gelu-mlp":
            self.projector = nn.Sequential(
                nn.Linear(input_dim, output_dim, bias=True),
                nn.GELU(approximate='tanh'),
                nn.Linear(output_dim, output_dim, bias=True),
            )

初始化代码 (ddpms.py:2014-2026):

# 状态投影器
self.state_projector = MLPProjector(agent_state_dim, 1024)  # 16 → 1024
self.action_projector = MLPProjector(agent_action_dim, 1024)  # 16 → 1024

# 位置嵌入
self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))

数据流:

状态 [B, T_obs, 16]
  ↓ MLPProjector (Linear + GELU + Linear)
  ↓ [B, T_obs, 1024]
  ↓ + agent_state_pos_emb
  ↓ [B, T_obs, 1024]

动作 [B, T_action, 16]
  ↓ MLPProjector (Linear + GELU + Linear)
  ↓ [B, T_action, 1024]
  ↓ + agent_action_pos_emb
  ↓ [B, T_action, 1024]

A.6 时间步嵌入实际实现

位置: src/unifolm_wma/utils/diffusion.py:timestep_embedding

实际代码:

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    创建正弦位置编码
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half) / half
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    return embedding

在WMAModel中的使用:

# 时间步 t ∈ [0, 999]
t_emb = timestep_embedding(t, model_channels)  # [B, 320]
t_emb = self.time_embed(t_emb)                 # Linear(320 → 1280)
# 输出: [B, 1280]

时间嵌入网络:

t ∈ [0, 999]
  ↓ timestep_embedding (正弦编码)
  ↓ [B, 320]
  ↓ Linear(320 → 1280)
  ↓ SiLU
  ↓ Linear(1280 → 1280)
  ↓ [B, 1280]

A.7 动作头 (ConditionalUnet1D) 实际配置

配置: configs/inference/world_model_interaction.yaml:106-127

ConditionalUnet1D:
  input_dim: 16                    # 动作维度
  n_obs_steps: 2                   # 观测步数
  diffusion_step_embed_dim: 128    # 扩散步嵌入维度
  down_dims: [256, 512, 1024, 2048]  # 下采样通道
  kernel_size: 5                   # 卷积核大小
  n_groups: 8                      # GroupNorm分组数
  horizon: 16                      # 预测时间范围
  use_linear_attn: true            # 使用线性注意力
  imagen_cond_gradient: true       # 使用图像条件梯度

架构:

输入: 噪声动作 [B, 16, 16]  (16维动作 × 16步)
条件: 
  - 图像特征 [B, C, H, W] 来自WMAModel中间层
  - 观测编码 [B, n_obs, obs_dim]
  
  ↓ Conv1D + ResBlock
  ↓ [B, 256, 16]
  ↓ Downsample + ResBlock
  ↓ [B, 512, 8]
  ↓ Downsample + ResBlock
  ↓ [B, 1024, 4]
  ↓ Downsample + ResBlock
  ↓ [B, 2048, 2]
  ↓ Middle Block (with attention)
  ↓ Upsample + ResBlock
  ↓ [B, 1024, 4]
  ↓ Upsample + ResBlock
  ↓ [B, 512, 8]
  ↓ Upsample + ResBlock
  ↓ [B, 256, 16]
  ↓ Conv1D
输出: 预测噪声 [B, 16, 16]

A.8 完整前向传播流程

基于实际代码,完整的前向传播流程如下:

# 1. 条件编码 (一次性完成,可缓存)
cond_img_emb = clip_encoder(img)  resampler  [B, 16, 1024]
cond_text_emb = text_encoder(text)  [B, seq_len, 1024]
cond_state_emb = state_projector(state) + pos_emb  [B, T_obs, 1024]
cond_action_emb = action_projector(action) + pos_emb  [B, T_action, 1024]
cond_latent = vae.encode(img)  [B, 4, T, 40, 64]

# 2. 拼接条件
c_concat = [cond_latent]  # 通道拼接
c_crossattn = [cond_text_emb, cond_img_emb, cond_state_emb, cond_action_emb]
c_crossattn = torch.cat(c_crossattn, dim=1)  # [B, total_tokens, 1024]

# 3. DDIM采样循环 (ddim_steps 默认 50实际由 --ddim_steps 控制)
x = torch.randn([B, 4, 16, 40, 64])  # 初始噪声
for step in range(ddim_steps):
    # 3.1 时间步嵌入
    t_emb = timestep_embedding(t, 320)  Linear  [B, 1280]
    
    # 3.2 拼接输入
    x_in = torch.cat([x, cond_latent], dim=1)  # [B, 8, 16, 40, 64]
    
    # 3.3 UNet前向传播 (核心瓶颈)
    noise_pred = wma_model(x_in, t_emb, c_crossattn)
    # 包含: 4个下采样阶段 + 中间块 + 3个上采样阶段
    # 每个阶段: 2个ResBlock + SpatialTransformer + TemporalTransformer
    
    # 3.4 DDIM更新
    x = ddim_update(x, noise_pred, t, t_prev)

# 4. VAE解码
video = vae.decode(x)  [B, 3, 16, 320, 512]

A.9 基于实际架构的优化建议更新

我的理解: 本次 profiling 显示采样阶段占据绝对主导地位:单次采样(50步)平均 35.58s,且每次迭代包含 action_generation 与 world_model_interaction 各一次采样。换句话说,任何“每步的细微改进”都会被 50 步和 2 阶段放大;因此最有效的优化要么减少步数,要么显著加速 UNet 前向。CUDA 时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%),而 Attention 约 3.0%,这意味着算子层面优先考虑矩阵乘/卷积路径的优化收益更稳定。CPU 侧 aten::copy_/to/_to_copy 也明显,说明循环内的数据搬运仍有成本可省。

优化点1: 采样步数与采样器 (最高优先级)

依据:

  • 50步采样平均 35.58s (0.712s/步),减少步数带来近线性收益
  • 单次迭代约 76.07s,其中采样约占 93%

建议:

  • 在保证质量的前提下,将 --ddim_steps 从 50 降到 20-30
  • 评估更快采样器(如 DPM-Solver++/UniPC)以减少步数
  • 若使用 CFG注意 unconditional_guidance_scale > 1.0 会使每步前向翻倍

优化点2: GEMM/Conv 主导路径加速 (高优先级)

依据:

  • CUDA 时间主力来自 Linear/GEMM 与 Convolution

建议:

  • torch.compile() 仅包裹 UNet 主干以获得融合收益
  • 启用混合精度 (autocast) + TF32 (torch.backends.cuda.matmul.allow_tf32 = True)
  • 固定输入形状时开启 torch.backends.cudnn.benchmark = True

优化点3: ResBlock融合 (中高优先级)

实际瓶颈:

  • 每个DDIM步骤调用UNet一次
  • UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段
  • 每个阶段有2个ResBlock
  • 总计: 16个ResBlock × 50步 × 2次(阶段1+2) = 1600次ResBlock调用

融合机会:

# 当前: 6次kernel启动
h = group_norm(x)      # kernel 1
h = silu(h)            # kernel 2
h = conv2d(h)          # kernel 3
h = group_norm(h)      # kernel 4
h = silu(h)            # kernel 5
h = conv2d(h)          # kernel 6
out = x + h            # kernel 7

# 优化后: 2-3次kernel启动
h = fused_norm_silu_conv(x)     # kernel 1 (融合)
h = fused_norm_silu_conv(h)     # kernel 2 (融合)
out = fused_residual_add(x, h)  # kernel 3 (融合)

预期收益: 每个ResBlock节省50-60%的kernel启动开销

优化点4: 注意力机制优化 (中优先级)

实际配置:

  • SpatialTransformer: 在每个阶段的每个ResBlock后
  • TemporalTransformer: 在每个阶段的每个ResBlock后
  • 总计: 16个Spatial + 16个Temporal = 32个Transformer × 50步 × 2次 = 3200次注意力调用

理解: Attention 在算子占比中只有约 3%,不是当前主要瓶颈,但若未启用高效实现仍可获得稳定收益。

优化方案:

  • 确认 xformers 已启用 (XFORMERS_IS_AVAILBLE 为 True)
  • 无 xformers 时替换为 PyTorch SDPA:
from torch.nn.functional import scaled_dot_product_attention
out = scaled_dot_product_attention(Q, K, V, is_causal=False)

优化点5: 数据搬运与 CPU 开销 (中优先级)

依据:

  • aten::copy_aten::to/_to_copy 在 CPU 侧耗时突出

建议:

  • 避免在 DDIM 循环内重复 .to(device) / .float() / .half()
  • 将常量张量(如 timestep、sigma)提前放到 GPU
  • 尽量减少临时张量创建与 clone(),尤其是 per-step 级别

优化点6: VAE 解码 (低优先级)

依据:

  • VAE 解码 0.57s/次,次于采样瓶颈

建议:

  • 统一使用 autocast 解码
  • 若可容忍轻微质量下降,可降低解码频率或分辨率