From f192c8aca9f23e158dd7af22b4864494e176abba Mon Sep 17 00:00:00 2001 From: olivame Date: Mon, 9 Feb 2026 17:04:23 +0000 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0CrossAttention=20kv=E7=BC=93?= =?UTF-8?q?=E5=AD=98=EF=BC=8C=E5=87=8F=E5=B0=91=E9=87=8D=E5=A4=8D=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=EF=BC=8C=E6=8F=90=E5=8D=87=E6=80=A7=E8=83=BD=EF=BC=8C?= =?UTF-8?q?psnr=3D31.8022=20dB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/unifolm_wma/models/samplers/ddim.py | 107 +++++++++--------- src/unifolm_wma/modules/attention.py | 77 +++++++++---- .../case1/output.log | 24 ++-- 3 files changed, 126 insertions(+), 82 deletions(-) diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index c9ade02..fe50ea0 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -6,6 +6,7 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import extract_into_tensor from tqdm import tqdm +from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache class DDIMSampler(object): @@ -243,63 +244,67 @@ class DDIMSampler(object): dp_ddim_scheduler_action.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps)) ts = torch.empty((b, ), device=device, dtype=torch.long) - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts.fill_(step) + enable_cross_attn_kv_cache(self.model) + try: + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts.fill_(step) - # Use mask to blend noised original latent (img_orig) & new sampled latent (img) - if mask is not None: - assert x0 is not None - if clean_cond: - img_orig = x0 - else: - img_orig = self.model.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + # Use mask to blend noised original latent (img_orig) & new sampled latent (img) + if mask is not None: + assert x0 is not None + if clean_cond: + img_orig = x0 + else: + img_orig = self.model.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img - outs = self.p_sample_ddim( - img, - action, - state, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - mask=mask, - x0=x0, - fs=fs, - guidance_rescale=guidance_rescale, - **kwargs) + outs = self.p_sample_ddim( + img, + action, + state, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + x0=x0, + fs=fs, + guidance_rescale=guidance_rescale, + **kwargs) - img, pred_x0, model_output_action, model_output_state = outs + img, pred_x0, model_output_action, model_output_state = outs - action = dp_ddim_scheduler_action.step( - model_output_action, - step, - action, - generator=None, - ).prev_sample - state = dp_ddim_scheduler_state.step( - model_output_state, - step, - state, - generator=None, - ).prev_sample + action = dp_ddim_scheduler_action.step( + model_output_action, + step, + action, + generator=None, + ).prev_sample + state = dp_ddim_scheduler_state.step( + model_output_state, + step, + state, + generator=None, + ).prev_sample - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - intermediates['x_inter_action'].append(action) - intermediates['x_inter_state'].append(state) + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + intermediates['x_inter_action'].append(action) + intermediates['x_inter_state'].append(state) + finally: + disable_cross_attn_kv_cache(self.model) return img, action, state, intermediates diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index a126396..248a1f6 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -97,6 +97,9 @@ class CrossAttention(nn.Module): self.text_context_len = text_context_len self.agent_state_context_len = agent_state_context_len self.agent_action_context_len = agent_action_context_len + self._kv_cache = {} + self._kv_cache_enabled = False + self.cross_attention_scale_learnable = cross_attention_scale_learnable if self.image_cross_attention: self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) @@ -243,7 +246,22 @@ class CrossAttention(nn.Module): q = self.to_q(x) context = default(context, x) - if self.image_cross_attention and not spatial_self_attn: + use_cache = self._kv_cache_enabled and not spatial_self_attn + cache_hit = use_cache and len(self._kv_cache) > 0 + + if cache_hit: + # Reuse cached K/V (already in (b*h, n, d) shape) + k = self._kv_cache['k'] + v = self._kv_cache['v'] + if 'k_ip' in self._kv_cache: + k_ip = self._kv_cache['k_ip'] + v_ip = self._kv_cache['v_ip'] + k_as = self._kv_cache['k_as'] + v_as = self._kv_cache['v_as'] + k_aa = self._kv_cache['k_aa'] + v_aa = self._kv_cache['v_aa'] + q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) + elif self.image_cross_attention and not spatial_self_attn: context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_action = context[:, self.agent_state_context_len:self. @@ -266,20 +284,39 @@ class CrossAttention(nn.Module): v_as = self.to_v_as(context_agent_state) k_aa = self.to_k_aa(context_agent_action) v_aa = self.to_v_aa(context_agent_action) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_ip, v_ip)) + k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_as, v_as)) + k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (k_aa, v_aa)) + + if use_cache: + self._kv_cache = { + 'k': k, 'v': v, + 'k_ip': k_ip, 'v_ip': v_ip, + 'k_as': k_as, 'v_as': v_as, + 'k_aa': k_aa, 'v_aa': v_aa, + } else: if not spatial_self_attn: context = context[:, :self.text_context_len, :] k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), - (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + + if use_cache: + self._kv_cache = {'k': k, 'v': v} # baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul sim = torch.baddbmm( torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device), q, k.transpose(-1, -2), beta=0, alpha=self.scale) - del k if exists(mask): max_neg_value = -torch.finfo(sim.dtype).max @@ -293,40 +330,28 @@ class CrossAttention(nn.Module): out = rearrange(out, '(b h) n d -> b n (h d)', h=h) if k_ip is not None and k_as is not None and k_aa is not None: - ## image cross-attention - k_ip, v_ip = map( - lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), - (k_ip, v_ip)) + ## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape) sim_ip = torch.baddbmm( torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device), q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale) - del k_ip with torch.amp.autocast('cuda', enabled=False): sim_ip = sim_ip.softmax(dim=-1) out_ip = torch.bmm(sim_ip, v_ip) out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) - ## agent state cross-attention - k_as, v_as = map( - lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), - (k_as, v_as)) + ## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape) sim_as = torch.baddbmm( torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device), q, k_as.transpose(-1, -2), beta=0, alpha=self.scale) - del k_as with torch.amp.autocast('cuda', enabled=False): sim_as = sim_as.softmax(dim=-1) out_as = torch.bmm(sim_as, v_as) out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h) - ## agent action cross-attention - k_aa, v_aa = map( - lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), - (k_aa, v_aa)) + ## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape) sim_aa = torch.baddbmm( torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device), q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale) - del k_aa with torch.amp.autocast('cuda', enabled=False): sim_aa = sim_aa.softmax(dim=-1) out_aa = torch.bmm(sim_aa, v_aa) @@ -526,6 +551,20 @@ class CrossAttention(nn.Module): return attn_mask +def enable_cross_attn_kv_cache(module): + for m in module.modules(): + if isinstance(m, CrossAttention): + m._kv_cache_enabled = True + m._kv_cache = {} + + +def disable_cross_attn_kv_cache(module): + for m in module.modules(): + if isinstance(m, CrossAttention): + m._kv_cache_enabled = False + m._kv_cache = {} + + class BasicTransformerBlock(nn.Module): def __init__(self, diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log index 5116dc6..4286993 100644 --- a/unitree_z1_dual_arm_cleanup_pencils/case1/output.log +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/output.log @@ -1,14 +1,14 @@ /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. __import__("pkg_resources").declare_namespace(__name__) -2026-02-09 16:37:01.511249: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. -2026-02-09 16:37:01.514371: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-09 16:37:01.545068: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered -2026-02-09 16:37:01.545097: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered -2026-02-09 16:37:01.546937: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered -2026-02-09 16:37:01.555024: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. -2026-02-09 16:37:01.555338: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. +2026-02-09 16:53:59.556813: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +2026-02-09 16:53:59.559892: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-09 16:53:59.591414: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered +2026-02-09 16:53:59.591446: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered +2026-02-09 16:53:59.593281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered +2026-02-09 16:53:59.601486: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used. +2026-02-09 16:53:59.601838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. -2026-02-09 16:37:02.212554: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT +2026-02-09 16:54:00.228108: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT [rank: 0] Global seed set to 123 /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) @@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin - 12%|█▎ | 1/8 [01:11<08:20, 71.56s/it] 25%|██▌ | 2/8 [02:19<06:56, 69.36s/it] 38%|███▊ | 3/8 [03:27<05:43, 68.67s/it] 50%|█████ | 4/8 [04:35<04:33, 68.41s/it] 62%|██████▎ | 5/8 [05:43<03:25, 68.38s/it] 75%|███████▌ | 6/8 [06:51<02:16, 68.18s/it] 88%|████████▊ | 7/8 [07:59<01:08, 68.01s/it] 100%|██████████| 8/8 [09:07<00:00, 68.02s/it] 100%|██████████| 8/8 [09:07<00:00, 68.38s/it] + 12%|█▎ | 1/8 [01:09<08:08, 69.72s/it] 25%|██▌ | 2/8 [02:15<06:45, 67.61s/it] 38%|███▊ | 3/8 [03:21<05:34, 66.92s/it] 50%|█████ | 4/8 [04:28<04:26, 66.60s/it] 62%|██████▎ | 5/8 [05:34<03:19, 66.44s/it] 75%|███████▌ | 6/8 [06:40<02:12, 66.32s/it] 88%|████████▊ | 7/8 [07:46<01:06, 66.25s/it] 100%|██████████| 8/8 [08:52<00:00, 66.23s/it] 100%|██████████| 8/8 [08:52<00:00, 66.57s/it] >>>>>>>>>>>>>>>>>>>>>>>> >>> Step 1: generating actions ... >>> Step 1: interacting with world model ... @@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin >>> Step 7: interacting with world model ... >>>>>>>>>>>>>>>>>>>>>>>> -real 10m15.640s -user 11m34.152s -sys 0m48.021s +real 9m53.691s +user 11m23.200s +sys 0m42.702s