From 1d23e5d36dc0ccce7b5c5c734593a9154bb0bb29 Mon Sep 17 00:00:00 2001 From: olivame Date: Wed, 11 Feb 2026 07:11:55 +0000 Subject: [PATCH] =?UTF-8?q?Layer=203:=20=E5=BB=B6=E8=BF=9F=20decode?= =?UTF-8?q?=EF=BC=8C=E5=8F=AA=E8=A7=A3=E7=A0=81=20CLIP=20=E9=9C=80?= =?UTF-8?q?=E8=A6=81=E7=9A=84=201=20=E5=B8=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - world model 调用 decode_video=False,跳过 16 帧全量 decode - 只 decode 最后 1 帧给 CLIP embedding / observation queue - 存 raw latent,循环结束后统一 batch decode 生成最终视频 - 每轮省 15 次 VAE decode,8 轮共省 120 次 - 跳过中间迭代的 wm tensorboard/mp4 保存 Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/evaluation/world_model_interaction.py | 43 +++++++------------ 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index cb25f2e..50182df 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -559,6 +559,7 @@ def image_guided_synthesis_sim_mode( autocast_ctx = nullcontext() batch_variants = None + samples = None if ddim_sampler is not None: with autocast_ctx: samples, actions, states, intermedia = ddim_sampler.sample( @@ -583,7 +584,7 @@ def image_guided_synthesis_sim_mode( batch_images = model.decode_first_stage(samples) batch_variants = batch_images - return batch_variants, actions, states + return batch_variants, actions, states, samples def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: @@ -693,7 +694,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: sample_save_dir = f'{video_save_dir}/wm/{fs}' os.makedirs(sample_save_dir, exist_ok=True) # For collecting interaction videos - wm_video = [] + wm_latent = [] # Initialize observation queues cond_obs_queues = { "observation.images.top": @@ -749,7 +750,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Use world-model in policy to generate action print(f'>>> Step {itr}: generating actions ...') - pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( + pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode( model, sample['instruction'], observation, @@ -791,7 +792,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Interaction with the world-model print(f'>>> Step {itr}: interacting with world model ...') - pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( + pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode( model, "", observation, @@ -804,12 +805,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale) + guidance_rescale=args.guidance_rescale, + decode_video=False) + + # Decode only the last frame for CLIP embedding in next iteration + last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :]) for idx in range(args.exe_steps): observation = { 'observation.images.top': - pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), + last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3), 'observation.state': torch.zeros_like(pred_states[0][idx:idx + 1]) if args.zero_pred_state else pred_states[0][idx:idx + 1], @@ -827,30 +832,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: pred_videos_0, sample_tag, fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_1, - sample_tag, - fps=args.save_fps) - - # Save the imagen videos for decision-making - if pred_videos_0 is not None: - sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_0.cpu(), - sample_video_file, - fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_1.cpu(), - sample_video_file, - fps=args.save_fps) print('>' * 24) - # Collect the result of world-model interactions - wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) + # Store raw latent for deferred decode + wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu()) - full_video = torch.cat(wm_video, dim=2) + # Deferred decode: batch decode all stored latents + full_latent = torch.cat(wm_latent, dim=2).to(device) + full_video = model.decode_first_stage(full_latent).cpu() sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" log_to_tensorboard(writer, full_video,