保存的优化
This commit is contained in:
208
record.md
Normal file
208
record.md
Normal file
@@ -0,0 +1,208 @@
|
||||
好,下面是完整的 decode 优化讲解。
|
||||
|
||||
---
|
||||
背景:原始 decode 流程的问题
|
||||
|
||||
原始代码中,每轮迭代的 world model 交互会调用 image_guided_synthesis_sim_mode,内部对DDIM 采样得到的 latent 做全量 VAE decode:
|
||||
|
||||
# ddpms.py decode_core, perframe_ae=True
|
||||
for index in range(z.shape[0]): # 逐帧循环,16帧就调16次
|
||||
frame_z = 1. / self.scale_factor * z[index:index+1]
|
||||
frame_result = self.first_stage_model.decode(frame_z)
|
||||
results.append(frame_result)
|
||||
|
||||
假设 8 轮迭代、每轮 16 帧,WM 路径 decode 128 次,policy 路径再 decode 128 次,共 256 次 VAE decode。每次都是同步阻塞的。
|
||||
|
||||
decode 完的视频还会被逐轮写成 mp4 + tensorboard,产生大量磁盘 I/O。最后还要把所有轮的 decoded video 在内存中torch.cat
|
||||
拼接,再写一次完整视频。
|
||||
|
||||
---
|
||||
优化1:decode_video 开关——按需跳过 decode
|
||||
|
||||
文件: world_model_interaction.py函数 image_guided_synthesis_sim_mode
|
||||
|
||||
改动: 给函数加decode_video 参数(默认 False),返回值增加 raw samples:
|
||||
|
||||
def image_guided_synthesis_sim_mode(...,
|
||||
decode_video: bool = False, # 新增
|
||||
...) -> tuple[Tensor | None, Tensor, Tensor, Tensor | None]:
|
||||
|
||||
samples = None
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(...)if decode_video: # 条件 decode
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states, samples# 多返回 samples
|
||||
|
||||
调用侧:
|
||||
- Policy 路径:由 CLI 参数 --fast_policy_no_decode 控制,只需要 action 时可跳过 decode
|
||||
- WM 交互路径:传decode_video=False,只拿 raw latent
|
||||
|
||||
效果: WM 路径每轮省掉 16 帧全量 decode。
|
||||
|
||||
---
|
||||
优化2:只decode observation 需要的帧
|
||||
|
||||
问题: WM 跳过了全量 decode,但下一轮的CLIP embedding 需要 pixel-space 图像做 observation。
|
||||
|
||||
改动: 只decode exe_steps 帧(通常 1帧),而不是全部 16 帧:
|
||||
|
||||
# WM 调用,不做全量 decode
|
||||
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
||||
..., decode_video=False)
|
||||
|
||||
# 只 decode exe_steps 帧给 observation
|
||||
obs_pixels = model.decode_first_stage(
|
||||
wm_samples[:, :, :args.exe_steps, :, :])
|
||||
|
||||
for idx in range(args.exe_steps):
|
||||
observation = {
|
||||
'observation.images.top':obs_pixels[0, :, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
...
|
||||
}
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
关键细节: 必须逐帧填充 observation queue(idx:idx+1),不能全用最后一帧,否则 CLIP embedding 输入变了会影响精度。
|
||||
|
||||
效果: 每轮从 decode 16 帧降到 decode exe_steps 帧(省15 帧/轮)。
|
||||
|
||||
---
|
||||
优化3:decode stream——GPU 上并行 decode 和 UNet
|
||||
|
||||
问题: 写入最终视频仍需要完整 segment 的 pixel,这部分 decode 还是要做。
|
||||
|
||||
思路: 用独立 CUDA stream 做 segment decode,和下一轮 UNet 推断在 GPU 上并行。
|
||||
|
||||
改动:
|
||||
|
||||
初始化:
|
||||
decode_stream = torch.cuda.Stream(device=device)
|
||||
pending_decode = None
|
||||
|
||||
循环尾部:
|
||||
# 收集上一轮 decode 结果
|
||||
if pending_decode is not None:
|
||||
decode_stream.synchronize()
|
||||
write_q.put(pending_decode.cpu())
|
||||
pending_decode = None
|
||||
|
||||
# 在 decode stream 上启动当前轮 segment decode(不阻塞主线程)
|
||||
latent_slice = wm_samples[:, :, :args.exe_steps]
|
||||
decode_stream.wait_stream(torch.cuda.current_stream()) # 确保 latent 就绪
|
||||
with torch.cuda.stream(decode_stream):
|
||||
pending_decode = model.decode_first_stage(latent_slice)
|
||||
# 主线程立即进入下一轮 UNet
|
||||
|
||||
循环结束后收集最后一轮:
|
||||
if pending_decode is not None:
|
||||
decode_stream.synchronize()
|
||||
write_q.put(pending_decode.cpu())
|
||||
|
||||
原理: decode_stream.wait_stream() 建立 stream间依赖,确保 latent 产出后才开始 decode。两个 stream 的 kernel 可以被GPU
|
||||
调度器交错执行。
|
||||
|
||||
效果: segment decode 时间被下一轮 UNet 推断掩盖。
|
||||
|
||||
---
|
||||
优化4:Writer 进程——CPU 工作跨进程并行
|
||||
|
||||
问题: decode 完的tensor 需要转numpy + cv2 编码写盘,这是 CPU 密集型操作,Python GIL 限制线程并行。
|
||||
|
||||
改动:
|
||||
|
||||
辅助函数(主进程和子进程都能调用):
|
||||
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
return grid.numpy()[:, :, :, ::-1] # RGB → BGR
|
||||
|
||||
Writer 进程:
|
||||
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
|
||||
vwriter = None
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None: # sentinel,退出
|
||||
break
|
||||
frames = _video_tensor_to_frames(item)
|
||||
if vwriter is None:
|
||||
h, w = frames.shape[1], frames.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
vwriter = cv2.VideoWriter(filename, fourcc, fps, (w, h))
|
||||
for f in frames:
|
||||
vwriter.write(f)
|
||||
if vwriter is not None:
|
||||
vwriter.release()
|
||||
|
||||
主进程启动 writer:
|
||||
write_q = mp.Queue()
|
||||
writer_proc = mp.Process(target=_video_writer_process,
|
||||
args=(write_q, sample_full_video_file, args.save_fps))
|
||||
writer_proc.start()
|
||||
|
||||
主进程通过 write_q.put(tensor.cpu()) 发送数据,循环结束发None sentinel 并join()。
|
||||
|
||||
效果:
|
||||
- tensor→numpy 转换和cv2 编码不占主进程 CPU 时间
|
||||
- 不受 GIL 限制
|
||||
- cv2.VideoWriter 增量写入,不攒内存,不做最终 torch.cat
|
||||
- Queue 自带背压,writer 处理不过来时 put 自然阻塞,不会 OOM
|
||||
|
||||
---
|
||||
同时删除的冗余 I/O
|
||||
|
||||
- 每轮迭代的 WM 中间tensorboard log(log_to_tensorboard_async)
|
||||
- 每轮迭代的 WM 中间 mp4(save_results_async)
|
||||
- 每轮迭代的 policy 中间 mp4
|
||||
- 最终的 torch.cat + 一次性 save_results_async 写完整视频
|
||||
|
||||
---
|
||||
总结:优化前后对比
|
||||
|
||||
┌─────────────────────┬────────────────────────┬───────────────────────────┐
|
||||
│ │ 原始 │ 优化后 │
|
||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
||||
│ WM decode帧数/轮 │ 16 │ exe_steps(通常1) │
|
||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
||||
│ segment decode 阻塞 │ 同步阻塞主线程 │ decode stream,GPU并行 │
|
||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
||||
│ tensor→numpy+写盘 │ 主线程同步/线程池(GIL) │ 独立进程,无GIL │
|
||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
||||
│ 视频写入方式 │ 攒内存→cat→一次写 │ cv2增量追加写 │
|
||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
||||
│ 中间I/O │ 每轮2×mp4+2×tb │ 无 │
|
||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
||||
│ 内存占用 │ 随迭代数线性增长 │ 恒定(不攒decoded video) │
|
||||
└─────────────────────┴────────────────────────┴───────────────────────────┘
|
||||
|
||||
✻ Cooked for 1m 31s
|
||||
|
||||
❯ 目前哪些优化是默认开启的
|
||||
|
||||
● 让我检查各个优化的开关状态。
|
||||
|
||||
● Searched for 2 patterns (ctrl+o to expand)
|
||||
|
||||
● 当前各优化的默认状态:
|
||||
|
||||
┌───────────────────────────────────┬────────────┬────────────────────────────────────────────────────────┐
|
||||
│ 优化 │ 默认开启? │ 控制方式 │
|
||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
||||
│ WM路径跳过全量decode │ ✅ 是 │ 硬编码 decode_video=False(789行) │
|
||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
||||
│ 只decode exe_steps帧给observation │ ✅ 是 │ 无开关,直接生效(792行) │
|
||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
||||
│ decode stream GPU并行 │ ✅ 是 │ 无开关,直接生效(824-827行) │
|
||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
||||
│ Writer进程增量写盘 │ ✅ 是 │ 无开关,直接生效(672-675行) │
|
||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
||||
│ 删除中间wm mp4/tb保存 │ ✅ 是 │ 代码已删除 │
|
||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
||||
│ Policy路径跳过decode │ ❌ 否 │ 需传--fast_policy_no_decode(默认False,即默认decode) │
|
||||
└───────────────────────────────────┴────────────┴────────────────────────────────────────────────────────┘
|
||||
Reference in New Issue
Block a user