添加CrossAttention kv缓存,减少重复计算,提升性能,psnr=31.8022 dB

This commit is contained in:
2026-02-09 17:04:23 +00:00
parent 4288c9d8c9
commit f192c8aca9
3 changed files with 126 additions and 82 deletions

View File

@@ -6,6 +6,7 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim
from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
class DDIMSampler(object):
@@ -243,63 +244,67 @@ class DDIMSampler(object):
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
ts = torch.empty((b, ), device=device, dtype=torch.long)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts.fill_(step)
enable_cross_attn_kv_cache(self.model)
try:
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts.fill_(step)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
if clean_cond:
img_orig = x0
else:
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
if clean_cond:
img_orig = x0
else:
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(
img,
action,
state,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
outs = self.p_sample_ddim(
img,
action,
state,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
img, pred_x0, model_output_action, model_output_state = outs
img, pred_x0, model_output_action, model_output_state = outs
action = dp_ddim_scheduler_action.step(
model_output_action,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
model_output_state,
step,
state,
generator=None,
).prev_sample
action = dp_ddim_scheduler_action.step(
model_output_action,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
model_output_state,
step,
state,
generator=None,
).prev_sample
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
finally:
disable_cross_attn_kv_cache(self.model)
return img, action, state, intermediates

View File

@@ -97,6 +97,9 @@ class CrossAttention(nn.Module):
self.text_context_len = text_context_len
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self._kv_cache = {}
self._kv_cache_enabled = False
self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
@@ -243,7 +246,22 @@ class CrossAttention(nn.Module):
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
use_cache = self._kv_cache_enabled and not spatial_self_attn
cache_hit = use_cache and len(self._kv_cache) > 0
if cache_hit:
# Reuse cached K/V (already in (b*h, n, d) shape)
k = self._kv_cache['k']
v = self._kv_cache['v']
if 'k_ip' in self._kv_cache:
k_ip = self._kv_cache['k_ip']
v_ip = self._kv_cache['v_ip']
k_as = self._kv_cache['k_as']
v_as = self._kv_cache['v_as']
k_aa = self._kv_cache['k_aa']
v_aa = self._kv_cache['v_aa']
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
elif self.image_cross_attention and not spatial_self_attn:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
@@ -266,20 +284,39 @@ class CrossAttention(nn.Module):
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
if use_cache:
self._kv_cache = {
'k': k, 'v': v,
'k_ip': k_ip, 'v_ip': v_ip,
'k_as': k_as, 'v_as': v_as,
'k_aa': k_aa, 'v_aa': v_aa,
}
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
if use_cache:
self._kv_cache = {'k': k, 'v': v}
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
sim = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
del k
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
@@ -293,40 +330,28 @@ class CrossAttention(nn.Module):
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## image cross-attention
k_ip, v_ip = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
sim_ip = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
del k_ip
with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.bmm(sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## agent state cross-attention
k_as, v_as = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
sim_as = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
del k_as
with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.bmm(sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## agent action cross-attention
k_aa, v_aa = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
sim_aa = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
del k_aa
with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.bmm(sim_aa, v_aa)
@@ -526,6 +551,20 @@ class CrossAttention(nn.Module):
return attn_mask
def enable_cross_attn_kv_cache(module):
for m in module.modules():
if isinstance(m, CrossAttention):
m._kv_cache_enabled = True
m._kv_cache = {}
def disable_cross_attn_kv_cache(module):
for m in module.modules():
if isinstance(m, CrossAttention):
m._kv_cache_enabled = False
m._kv_cache = {}
class BasicTransformerBlock(nn.Module):
def __init__(self,

View File

@@ -1,14 +1,14 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-09 16:37:01.511249: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 16:37:01.514371: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 16:37:01.545068: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 16:37:01.545097: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 16:37:01.546937: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 16:37:01.555024: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 16:37:01.555338: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
2026-02-09 16:53:59.556813: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 16:53:59.559892: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 16:53:59.591414: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 16:53:59.591446: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 16:53:59.593281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 16:53:59.601486: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 16:53:59.601838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-09 16:37:02.212554: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-09 16:54:00.228108: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:09<08:08, 69.72s/it]
25%|██▌ | 2/8 [02:15<06:45, 67.61s/it]
@@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>