diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index 77602a1..2e88f0b 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -67,11 +67,12 @@ class DDIMSampler(object): ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + # Ensure tensors are on correct device for efficient indexing + self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas))) + self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas))) + self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev))) self.register_buffer('ddim_sqrt_one_minus_alphas', - np.sqrt(1. - ddim_alphas)) + to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas)))) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) @@ -241,9 +242,10 @@ class DDIMSampler(object): dp_ddim_scheduler_action.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps)) + ts = torch.empty((b, ), device=device, dtype=torch.long) for i, step in enumerate(iterator): index = total_steps - i - 1 - ts = torch.full((b, ), step, device=device, dtype=torch.long) + ts.fill_(step) # Use mask to blend noised original latent (img_orig) & new sampled latent (img) if mask is not None: @@ -325,10 +327,6 @@ class DDIMSampler(object): guidance_rescale=0.0, **kwargs): b, *_, device = *x.shape, x.device - if x.dim() == 5: - is_video = True - else: - is_video = False if unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output, model_output_action, model_output_state = self.model.apply_model( @@ -377,17 +375,11 @@ class DDIMSampler(object): sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - if is_video: - size = (b, 1, 1, 1, 1) - else: - size = (b, 1, 1, 1) - - a_t = torch.full(size, alphas[index], device=device) - a_prev = torch.full(size, alphas_prev[index], device=device) - sigma_t = torch.full(size, sigmas[index], device=device) - sqrt_one_minus_at = torch.full(size, - sqrt_one_minus_alphas[index], - device=device) + # Use 0-d tensors directly (already on device); broadcasting handles shape + a_t = alphas[index] + a_prev = alphas_prev[index] + sigma_t = sigmas[index] + sqrt_one_minus_at = sqrt_one_minus_alphas[index] if self.model.parameterization != "v": pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -395,12 +387,8 @@ class DDIMSampler(object): pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) if self.model.use_dynamic_rescale: - scale_t = torch.full(size, - self.ddim_scale_arr[index], - device=device) - prev_scale_t = torch.full(size, - self.ddim_scale_arr_prev[index], - device=device) + scale_t = self.ddim_scale_arr[index] + prev_scale_t = self.ddim_scale_arr_prev[index] rescale = (prev_scale_t / scale_t) pred_x0 *= rescale diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index eab69d4..9c9c5b7 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -275,7 +275,8 @@ class CrossAttention(nn.Module): attn_mask_aa = self._get_attn_mask_aa(x.shape[0], q.shape[1], k_aa.shape[1], - block_size=16).to(k_aa.device) + block_size=16, + device=k_aa.device) else: if not spatial_self_attn: assert 1 > 2, ">>> ERROR: you should never go into here ..." @@ -386,14 +387,26 @@ class CrossAttention(nn.Module): return self.to_out(out) - def _get_attn_mask_aa(self, b, l1, l2, block_size=16): + def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None): + cache_key = (b, l1, l2, block_size) + if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key: + cached = self._attn_mask_aa_cache + if device is not None and cached.device != torch.device(device): + cached = cached.to(device) + self._attn_mask_aa_cache = cached + return cached + + target_device = device if device is not None else 'cpu' num_token = l2 // block_size - start_positions = ((torch.arange(b) % block_size) + 1) * num_token - col_indices = torch.arange(l2) + start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token + col_indices = torch.arange(l2, device=target_device) mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) mask = mask_2d.unsqueeze(1).expand(b, l1, l2) - attn_mask = torch.zeros_like(mask, dtype=torch.float) + attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device) attn_mask[mask] = float('-inf') + + self._attn_mask_aa_cache_key = cache_key + self._attn_mask_aa_cache = attn_mask return attn_mask diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index 9ee412e..4ad0ac4 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,12 +1,12 @@ -2026-02-08 12:22:55.885867: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-08 12:22:55.890510: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 12:22:55.938683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-08 12:22:55.938759: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-08 12:22:55.941091: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-08 12:22:55.952450: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-08 12:22:55.952933: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-08 13:59:02.578826: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-08 13:59:02.581891: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 13:59:02.613088: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-08 13:59:02.613125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-08 13:59:02.614961: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-08 13:59:02.623180: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-08 13:59:02.623460: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-08 12:22:56.593653: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-08 13:59:03.306638: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 /mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) @@ -115,7 +115,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:24<09:53, 84.82s/it] 25%|██▌ | 2/8 [02:49<08:26, 84.48s/it] 38%|███▊ | 3/8 [04:13<07:01, 84.40s/it] 50%|█████ | 4/8 [05:37<05:37, 84.43s/it] 62%|██████▎ | 5/8 [07:02<04:13, 84.44s/it] 75%|███████▌ | 6/8 [08:26<02:48, 84.44s/it] 88%|████████▊ | 7/8 [09:50<01:24, 84.36s/it] 100%|██████████| 8/8 [11:15<00:00, 84.41s/it] 100%|██████████| 8/8 [11:15<00:00, 84.43s/it] + 12%|█▎ | 1/8 [01:24<09:54, 84.90s/it] 25%|██▌ | 2/8 [02:49<08:27, 84.55s/it] 38%|███▊ | 3/8 [04:13<07:02, 84.46s/it] 50%|█████ | 4/8 [05:38<05:38, 84.50s/it] 62%|██████▎ | 5/8 [07:02<04:13, 84.52s/it] 75%|███████▌ | 6/8 [08:27<02:49, 84.52s/it] 88%|████████▊ | 7/8 [09:51<01:24, 84.44s/it] 100%|██████████| 8/8 [11:16<00:00, 84.47s/it] 100%|██████████| 8/8 [11:16<00:00, 84.50s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -139,6 +139,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 12m19.457s -user 12m13.197s -sys 0m38.223s +real 12m14.598s +user 12m18.424s +sys 0m45.306s