Wrap eval inference in torch.inference_mode

This commit is contained in:
qihuanye
2026-04-09 09:18:35 +00:00
parent 0f85e39690
commit 9e2407cdc4
3 changed files with 791 additions and 12 deletions

25
eval.py
View File

@@ -239,18 +239,19 @@ def run(cfg: DictConfig):
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=output_dir,
)
with torch.inference_mode():
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
metrics = world.evaluate_from_dataset(
dataset,
start_steps=eval_start_idx.tolist(),
goal_offset_steps=cfg.eval.goal_offset_steps,
eval_budget=cfg.eval.eval_budget,
episodes_idx=eval_episodes.tolist(),
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
video_path=output_dir,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()

192
sth.md Normal file
View File

@@ -0,0 +1,192 @@
1. 压 rollout 内环
这条最通用,而且基本不改算法语义,只是把实现做对。
在 jepa.py:129 这段里,当前问题是:
- 循环里每步都 action_encoder(next_act),见 jepa.py:159
- history 每步用 torch.cat 重建,见 jepa.py:155 和 jepa.py:162
- 每步都走一次很短的 predict()host 调度比例很高
通用改法:
- 整条 action_sequence 一次性做 action_encoder
- emb_hist / act_emb_hist 改成预分配 buffer
- 用 ring buffer 或 index rotate 更新历史
- 循环里只做 copy_ / 索引覆盖,不做 cat
这个优化对任何数据集都成立因为它优化的是“rolling inference 实现方
式”,不是任务参数。
2. 用 torch.inference_mode()
你现在在 eval.py:242 这里只用了 autocast没有 inference_mode()。
建议推理主路径外层直接包:
with torch.inference_mode():
with inference_ctx:
...
这是纯通用优化,所有数据集都受益。
3. 只编译 predictor / predict不要编译整个 solver
当前热点是大量小 predict() 调用,不是整条 eval graph。
通用建议:
- 先只编译 self.predictor
- 或只编译 JEPA.predict()
- 模式优先试 reduce-overhead
不要先编译整个 WorldModelPolicy 或 CEM solver那通常图不稳定泛化收益
反而差。
4. 减少循环里的张量形状重排和临时对象
这也是实现层通用优化。
可以继续查:
- rearrange 是否能前移到循环外
- 是否有重复的 slice/view 触发隐式拷贝
- pred_proj(rearrange(...)) 这类 reshape 往返是否能合并
这类优化对所有任务都有效,因为是在降 Python 和 tensor bookkeeping 成
本。
5. 再考虑结构级优化,但放后面
比如 predictor 深度、MLP 宽度、heads 数量。这也通用,但已经开始碰模型容
量和精度,不该是第一刀。
不建议优先做的
这些更偏任务/数据集相关,不算你要的“泛用优化”:
- 先调 num_samples/topk/n_steps
- 先缩 horizon
- 先按 tworoom 特性做 shortcut
- 先针对某个 dataset 做 cache 规则
一句话判断
你现在最像是“算法没错,但 rollout 实现过于碎片化”,所以第一优先级应该
是:
一次性 action encode + 预分配历史 buffer + 去掉循环内 torch.cat +
inference_mode + compile predictor
如果你要,我下一步就直接改 jepa.py 做这套通用优化,不碰任何数据集特化逻
辑。
除开 CEM solve 本体,剩下这些杂项可以这样优化。
最高优先级
1. 保证传给环境的是 numpy不要让 Gym 代转
你日志已经说明 env step() 收到了 torch.Tensor。这会带来拷贝、同步、
checker 额外开销。
做法:
- 在 policy 输出动作、准备喂给 env 的那一层,显式转成
action.detach().cpu().numpy()
- 最好一次性转好,别在 env 内部或 wrapper 内隐式转换
收益:
- 去掉 Gym warning
- 减少同步和类型检查开销
- 通常是最直接的非模型提速点
2. 关掉 Gym passive checker
这些 warning 本身就说明 checker 在持续检查类型和空间匹配。
做法:
- 尽量用禁 checker 的构造方式
- 或在你自己的 env wrapper 里保证输入输出符合 Gym 预期,避免它每步检查
收益:
- 每步少一层 Python 校验
- 对长 episode / 多 episode 累积明显
中优先级
3. 把预处理前移,避免每步重复做能缓存的东西
如果 goal、初始条件、某些 dataset 字段在 episode 里不变,就不要每次
都重新组织。
做法:
- goal 相关 embedding 已经有缓存,继续扩展到更多静态字段
- 固定的 callables 参数尽量预解析
- 能在 episode 开头准备好的,不要放在 step 循环里
收益:
- 降低 Python dict 操作和小张量处理开销
4. 避免频繁 CPU/GPU 来回切
如果模型在 GPU但环境在 CPU就要非常小心中间格式。
做法:
- 模型侧尽量连续留在 GPU
- 到真正 env step 前再一次性转 CPU numpy
- 不要中间反复 .cpu() / .to(device) / np.array(...)
收益:
- 减少隐式同步
- 稳定延迟
5. 缩减 Python 层对象操作
dict 组装、字段拷贝、wrapper 嵌套太多时,端到端会慢。
做法:
- 关键热路径里少做深拷贝
- 少重复构造新的 info / obs 容器
- 固定结构优先原地更新
收益:
- 对小步高频调用路径有效
如果你要继续压评测时间
6. 降低日志和 warning 输出
频繁 warning 会拖慢,而且污染 timing。
做法:
- 修掉类型不匹配后 warning 自然消失
- 非必要的 print 尤其是 step 内 print 要去掉
7. 针对环境 step 做批量化检查
如果 num_envs=50尽量确认 env wrapper 没有在内部退化成逐环境 Python
for-loop。
做法:
- 查 world.evaluate_from_dataset() 到 env step() 之间是不是 batch 接口
- 如果 batch env 里还有逐个 env 转换/检查,尽量前移或向量化
收益:
- 这类经常能解释“为什么 solver 时间之外还有很多时间”
8. 把 callables 的执行成本单独看
你这里有:
- _set_state
- _set_goal_state
- 确认它们只在必要时执行
- 能批量设置就别逐条 Python 调
2. 消掉 Gym warning
3. 单独量 env.step 总时间
4. 检查是否有反复 CPU/GPU 转换
5. 再看 wrapper / callable / obs 组装
一句话总结
剩下的杂项优化,核心不是“再多上几张卡”,而是:
- 去掉隐式类型转换
- 去掉多余检查
- 去掉重复数据整理
- 减少 CPU/GPU 往返
- 减少 Python 高频小开销
如果你要,我下一步可以直接帮你定位“动作是在哪一层以 torch.Tensor 传进
env 的”,给你指出具体应该改哪个函数。

View File

@@ -458,3 +458,589 @@ metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 86.20240807533264 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
profile:
enabled: true
export_tensorboard: false
export_chrome_trace: false
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 518.512722492218 seconds
inference_precision: fp16
profile_dir: /mnt/ASC1637/lewm_baseline/le-wm/torch_profile
profile_summary: /mnt/ASC1637/lewm_baseline/le-wm/torch_profile/key_averages.txt
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 86.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, False,
True, True, False, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 89.49835586547852 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 86.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, False,
True, True, False, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 105.07861399650574 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 152.31250739097595 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 122.81560277938843 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 101.30036067962646 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 122.01387643814087 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 86.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, False,
True, True, False, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 110.37948775291443 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 94.35640263557434 seconds
inference_precision: fp16
==== CONFIG ====
cache_dir: null
solver:
_target_: stable_worldmodel.solver.CEMSolver
model: ???
batch_size: 1
num_samples: 300
var_scale: 1.0
n_steps: 30
topk: 30
device: cuda
seed: ${seed}
world:
env_name: swm/TwoRoom-v1
num_envs: ${eval.num_eval}
max_episode_steps: 100
history_size: 1
frame_skip: 1
seed: 42
policy: two-room/tworoom/lejepa
inference_precision: fp16
dataset:
stats: ${eval.dataset_name}
keys_to_cache:
- action
- proprio
plan_config:
horizon: 5
receding_horizon: 5
action_block: 5
eval:
num_eval: 50
goal_offset_steps: 25
eval_budget: 50
img_size: 224
dataset_name: tworoom
callables:
- method: _set_state
args:
state:
value: proprio
- method: _set_goal_state
args:
goal_state:
value: goal_proprio
output:
filename: tworoom_results.txt
==== RESULTS ====
metrics: {'success_rate': 90.0, 'episode_successes': array([ True, False, True, False, True, True, True, True, False,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, False, True, True, True, True,
True, True, True, False, True, True, True, True, True,
True, True, True, True, True]), 'seeds': None}
evaluation_time: 98.5384590625763 seconds
inference_precision: fp16