Layer 3: 延迟 decode,只解码 CLIP 需要的 1 帧

- world model 调用 decode_video=False,跳过 16 帧全量 decode
- 只 decode 最后 1 帧给 CLIP embedding / observation queue
- 存 raw latent,循环结束后统一 batch decode 生成最终视频
- 每轮省 15 次 VAE decode,8 轮共省 120 次
- 跳过中间迭代的 wm tensorboard/mp4 保存

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-02-11 07:11:55 +00:00
parent 57ba85d147
commit 1d23e5d36d

View File

@@ -559,6 +559,7 @@ def image_guided_synthesis_sim_mode(
autocast_ctx = nullcontext()
batch_variants = None
samples = None
if ddim_sampler is not None:
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
@@ -583,7 +584,7 @@ def image_guided_synthesis_sim_mode(
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
return batch_variants, actions, states, samples
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
@@ -693,7 +694,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_video = []
wm_latent = []
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
@@ -749,7 +750,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
@@ -791,7 +792,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
model,
"",
observation,
@@ -804,12 +805,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale)
guidance_rescale=args.guidance_rescale,
decode_video=False)
# Decode only the last frame for CLIP embedding in next iteration
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
for idx in range(args.exe_steps):
observation = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][idx:idx + 1]) if
args.zero_pred_state else pred_states[0][idx:idx + 1],
@@ -827,30 +832,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
sample_video_file,
fps=args.save_fps)
print('>' * 24)
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
# Store raw latent for deferred decode
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
full_video = torch.cat(wm_video, dim=2)
# Deferred decode: batch decode all stored latents
full_latent = torch.cat(wm_latent, dim=2).to(device)
full_video = model.decode_first_stage(full_latent).cpu()
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,