40 KiB
UnifoLM World Model Action - 模型架构详细分析
目录
1. 整体架构概览
1.1 模型层次结构
DDPM (顶层模型)
├── DiffusionWrapper (条件包装器)
│ └── UNet3D (核心扩散模型)
│ ├── 时间嵌入 (Time Embedding)
│ ├── 下采样块 (Downsampling Blocks)
│ ├── 中间块 (Middle Blocks)
│ └── 上采样块 (Upsampling Blocks)
├── VAE (变分自编码器)
│ ├── Encoder (编码器)
│ └── Decoder (解码器)
├── CLIP Image Encoder (图像编码器)
├── Text Encoder (文本编码器)
├── State Projector (状态投影器)
└── Action Projector (动作投影器)
1.2 推理阶段数据流
输入观测 (Observation)
↓
[1] 条件编码阶段
├── 图像 → CLIP Encoder → Image Embedding
├── 图像 → VAE Encoder → Latent Condition
├── 文本 → Text Encoder → Text Embedding
├── 状态 → State Projector → State Embedding
└── 动作 → Action Projector → Action Embedding
↓
[2] DDIM采样阶段 (n步迭代)
├── 初始化噪声 x_T
└── For step in [0, n]:
├── 模型前向传播 (UNet3D)
│ ├── 时间步嵌入
│ ├── 条件注入 (CrossAttention)
│ └── 噪声预测
├── DDIM更新公式
└── x_{t-1} = f(x_t, noise_pred)
↓
[3] VAE解码阶段
└── Latent → VAE Decoder → 视频帧
2. 推理流程分析
2.1 阶段1: 生成动作 (sim_mode=False)
目的: 根据观测和指令生成动作序列
输入:
observation.images.top: 历史图像观测 [B, C, T_obs, H, W]observation.state: 历史状态 [B, T_obs, state_dim]action: 历史动作 [B, T_action, action_dim]instruction: 文本指令
输出:
pred_videos: 预测视频 [B, C, T, H, W]pred_actions: 预测动作序列 [B, T, action_dim]
关键特点:
- 动作条件被置零 (
cond_action_emb = torch.zeros_like(...)) - 使用文本指令作为主要引导
2.2 阶段2: 世界模型交互 (sim_mode=True)
目的: 根据动作预测未来观测
输入:
observation.images.top: 当前图像observation.state: 当前状态action: 阶段1生成的动作序列
输出:
pred_videos: 预测的未来视频帧pred_states: 预测的未来状态
关键特点:
- 不使用文本指令 (
text_input=False) - 动作条件被实际使用
3. 核心组件详解
3.1 DDIM采样器 (DDIMSampler)
代码位置: src/unifolm_wma/models/samplers/ddim.py
核心方法: ddim_sampling() (第168-300行)
实际代码结构:
def ddim_sampling(self, cond, shape, x_T=None, ddim_steps=50, ...):
# 初始化
timesteps = self.ddim_timesteps[:ddim_steps]
x = x_T if x_T is not None else torch.randn(shape, device=device)
# 主循环
for i, step in enumerate(iterator):
# 获取时间步
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
# 模型前向传播 (核心瓶颈)
outs = self.p_sample_ddim(x, cond, ts, index=index, ...)
x, pred_x0 = outs
return x
性能数据 (来自 profiling 报告,--ddim_steps 50):
- DDIM采样调用: 22次 (action_generation + world_model_interaction 各11次)
- 单次采样(50步)平均耗时: 35.58s (总计 782.70s)
- 平均每步耗时: ~0.712s (35.58s / 50)
- 当前
unconditional_guidance_scale=1.0时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向
3.2 DiffusionWrapper (条件路由器)
代码位置: src/unifolm_wma/models/ddpms.py:2413-2524
作用: 将输入和条件路由到内部扩散模型
实际代码 (第2469-2479行):
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) # 拼接latent条件
cc = torch.cat(c_crossattn, 1) # 拼接cross-attention条件
cc_action = c_crossattn_action
out = self.diffusion_model(xc, x_action, x_state, t,
context=cc, context_action=cc_action, **kwargs)
条件类型:
- c_concat: 通道拼接条件 (VAE编码的图像)
- c_crossattn: 交叉注意力条件 (文本、图像、状态、动作embedding)
- c_crossattn_action: 动作头专用条件
3.3 WMAModel (核心扩散模型)
代码位置: src/unifolm_wma/modules/networks/wma_model.py:326-849
配置文件: configs/inference/world_model_interaction.yaml:69-104
实际配置参数:
in_channels: 8 # 输入通道 (4 latent + 4 VAE条件)
out_channels: 4 # 输出通道
model_channels: 320 # 基础通道数
channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280]
num_res_blocks: 2 # 每个分辨率2个ResBlock
attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力
num_head_channels: 64 # 每个注意力头64通道
transformer_depth: 1 # Transformer深度
context_dim: 1024 # 交叉注意力上下文维度
temporal_length: 16 # 时间序列长度
架构层次 (详见附录A.1):
- 4个下采样阶段 (每阶段2个ResBlock + Attention)
- 1个中间块 (2个ResBlock + Attention)
- 3个上采样阶段 (每阶段2个ResBlock + Attention)
- 总计: 16个ResBlock + 32个Transformer
3.4 VAE (变分自编码器)
代码位置: src/unifolm_wma/models/autoencoder.py
配置文件: configs/inference/world_model_interaction.yaml:159-180
实际配置参数:
AutoencoderKL:
embed_dim: 4 # Latent维度
z_channels: 4 # Latent通道数
in_channels: 3 # RGB输入
out_ch: 3 # RGB输出
ch: 128 # 基础通道数
ch_mult: [1, 2, 4, 4] # 通道倍增: [128, 256, 512, 512]
num_res_blocks: 2 # 每层2个ResBlock
attn_resolutions: [] # VAE中不使用注意力
编码/解码过程:
# 编码: [B, 3, 320, 512] → [B, 4, 40, 64] (8×8下采样)
z = model.encode_first_stage(img)
# 解码: [B, 4, 40, 64] → [B, 3, 320, 512]
video = model.decode_first_stage(samples)
性能数据:
- VAE编码: 0.90s (22次, 平均0.041s/次)
- VAE解码: 12.44s (22次, 平均0.566s/次)
- 压缩比: 8×8 = 64倍空间压缩
详细架构: 见附录A.4
3.5 条件编码器
性能说明: 本次 profiling 未对各条件编码器单独计时,统一计入 synthesis/conditioning_prep,总计 2.92s (22次, 平均0.133s/次)。
3.5.1 CLIP图像编码器
代码位置: src/unifolm_wma/modules/encoders/condition.py - FrozenOpenCLIPImageEmbedderV2
配置文件: configs/inference/world_model_interaction.yaml:188-204
实际配置:
FrozenOpenCLIPImageEmbedderV2:
freeze: true
# 使用OpenCLIP ViT-H/14
# 输出: [B, 1280]
Resampler (图像投影器):
dim: 1024 # 输出维度
depth: 4 # Transformer深度
heads: 12 # 12个注意力头
num_queries: 16 # 16个查询token
embedding_dim: 1280 # CLIP输出维度
数据流: 图像 [B, 3, H, W] → CLIP → [B, 1280] → Resampler → [B, 16, 1024]
3.5.2 文本编码器
代码位置: src/unifolm_wma/modules/encoders/condition.py - FrozenOpenCLIPEmbedder
配置文件: configs/inference/world_model_interaction.yaml:182-186
实际配置:
FrozenOpenCLIPEmbedder:
freeze: True
layer: "penultimate" # 使用倒数第二层
# 输出: [B, seq_len, 1024]
3.5.3 状态投影器
代码位置: src/unifolm_wma/models/ddpms.py:2014-2026 - MLPProjector
MLPProjector实现 (src/unifolm_wma/utils/projector.py:14-37):
class MLPProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(output_dim, output_dim, bias=True),
)
数据流: 状态 [B, T_obs, 16] → MLPProjector → [B, T_obs, 1024] + agent_state_pos_emb
3.5.4 动作投影器
代码位置: src/unifolm_wma/models/ddpms.py:2020-2024 - MLPProjector
数据流: 动作 [B, T_action, 16] → MLPProjector → [B, T_action, 1024] + agent_action_pos_emb
位置嵌入定义:
# ddpms.py:2023-2026
self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))
4. 性能瓶颈分析
4.1 时间分布 (profiling 报告, 11次迭代 / 22次采样)
说明: profile_section 存在嵌套,宏观统计的总和不是 wall time,以下以每段的 total/avg 为准。
| Section | Count | Total(s) | Avg(s) | 说明 |
|---|---|---|---|---|
| iteration_total | 11 | 836.79 | 76.07 | 单次迭代总耗时 |
| action_generation | 11 | 399.71 | 36.34 | 生成动作 (DDIM 50步) |
| world_model_interaction | 11 | 398.38 | 36.22 | 世界模型交互 (DDIM 50步) |
| synthesis/ddim_sampling | 22 | 782.70 | 35.58 | 单次采样 |
| synthesis/conditioning_prep | 22 | 2.92 | 0.13 | 条件编码汇总 |
| synthesis/decode_first_stage | 22 | 12.44 | 0.57 | VAE解码 |
| save_results | 11 | 38.67 | 3.52 | I/O保存 |
| model_loading/config | 1 | 49.77 | 49.77 | 一次性开销 |
| model_loading/checkpoint | 1 | 11.83 | 11.83 | 一次性开销 |
| model_to_cuda | 1 | 8.91 | 8.91 | 一次性开销 |
4.2 DDIM采样详细分析
DDIM采样是主要瓶颈 (基于 50 步采样):
- 采样调用次数: 22 次 (11 次迭代 × 2 阶段)
- 采样总耗时: 782.70s,平均 35.58s/次
- 平均每步耗时: ~0.712s (35.58s / 50)
unconditional_guidance_scale=1.0时每步 1 次 UNet 前向;开启 CFG 时每步 2 次前向- 在
conditioning_prep + ddim_sampling + decode_first_stage中,ddim_sampling 占约 98%
4.3 瓶颈总结
关键发现:
- DDIM采样占比最高 - 单次迭代平均 76.07s,其中采样约 71.15s (≈93%)
- CUDA算子时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%);Attention 约 3.0%
- CPU侧 copy/to 仍明显 (
aten::copy_,aten::to/_to_copy在报告中耗时靠前) - VAE解码为次级瓶颈 (0.57s/次)
4.4 GPU显存概览
- Peak allocated: 17890.50 MB
- Average allocated: 16129.98 MB
5. 内核融合优化建议
5.1 优化策略概览
基于性能分析,优化应聚焦于:
- UNet3D模型前向传播 (最高优先级)
- VAE解码器 (次要优先级)
- 批处理和并行化 (辅助优化)
5.2 WMAModel内核融合机会
5.2.1 时间步嵌入融合
代码位置: src/unifolm_wma/utils/diffusion.py - timestep_embedding()
当前实现 (实际代码):
# 1. 正弦位置编码
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half) / half)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
# 2. 时间嵌入网络 (在WMAModel.__init__中)
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim), # 320 → 1280
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim), # 1280 → 1280
)
融合机会:
Linear + SiLU + Linear可融合为单个kernel- 正弦编码计算可与第一个Linear融合
预期收益: 减少2-3次kernel启动开销
5.2.2 ResBlock内核融合
代码位置: src/unifolm_wma/modules/networks/wma_model.py:130-263 - class ResBlock
当前实现 (实际代码):
# in_layers: GroupNorm + SiLU + Conv
self.in_layers = nn.Sequential(
normalization(channels), # GroupNorm
nn.SiLU(),
conv_nd(dims, channels, out_channels, 3, padding=1)
)
# emb_layers: SiLU + Linear
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_channels, out_channels)
)
# out_layers: GroupNorm + SiLU + Dropout + Conv
self.out_layers = nn.Sequential(
normalization(out_channels), # GroupNorm
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
)
融合机会:
GroupNorm + SiLU可融合 (in_layers和out_layers各一次)emb_layers的SiLU + Linear可融合- 残差加法可与下一层的GroupNorm融合
实际瓶颈: 16个ResBlock × 50步 × 2次 = 1600次ResBlock调用
预期收益: 每个ResBlock节省50-60%的kernel启动开销
5.2.3 注意力机制优化
代码位置: src/unifolm_wma/modules/attention.py
实际配置:
- SpatialTransformer: 空间维度注意力
- TemporalTransformer: 时间维度注意力
- 总计: 32个Transformer × 50步 × 2次 = 3200次注意力调用
优化方案: 使用 PyTorch 内置的 Flash Attention:
from torch.nn.functional import scaled_dot_product_attention
# 替换标准注意力计算
out = scaled_dot_product_attention(Q, K, V, is_causal=False)
预期收益: 注意力层加速2-3倍,整体加速30-40%
5.3 VAE解码器优化
代码位置: src/unifolm_wma/models/autoencoder.py
当前性能: 12.44s (22次调用, 平均0.566s/次)
优化方案:
-
混合精度: 使用FP16进行解码
with torch.cuda.amp.autocast(): video = vae.decode(latent) -
批处理优化: 确保VAE解码使用批处理而非逐帧
预期收益: 加速20-30%
5.4 实施建议
5.4.1 使用 torch.compile() (最简单)
代码位置: scripts/evaluation/world_model_interaction.py
实际实施位置: 在模型加载后添加:
# 在模型加载并移动到GPU后添加
config = OmegaConf.load(args.config)
model = instantiate_from_config(config.model)
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
model = model.cuda()
# 添加 torch.compile() 优化
model.model.diffusion_model = torch.compile(
model.model.diffusion_model,
mode='max-autotune', # 或 'reduce-overhead'
fullgraph=True
)
优点:
- 无需修改模型代码
- 自动融合操作
- 支持动态shape
预期收益: 20-40%加速
5.4.2 使用 Flash Attention
代码位置: src/unifolm_wma/modules/attention.py
当前实现分析:
代码已经支持 xformers (xformers.ops.memory_efficient_attention)。当 xformers 可用时,CrossAttention 类会自动使用 efficient_forward 方法:
# attention.py:90-91
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
进一步优化方案: 如果 xformers 不可用,可以使用 PyTorch 内置的 Flash Attention:
from torch.nn.functional import scaled_dot_product_attention
out = scaled_dot_product_attention(q, k, v, is_causal=False)
预期收益: 注意力层加速2-3倍
5.4.3 混合精度推理
代码位置: scripts/evaluation/world_model_interaction.py
实际实施位置: 在推理调用处添加混合精度上下文:
# 在 image_guided_synthesis_sim_mode 调用处添加
with torch.cuda.amp.autocast():
pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(
model, sample['instruction'], observation, noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale, sim_mode=False)
注意事项:
- 模型会自动在FP16和FP32之间切换
- 某些操作(如LayerNorm)会自动保持FP32精度
- 无需手动转换模型权重
预期收益: 30-50%加速 + 减少50%显存
5.5 优化路线图
阶段1: 快速优化
目标: 获得20-40%加速,无需修改模型代码
实施步骤:
- 启用
torch.compile()- 在模型加载后添加 - 启用
torch.backends.cudnn.benchmark = True- 在推理开始前设置 - 使用混合精度推理 (FP16) - 在推理调用处添加
实施代码:
# 在推理函数开始处添加
torch.backends.cudnn.benchmark = True
# 在模型加载后添加 torch.compile()
model = model.cuda()
model.model.diffusion_model = torch.compile(
model.model.diffusion_model,
mode='max-autotune'
)
# 在推理循环中使用混合精度
with torch.cuda.amp.autocast():
pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
阶段2: 中级优化
目标: 获得50-70%加速
实施步骤:
- 确保 xformers 已安装并启用 - 检查
XFORMERS_IS_AVAILBLE标志 - 优化VAE解码器批处理 - 检查 src/unifolm_wma/models/autoencoder.py 中的
decode()方法 - 分析并优化内存访问模式 - 使用
torch.cuda.memory_stats()分析
关键修改点:
- 确认 xformers 已正确安装:
pip install xformers - 在
CrossAttention类中,当 xformers 可用时会自动使用efficient_forward - 确保VAE解码使用批处理而非逐帧处理
阶段3: 深度优化
目标: 获得2-3倍加速
实施步骤:
- 自定义CUDA kernel融合关键操作
- 融合 GroupNorm + SiLU + Conv (在 src/unifolm_wma/modules/networks/wma_model.py:130-263 的 ResBlock 中)
- 优化卷积操作
- 分析 Conv2D 操作的性能 (模型实际使用 Conv2D 而非 Conv3D)
- 优化数据加载和预处理pipeline
需要的技能:
- CUDA编程
- PyTorch C++扩展
- 性能分析工具 (Nsight Systems, nvprof)
5.6 预期总体收益
基于以上优化策略和实际代码分析,预期性能提升:
| 优化阶段 | 预期加速比 | 实施难度 | 主要修改文件 |
|---|---|---|---|
| 阶段1: 快速优化 | 1.2-1.4x | 低 | scripts/evaluation/world_model_interaction.py |
| 阶段2: 中级优化 | 1.5-1.7x | 中 | src/unifolm_wma/modules/attention.py, src/unifolm_wma/models/autoencoder.py |
| 阶段3: 深度优化 | 2.0-3.0x | 高 | src/unifolm_wma/modules/networks/wma_model.py + 自定义CUDA kernel |
总体目标: 通过系统性优化实现 2-3倍加速
6. 关键代码位置索引
为方便内核融合实施,以下是关键代码位置:
6.1 核心模型文件
| 组件 | 文件路径 | 关键类/函数 |
|---|---|---|
| DDPM主模型 | src/unifolm_wma/models/ddpms.py |
class DDPM |
| 条件包装器 | src/unifolm_wma/models/ddpms.py:2413 |
class DiffusionWrapper |
| DDIM采样器 | src/unifolm_wma/models/samplers/ddim.py |
class DDIMSampler |
| VAE编解码 | src/unifolm_wma/models/autoencoder.py |
encode_first_stage, decode_first_stage |
6.2 推理脚本
| 文件 | 说明 |
|---|---|
scripts/evaluation/world_model_interaction.py |
推理脚本 |
6.3 配置文件
| 文件 | 说明 |
|---|---|
configs/inference/world_model_interaction.yaml |
推理配置 |
unitree_g1_pack_camera/case1/run_world_model_interaction.sh |
运行脚本 |
7. 下一步行动建议
7.1 立即可执行的优化
最小改动,最大收益:
-
启用 torch.compile()
- 代码位置: scripts/evaluation/world_model_interaction.py
- 修改: 在模型加载后添加
# 在模型加载并移动到GPU后添加 model.model.diffusion_model = torch.compile( model.model.diffusion_model, mode='max-autotune' ) -
启用 cuDNN benchmark
- 代码位置: scripts/evaluation/world_model_interaction.py
- 修改: 在推理函数开始处添加
torch.backends.cudnn.benchmark = True -
混合精度推理
- 代码位置: scripts/evaluation/world_model_interaction.py
- 修改: 在
image_guided_synthesis_sim_mode调用处添加
with torch.cuda.amp.autocast(): pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
预期收益: 20-40%加速,无风险
7.2 需要深入探索的部分
为了更精确的优化,建议进一步分析:
-
注意力层的具体实现
- 代码位置: src/unifolm_wma/modules/attention.py
- 分析目标:
CrossAttention类 (第48-398行) - 核心注意力实现BasicTransformerBlock类 (第400-469行) - Transformer块- 确认 xformers 是否已启用 (
XFORMERS_IS_AVAILBLE标志)
-
ResBlock的详细结构
- 代码位置: src/unifolm_wma/modules/networks/wma_model.py:130-263
- 分析目标:
- 确认 GroupNorm + SiLU + Conv 的调用顺序
- 识别可以融合的操作序列
- 评估自定义 CUDA kernel 的可行性
-
内存瓶颈分析
- 分析工具: 使用
torch.cuda.memory_stats()和torch.profiler - 分析位置: scripts/evaluation/world_model_interaction.py
- 分析目标:
- 识别内存拷贝热点
- 优化中间张量的生命周期
- 减少不必要的内存分配
- 分析工具: 使用
-
计算瓶颈定位
- 分析工具: Nsight Systems 或 PyTorch Profiler
- 分析目标:
- 识别 kernel 启动开销
- 分析 GPU 利用率
- 找到计算密集型操作
8. 参考资料
8.1 优化技术文档
8.2 相关论文
- DDIM: Denoising Diffusion Implicit Models
- Flash Attention: Fast and Memory-Efficient Exact Attention
- Efficient Diffusion Models for Vision
9. 总结
9.1 关键发现
- DDIM采样仍是主要瓶颈 - 单次采样(50步)平均 35.58s
- Linear/GEMM 与 Convolution 为主要 CUDA 时间来源 - Attention 占比相对较小
- VAE解码为次级优化目标 - 0.57s/次
9.2 优化优先级
高优先级 (立即实施):
- ✅ torch.compile()
- ✅ cuDNN benchmark
- ✅ 混合精度推理
中优先级 (1周内):
- Flash Attention集成
- VAE批处理优化
低优先级 (长期):
- 自定义CUDA kernel
- 模型架构改进
9.3 预期成果
通过系统性优化,预期可获得 1.5-3倍加速 (视采样步数与编译/混合精度策略而定)。
文档版本: v1.2 创建日期: 2026-01-17 最后更新: 2026-01-18 更新内容: 校准DDIM步数为50并替换为最新profiling数据
附录A: 实际模型架构详解
基于代码分析,以下是真实的模型实现细节。
A.1 WMAModel 实际配置
配置文件: configs/inference/world_model_interaction.yaml:69-104
WMAModel参数:
in_channels: 8 # 输入通道 (4 latent + 4 concat条件)
out_channels: 4 # 输出通道 (latent空间)
model_channels: 320 # 基础通道数
channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280]
num_res_blocks: 2 # 每个分辨率2个ResBlock
attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力
num_head_channels: 64 # 每个注意力头64通道
transformer_depth: 1 # Transformer深度
context_dim: 1024 # 交叉注意力上下文维度
temporal_length: 16 # 时间序列长度
dropout: 0.1
架构层次:
输入: [B, 8, 16, 40, 64] (8通道 = 4 latent + 4 VAE条件)
↓
下采样路径 (4个阶段):
Stage 0: [B, 320, 16, 40, 64] - 2个ResBlock + SpatialTransformer + TemporalTransformer
Stage 1: [B, 640, 16, 20, 32] - Downsample + 2个ResBlock + Attention
Stage 2: [B, 1280, 16, 10, 16] - Downsample + 2个ResBlock + Attention
Stage 3: [B, 1280, 16, 5, 8] - Downsample + 2个ResBlock + Attention
↓
中间块: [B, 1280, 16, 5, 8]
- ResBlock + SpatialTransformer + TemporalTransformer + ResBlock
↓
上采样路径 (3个阶段):
Stage 2: [B, 1280, 16, 10, 16] - Upsample + 2个ResBlock + Attention
Stage 1: [B, 640, 16, 20, 32] - Upsample + 2个ResBlock + Attention
Stage 0: [B, 320, 16, 40, 64] - Upsample + 2个ResBlock + Attention
↓
输出: [B, 4, 16, 40, 64] (预测的噪声或速度)
A.2 ResBlock 实际实现
位置: src/unifolm_wma/modules/networks/wma_model.py:130-263
实际代码结构:
class ResBlock:
def __init__(self, channels, emb_channels, dropout, ...):
# 输入层: GroupNorm + SiLU + Conv
self.in_layers = nn.Sequential(
normalization(channels), # GroupNorm
nn.SiLU(), # 激活函数
conv_nd(dims, channels, out_channels, 3, padding=1)
)
# 时间步嵌入层: SiLU + Linear
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_channels, out_channels)
)
# 输出层: GroupNorm + SiLU + Dropout + Conv
self.out_layers = nn.Sequential(
normalization(out_channels), # GroupNorm
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
)
# 残差连接
self.skip_connection = ...
# 可选的时间卷积
if use_temporal_conv:
self.temporal_conv = TemporalConvBlock(...)
def forward(self, x, emb):
h = self.in_layers(x) # GroupNorm + SiLU + Conv
emb_out = self.emb_layers(emb) # 时间步嵌入
h = h + emb_out # 加入时间步信息
h = self.out_layers(h) # GroupNorm + SiLU + Dropout + Conv
h = self.skip_connection(x) + h # 残差连接
if use_temporal_conv:
h = self.temporal_conv(h) # 时间维度卷积
return h
内核融合机会:
GroupNorm + SiLU可融合 (in_layers和out_layers各一次)emb_layers(SiLU + Linear)可融合残差加法 + 下一层的GroupNorm可融合
A.3 注意力机制实际实现
SpatialTransformer (空间注意力):
- 位置:
src/unifolm_wma/modules/attention.py:472-558 - 在特征图的空间维度(H×W)上执行自注意力和交叉注意力
- 使用
transformer_depth=1,即每个位置1层Transformer - 当 xformers 可用时,使用
efficient_forward方法进行高效注意力计算
TemporalTransformer (时间注意力):
- 位置:
src/unifolm_wma/modules/attention.py:561-680 - 在时间维度(T=16帧)上执行自注意力
- 配置:
temporal_selfatt_only=True(仅时间自注意力,不做交叉注意力) - 使用相对位置编码:
use_relative_position=False(实际未启用)
CrossAttention (核心注意力层):
- 位置:
src/unifolm_wma/modules/attention.py:48-398 - 支持多种交叉注意力: 图像、文本、状态、动作
- 当 xformers 可用时自动使用
xformers.ops.memory_efficient_attention
注意力头配置:
num_head_channels=64: 每个头64通道- 对于320通道: 320/64 = 5个注意力头
- 对于640通道: 640/64 = 10个注意力头
- 对于1280通道: 1280/64 = 20个注意力头
A.4 VAE 实际配置
配置文件: configs/inference/world_model_interaction.yaml:159-180
AutoencoderKL:
embed_dim: 4 # Latent维度
z_channels: 4 # Latent通道数
resolution: 256 # 基础分辨率
in_channels: 3 # RGB输入
out_ch: 3 # RGB输出
ch: 128 # 基础通道数
ch_mult: [1, 2, 4, 4] # 通道倍增
num_res_blocks: 2 # 每层2个ResBlock
attn_resolutions: [] # VAE中不使用注意力
dropout: 0.0
编码器架构:
输入: [B, 3, 320, 512]
↓ Conv 3→128
↓ ResBlock×2 [128, 320, 512]
↓ Downsample [128, 160, 256]
↓ ResBlock×2 [256, 160, 256]
↓ Downsample [256, 80, 128]
↓ ResBlock×2 [512, 80, 128]
↓ Downsample [512, 40, 64]
↓ ResBlock×2 [512, 40, 64]
↓ ResBlock + Conv
输出: [B, 4, 40, 64] (8×8下采样)
解码器架构 (编码器的镜像):
输入: [B, 4, 40, 64]
↓ Conv + ResBlock
↓ ResBlock×2 [512, 40, 64]
↓ Upsample [512, 80, 128]
↓ ResBlock×2 [512, 80, 128]
↓ Upsample [256, 160, 256]
↓ ResBlock×2 [256, 160, 256]
↓ Upsample [128, 320, 512]
↓ ResBlock×2 [128, 320, 512]
↓ Conv 128→3
输出: [B, 3, 320, 512]
A.5 条件编码器实际配置
CLIP图像编码器
配置: configs/inference/world_model_interaction.yaml:188-191
FrozenOpenCLIPImageEmbedderV2:
freeze: true
# 使用OpenCLIP的ViT-H/14模型
# 输出维度: 1280
图像投影器 (Resampler):
Resampler:
dim: 1024 # 输出维度
depth: 4 # Transformer深度
dim_head: 64 # 注意力头维度
heads: 12 # 12个注意力头
num_queries: 16 # 16个查询token
embedding_dim: 1280 # CLIP输出维度
output_dim: 1024 # 最终输出维度
video_length: 16 # 视频长度
数据流:
图像 [B, 3, H, W]
↓ CLIP Encoder
↓ [B, 1280]
↓ Resampler (Perceiver-style)
↓ [B, 16, 1024] (16个token,每个1024维)
文本编码器
配置: configs/inference/world_model_interaction.yaml:182-186
FrozenOpenCLIPEmbedder:
freeze: True
layer: "penultimate" # 使用倒数第二层
# 输出维度: 1024
数据流:
文本指令 "pick up the box"
↓ OpenCLIP Text Encoder
↓ [B, seq_len, 1024]
动作/状态投影器
代码位置: src/unifolm_wma/models/ddpms.py:2014-2026
MLPProjector实现 (src/unifolm_wma/utils/projector.py:14-37):
class MLPProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(output_dim, output_dim, bias=True),
)
初始化代码 (ddpms.py:2014-2026):
# 状态投影器
self.state_projector = MLPProjector(agent_state_dim, 1024) # 16 → 1024
self.action_projector = MLPProjector(agent_action_dim, 1024) # 16 → 1024
# 位置嵌入
self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))
数据流:
状态 [B, T_obs, 16]
↓ MLPProjector (Linear + GELU + Linear)
↓ [B, T_obs, 1024]
↓ + agent_state_pos_emb
↓ [B, T_obs, 1024]
动作 [B, T_action, 16]
↓ MLPProjector (Linear + GELU + Linear)
↓ [B, T_action, 1024]
↓ + agent_action_pos_emb
↓ [B, T_action, 1024]
A.6 时间步嵌入实际实现
位置: src/unifolm_wma/utils/diffusion.py:timestep_embedding
实际代码:
def timestep_embedding(timesteps, dim, max_period=10000):
"""
创建正弦位置编码
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return embedding
在WMAModel中的使用:
# 时间步 t ∈ [0, 999]
t_emb = timestep_embedding(t, model_channels) # [B, 320]
t_emb = self.time_embed(t_emb) # Linear(320 → 1280)
# 输出: [B, 1280]
时间嵌入网络:
t ∈ [0, 999]
↓ timestep_embedding (正弦编码)
↓ [B, 320]
↓ Linear(320 → 1280)
↓ SiLU
↓ Linear(1280 → 1280)
↓ [B, 1280]
A.7 动作头 (ConditionalUnet1D) 实际配置
配置: configs/inference/world_model_interaction.yaml:106-127
ConditionalUnet1D:
input_dim: 16 # 动作维度
n_obs_steps: 2 # 观测步数
diffusion_step_embed_dim: 128 # 扩散步嵌入维度
down_dims: [256, 512, 1024, 2048] # 下采样通道
kernel_size: 5 # 卷积核大小
n_groups: 8 # GroupNorm分组数
horizon: 16 # 预测时间范围
use_linear_attn: true # 使用线性注意力
imagen_cond_gradient: true # 使用图像条件梯度
架构:
输入: 噪声动作 [B, 16, 16] (16维动作 × 16步)
条件:
- 图像特征 [B, C, H, W] 来自WMAModel中间层
- 观测编码 [B, n_obs, obs_dim]
↓ Conv1D + ResBlock
↓ [B, 256, 16]
↓ Downsample + ResBlock
↓ [B, 512, 8]
↓ Downsample + ResBlock
↓ [B, 1024, 4]
↓ Downsample + ResBlock
↓ [B, 2048, 2]
↓ Middle Block (with attention)
↓ Upsample + ResBlock
↓ [B, 1024, 4]
↓ Upsample + ResBlock
↓ [B, 512, 8]
↓ Upsample + ResBlock
↓ [B, 256, 16]
↓ Conv1D
输出: 预测噪声 [B, 16, 16]
A.8 完整前向传播流程
基于实际代码,完整的前向传播流程如下:
# 1. 条件编码 (一次性完成,可缓存)
cond_img_emb = clip_encoder(img) → resampler → [B, 16, 1024]
cond_text_emb = text_encoder(text) → [B, seq_len, 1024]
cond_state_emb = state_projector(state) + pos_emb → [B, T_obs, 1024]
cond_action_emb = action_projector(action) + pos_emb → [B, T_action, 1024]
cond_latent = vae.encode(img) → [B, 4, T, 40, 64]
# 2. 拼接条件
c_concat = [cond_latent] # 通道拼接
c_crossattn = [cond_text_emb, cond_img_emb, cond_state_emb, cond_action_emb]
c_crossattn = torch.cat(c_crossattn, dim=1) # [B, total_tokens, 1024]
# 3. DDIM采样循环 (ddim_steps 默认 50,实际由 --ddim_steps 控制)
x = torch.randn([B, 4, 16, 40, 64]) # 初始噪声
for step in range(ddim_steps):
# 3.1 时间步嵌入
t_emb = timestep_embedding(t, 320) → Linear → [B, 1280]
# 3.2 拼接输入
x_in = torch.cat([x, cond_latent], dim=1) # [B, 8, 16, 40, 64]
# 3.3 UNet前向传播 (核心瓶颈)
noise_pred = wma_model(x_in, t_emb, c_crossattn)
# 包含: 4个下采样阶段 + 中间块 + 3个上采样阶段
# 每个阶段: 2个ResBlock + SpatialTransformer + TemporalTransformer
# 3.4 DDIM更新
x = ddim_update(x, noise_pred, t, t_prev)
# 4. VAE解码
video = vae.decode(x) → [B, 3, 16, 320, 512]
A.9 基于实际架构的优化建议更新
我的理解:
本次 profiling 显示采样阶段占据绝对主导地位:单次采样(50步)平均 35.58s,且每次迭代包含 action_generation 与 world_model_interaction 各一次采样。换句话说,任何“每步的细微改进”都会被 50 步和 2 阶段放大;因此最有效的优化要么减少步数,要么显著加速 UNet 前向。CUDA 时间主要集中在 Linear/GEMM(29.8%) 与 Convolution(13.9%),而 Attention 约 3.0%,这意味着算子层面优先考虑矩阵乘/卷积路径的优化收益更稳定。CPU 侧 aten::copy_/to/_to_copy 也明显,说明循环内的数据搬运仍有成本可省。
优化点1: 采样步数与采样器 (最高优先级)
依据:
- 50步采样平均 35.58s (0.712s/步),减少步数带来近线性收益
- 单次迭代约 76.07s,其中采样约占 93%
建议:
- 在保证质量的前提下,将
--ddim_steps从 50 降到 20-30 - 评估更快采样器(如 DPM-Solver++/UniPC)以减少步数
- 若使用 CFG,注意
unconditional_guidance_scale > 1.0会使每步前向翻倍
优化点2: GEMM/Conv 主导路径加速 (高优先级)
依据:
- CUDA 时间主力来自 Linear/GEMM 与 Convolution
建议:
torch.compile()仅包裹 UNet 主干以获得融合收益- 启用混合精度 (
autocast) + TF32 (torch.backends.cuda.matmul.allow_tf32 = True) - 固定输入形状时开启
torch.backends.cudnn.benchmark = True
优化点3: ResBlock融合 (中高优先级)
实际瓶颈:
- 每个DDIM步骤调用UNet一次
- UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段
- 每个阶段有2个ResBlock
- 总计: 16个ResBlock × 50步 × 2次(阶段1+2) = 1600次ResBlock调用
融合机会:
# 当前: 6次kernel启动
h = group_norm(x) # kernel 1
h = silu(h) # kernel 2
h = conv2d(h) # kernel 3
h = group_norm(h) # kernel 4
h = silu(h) # kernel 5
h = conv2d(h) # kernel 6
out = x + h # kernel 7
# 优化后: 2-3次kernel启动
h = fused_norm_silu_conv(x) # kernel 1 (融合)
h = fused_norm_silu_conv(h) # kernel 2 (融合)
out = fused_residual_add(x, h) # kernel 3 (融合)
预期收益: 每个ResBlock节省50-60%的kernel启动开销
优化点4: 注意力机制优化 (中优先级)
实际配置:
- SpatialTransformer: 在每个阶段的每个ResBlock后
- TemporalTransformer: 在每个阶段的每个ResBlock后
- 总计: 16个Spatial + 16个Temporal = 32个Transformer × 50步 × 2次 = 3200次注意力调用
理解: Attention 在算子占比中只有约 3%,不是当前主要瓶颈,但若未启用高效实现仍可获得稳定收益。
优化方案:
- 确认 xformers 已启用 (
XFORMERS_IS_AVAILBLE为 True) - 无 xformers 时替换为 PyTorch SDPA:
from torch.nn.functional import scaled_dot_product_attention
out = scaled_dot_product_attention(Q, K, V, is_causal=False)
优化点5: 数据搬运与 CPU 开销 (中优先级)
依据:
aten::copy_与aten::to/_to_copy在 CPU 侧耗时突出
建议:
- 避免在 DDIM 循环内重复
.to(device)/.float()/.half() - 将常量张量(如 timestep、sigma)提前放到 GPU
- 尽量减少临时张量创建与
clone(),尤其是 per-step 级别
优化点6: VAE 解码 (低优先级)
依据:
- VAE 解码 0.57s/次,次于采样瓶颈
建议:
- 统一使用
autocast解码 - 若可容忍轻微质量下降,可降低解码频率或分辨率