Files
lewm/sth.md
qihuanye 995cd8cfec 优化 jepa.py 中通用 rollout 热路径:批量预编码动
作、移除循环内
  torch.cat,并为 history_size==1 与环形缓冲区更新
  添加更轻量实现; 收益不大
2026-04-09 11:57:09 +00:00

1.9 KiB
Raw Blame History

我建议优先做这 4 类,都是跨数据集成立的:

  1. 压 rollout 内环实现 见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规 模 predict 调用,这种碎片化实现对任何任务都亏。 通用改法:
  • 整条 action_sequence 一次性做 action_encoder
  • emb_hist / act_emb_hist 改成预分配 buffer
  • 循环里只做索引覆盖或 copy_
  • 去掉循环内 torch.cat
  1. 减少热路径里的搬运和同步 profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看 jepa.py:67 和 jepa.py:186。 通用目标:
  • 模型侧张量尽量全程留在 GPU
  • 避免热路径反复 .to(device) / 隐式 layout 修复
  • 到必须和环境交互的边界再一次性转 CPU / numpy
  • 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
  1. 把编译成本移出正式计时 现在 torch.compile 默认开在 predictor见 eval.py:70。102s -> 45s 很 像首轮编译预热。 通用做法:
  • 在正式 start_time 前做一次 dummy predict 或 dummy rollout
  • 保留只编译 predictor/predict不要编译整个 solver
  1. 减少临时对象和 shape bookkeeping 这是所有任务都会受益的。 重点看:
  • jepa.py:100 到 jepa.py:106
  • jepa.py:143 到 jepa.py:148 方向是:
  • 能循环外做的 reshape不放循环里
  • 能原地更新,不新建张量
  • 少做 dict 字段增删和中间容器组装

不建议优先做的通用性较差方案:

  • 调 TwoRoom 专属 cache 规则
  • 改数据集采样逻辑
  • 按小数据集特点缩短 horizon
  • 直接改 CEM 超参当“优化”

如果你要我直接开始改,我建议第一批只做两件事:

  • 重写 jepa.py:127 这段 rollout去掉循环内 action_encoder + cat
  • 在 eval.py:306 前加 compile warmup