Layer 3: 延迟 decode,只解码 CLIP 需要的 1 帧
- world model 调用 decode_video=False,跳过 16 帧全量 decode - 只 decode 最后 1 帧给 CLIP embedding / observation queue - 存 raw latent,循环结束后统一 batch decode 生成最终视频 - 每轮省 15 次 VAE decode,8 轮共省 120 次 - 跳过中间迭代的 wm tensorboard/mp4 保存 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -559,6 +559,7 @@ def image_guided_synthesis_sim_mode(
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
batch_variants = None
|
||||
samples = None
|
||||
if ddim_sampler is not None:
|
||||
with autocast_ctx:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
@@ -583,7 +584,7 @@ def image_guided_synthesis_sim_mode(
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
return batch_variants, actions, states, samples
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
@@ -693,7 +694,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||
os.makedirs(sample_save_dir, exist_ok=True)
|
||||
# For collecting interaction videos
|
||||
wm_video = []
|
||||
wm_latent = []
|
||||
# Initialize observation queues
|
||||
cond_obs_queues = {
|
||||
"observation.images.top":
|
||||
@@ -749,7 +750,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
@@ -791,7 +792,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
@@ -804,12 +805,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
decode_video=False)
|
||||
|
||||
# Decode only the last frame for CLIP embedding in next iteration
|
||||
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
|
||||
|
||||
for idx in range(args.exe_steps):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
|
||||
'observation.state':
|
||||
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
||||
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
||||
@@ -827,30 +832,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Collect the result of world-model interactions
|
||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
||||
# Store raw latent for deferred decode
|
||||
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
# Deferred decode: batch decode all stored latents
|
||||
full_latent = torch.cat(wm_latent, dim=2).to(device)
|
||||
full_video = model.decode_first_stage(full_latent).cpu()
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard(writer,
|
||||
full_video,
|
||||
|
||||
Reference in New Issue
Block a user