DDIM loop 内小张量分配优化,attention mask 缓存到 GPU
This commit is contained in:
@@ -67,11 +67,12 @@ class DDIMSampler(object):
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
# Ensure tensors are on correct device for efficient indexing
|
||||
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
|
||||
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
|
||||
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -241,9 +242,10 @@ class DDIMSampler(object):
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts.fill_(step)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
@@ -325,10 +327,6 @@ class DDIMSampler(object):
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
@@ -377,17 +375,11 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index]
|
||||
a_prev = alphas_prev[index]
|
||||
sigma_t = sigmas[index]
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
@@ -395,12 +387,8 @@ class DDIMSampler(object):
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
scale_t = self.ddim_scale_arr[index]
|
||||
prev_scale_t = self.ddim_scale_arr_prev[index]
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
|
||||
@@ -275,7 +275,8 @@ class CrossAttention(nn.Module):
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
block_size=16,
|
||||
device=k_aa.device)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
@@ -386,14 +387,26 @@ class CrossAttention(nn.Module):
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None):
|
||||
cache_key = (b, l1, l2, block_size)
|
||||
if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key:
|
||||
cached = self._attn_mask_aa_cache
|
||||
if device is not None and cached.device != torch.device(device):
|
||||
cached = cached.to(device)
|
||||
self._attn_mask_aa_cache = cached
|
||||
return cached
|
||||
|
||||
target_device = device if device is not None else 'cpu'
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2, device=target_device)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
|
||||
attn_mask[mask] = float('-inf')
|
||||
|
||||
self._attn_mask_aa_cache_key = cache_key
|
||||
self._attn_mask_aa_cache = attn_mask
|
||||
return attn_mask
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
2026-02-08 12:22:55.885867: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 12:22:55.890510: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 12:22:55.938683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 12:22:55.938759: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 12:22:55.941091: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 12:22:55.952450: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 12:22:55.952933: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-08 13:59:02.578826: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 13:59:02.581891: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 13:59:02.613088: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 13:59:02.613125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 13:59:02.614961: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 13:59:02.623180: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 13:59:02.623460: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 12:22:56.593653: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-08 13:59:03.306638: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
@@ -115,7 +115,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:24<09:54, 84.90s/it]
|
||||
25%|██▌ | 2/8 [02:49<08:27, 84.55s/it]
|
||||
@@ -139,6 +139,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
Reference in New Issue
Block a user