DDIM loop 内小张量分配优化,attention mask 缓存到 GPU

This commit is contained in:
2026-02-08 14:20:48 +00:00
parent e588182642
commit 75c798ded0
3 changed files with 44 additions and 43 deletions

View File

@@ -67,11 +67,12 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
# Ensure tensors are on correct device for efficient indexing
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -241,9 +242,10 @@ class DDIMSampler(object):
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts.fill_(step)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
@@ -325,10 +327,6 @@ class DDIMSampler(object):
guidance_rescale=0.0,
**kwargs):
b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model(
@@ -377,17 +375,11 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
# Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index]
a_prev = alphas_prev[index]
sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +387,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
scale_t = self.ddim_scale_arr[index]
prev_scale_t = self.ddim_scale_arr_prev[index]
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale

View File

@@ -275,7 +275,8 @@ class CrossAttention(nn.Module):
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
block_size=16,
device=k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
@@ -386,14 +387,26 @@ class CrossAttention(nn.Module):
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None):
cache_key = (b, l1, l2, block_size)
if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key:
cached = self._attn_mask_aa_cache
if device is not None and cached.device != torch.device(device):
cached = cached.to(device)
self._attn_mask_aa_cache = cached
return cached
target_device = device if device is not None else 'cpu'
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token
col_indices = torch.arange(l2, device=target_device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
attn_mask[mask] = float('-inf')
self._attn_mask_aa_cache_key = cache_key
self._attn_mask_aa_cache = attn_mask
return attn_mask

View File

@@ -1,12 +1,12 @@
2026-02-08 12:22:55.885867: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 12:22:55.890510: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 12:22:55.938683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 12:22:55.938759: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 12:22:55.941091: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 12:22:55.952450: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 12:22:55.952933: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
2026-02-08 13:59:02.578826: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 13:59:02.581891: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 13:59:02.613088: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 13:59:02.613125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 13:59:02.614961: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 13:59:02.623180: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 13:59:02.623460: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 12:22:56.593653: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-08 13:59:03.306638: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@@ -115,7 +115,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:24<09:54, 84.90s/it]
25%|██▌ | 2/8 [02:49<08:27, 84.55s/it]
@@ -139,6 +139,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>