减少了一路视频vae解码
This commit is contained in:
@@ -330,7 +330,8 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
decode_video: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||
|
||||
@@ -353,10 +354,13 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
decode_video (bool): Whether to decode latent samples to pixel-space video.
|
||||
Set to False to skip VAE decode for speed when only actions/states are needed.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
|
||||
or None when decode_video=False.
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
"""
|
||||
@@ -409,6 +413,7 @@ def image_guided_synthesis_sim_mode(
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
batch_variants = None
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
@@ -427,9 +432,10 @@ def image_guided_synthesis_sim_mode(
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
if decode_video:
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
@@ -590,7 +596,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False)
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
for idx in range(len(pred_actions[0])):
|
||||
@@ -648,11 +655,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
observation)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
if pred_videos_0 is not None:
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
@@ -661,10 +669,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
@@ -797,6 +806,11 @@ def get_parser():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="not using the predicted states as comparison")
|
||||
parser.add_argument(
|
||||
"--fast_policy_no_decode",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
||||
parser.add_argument("--save_fps",
|
||||
type=int,
|
||||
default=8,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
2026-02-10 16:42:59.052755: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 16:42:59.102749: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 16:42:59.102803: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 16:42:59.104125: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 16:42:59.111711: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-10 17:03:42.057881: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 17:03:42.107520: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 17:03:42.107564: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 17:03:42.108900: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 17:03:42.116404: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-10 16:43:00.040735: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-10 17:03:43.044539: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
@@ -92,7 +92,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [00:37<06:15, 37.55s/it]
|
||||
18%|█▊ | 2/11 [01:15<05:39, 37.71s/it]
|
||||
@@ -125,6 +125,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||
--n_iter 11 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
--perframe_ae \
|
||||
--fast_policy_no_decode
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user