52 lines
1.9 KiB
Markdown
52 lines
1.9 KiB
Markdown
我建议优先做这 4 类,都是跨数据集成立的:
|
||
|
||
1. 压 rollout 内环实现
|
||
见 jepa.py:127。现在每步都在做 action_encoder、切片、torch.cat、小规
|
||
模 predict 调用,这种碎片化实现对任何任务都亏。
|
||
通用改法:
|
||
|
||
- 整条 action_sequence 一次性做 action_encoder
|
||
- emb_hist / act_emb_hist 改成预分配 buffer
|
||
- 循环里只做索引覆盖或 copy_
|
||
- 去掉循环内 torch.cat
|
||
|
||
2. 减少热路径里的搬运和同步
|
||
profile 里 aten::copy_ 很重,这不是 TwoRoom 特有问题。重点看
|
||
jepa.py:67 和 jepa.py:186。
|
||
通用目标:
|
||
|
||
- 模型侧张量尽量全程留在 GPU
|
||
- 避免热路径反复 .to(device) / 隐式 layout 修复
|
||
- 到必须和环境交互的边界再一次性转 CPU / numpy
|
||
- 确保进入 predictor 的张量是 contiguous 的,少触发隐式 copy
|
||
|
||
3. 把编译成本移出正式计时
|
||
现在 torch.compile 默认开在 predictor,见 eval.py:70。102s -> 45s 很
|
||
像首轮编译预热。
|
||
通用做法:
|
||
|
||
- 在正式 start_time 前做一次 dummy predict 或 dummy rollout
|
||
- 保留只编译 predictor/predict,不要编译整个 solver
|
||
|
||
4. 减少临时对象和 shape bookkeeping
|
||
这是所有任务都会受益的。
|
||
重点看:
|
||
|
||
- jepa.py:100 到 jepa.py:106
|
||
- jepa.py:143 到 jepa.py:148
|
||
方向是:
|
||
- 能循环外做的 reshape,不放循环里
|
||
- 能原地更新,不新建张量
|
||
- 少做 dict 字段增删和中间容器组装
|
||
|
||
不建议优先做的通用性较差方案:
|
||
|
||
- 调 TwoRoom 专属 cache 规则
|
||
- 改数据集采样逻辑
|
||
- 按小数据集特点缩短 horizon
|
||
- 直接改 CEM 超参当“优化”
|
||
|
||
如果你要我直接开始改,我建议第一批只做两件事:
|
||
|
||
- 重写 jepa.py:127 这段 rollout,去掉循环内 action_encoder + cat
|
||
- 在 eval.py:306 前加 compile warmup |