1.9 KiB
1.9 KiB
我建议优先做这 4 类,都是跨数据集成立的:
- 压 rollout 内环实现 见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规 模 predict 调用,这种碎片化实现对任何任务都亏。 通用改法:
- 整条 action_sequence 一次性做 action_encoder
- emb_hist / act_emb_hist 改成预分配 buffer
- 循环里只做索引覆盖或 copy_
- 去掉循环内 torch.cat
- 减少热路径里的搬运和同步 profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看 jepa.py:67 和 jepa.py:186。 通用目标:
- 模型侧张量尽量全程留在 GPU
- 避免热路径反复 .to(device) / 隐式 layout 修复
- 到必须和环境交互的边界再一次性转 CPU / numpy
- 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
- 把编译成本移出正式计时 现在 torch.compile 默认开在 predictor,见 eval.py:70。102s -> 45s 很 像首轮编译预热。 通用做法:
- 在正式 start_time 前做一次 dummy predict 或 dummy rollout
- 保留只编译 predictor/predict,不要编译整个 solver
- 减少临时对象和 shape bookkeeping 这是所有任务都会受益的。 重点看:
- jepa.py:100 到 jepa.py:106
- jepa.py:143 到 jepa.py:148 方向是:
- 能循环外做的 reshape,不放循环里
- 能原地更新,不新建张量
- 少做 dict 字段增删和中间容器组装
不建议优先做的通用性较差方案:
- 调 TwoRoom 专属 cache 规则
- 改数据集采样逻辑
- 按小数据集特点缩短 horizon
- 直接改 CEM 超参当“优化”
如果你要我直接开始改,我建议第一批只做两件事:
- 重写 jepa.py:127 这段 rollout,去掉循环内 action_encoder + cat
- 在 eval.py:306 前加 compile warmup