Files
lewm/sth.md
qihuanye 995cd8cfec 优化 jepa.py 中通用 rollout 热路径:批量预编码动
作、移除循环内
  torch.cat,并为 history_size==1 与环形缓冲区更新
  添加更轻量实现; 收益不大
2026-04-09 11:57:09 +00:00

52 lines
1.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
我建议优先做这 4 类,都是跨数据集成立的:
1. 压 rollout 内环实现
见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规
模 predict 调用,这种碎片化实现对任何任务都亏。
通用改法:
- 整条 action_sequence 一次性做 action_encoder
- emb_hist / act_emb_hist 改成预分配 buffer
- 循环里只做索引覆盖或 copy_
- 去掉循环内 torch.cat
2. 减少热路径里的搬运和同步
profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看
jepa.py:67 和 jepa.py:186。
通用目标:
- 模型侧张量尽量全程留在 GPU
- 避免热路径反复 .to(device) / 隐式 layout 修复
- 到必须和环境交互的边界再一次性转 CPU / numpy
- 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
3. 把编译成本移出正式计时
现在 torch.compile 默认开在 predictor见 eval.py:70。102s -> 45s 很
像首轮编译预热。
通用做法:
- 在正式 start_time 前做一次 dummy predict 或 dummy rollout
- 保留只编译 predictor/predict不要编译整个 solver
4. 减少临时对象和 shape bookkeeping
这是所有任务都会受益的。
重点看:
- jepa.py:100 到 jepa.py:106
- jepa.py:143 到 jepa.py:148
方向是:
- 能循环外做的 reshape不放循环里
- 能原地更新,不新建张量
- 少做 dict 字段增删和中间容器组装
不建议优先做的通用性较差方案:
- 调 TwoRoom 专属 cache 规则
- 改数据集采样逻辑
- 按小数据集特点缩短 horizon
- 直接改 CEM 超参当“优化”
如果你要我直接开始改,我建议第一批只做两件事:
- 重写 jepa.py:127 这段 rollout去掉循环内 action_encoder + cat
- 在 eval.py:306 前加 compile warmup