From aba2a90045976de7c7d4c227b069c4ffd4738c4e Mon Sep 17 00:00:00 2001 From: qhy <2728290997@qq.com> Date: Sat, 7 Feb 2026 16:40:33 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=97=E5=AD=90=E8=9E=8D=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/unifolm_wma/models/samplers/ddim.py | 75 +++++++++++++------------ 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index cd761c4..71b01a9 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -55,16 +55,13 @@ class DDIMSampler(object): to_torch(self.model.alphas_cumprod_prev)) # Calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', - to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', - to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', - to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', - to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', - to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # Computed directly on GPU to avoid CPU↔GPU transfers + ac = to_torch(alphas_cumprod) + self.register_buffer('sqrt_alphas_cumprod', ac.sqrt()) + self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt()) + self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log()) + self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt()) + self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt()) # DDIM sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( @@ -86,6 +83,11 @@ class DDIMSampler(object): self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', torch.sqrt(1. - ddim_alphas)) + # Precomputed coefficients for DDIM update formula + self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt()) + self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt()) + self.register_buffer('ddim_dir_coeff', + (1. - ddim_alphas_prev - ddim_sigmas**2).sqrt()) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) @@ -208,18 +210,11 @@ class DDIMSampler(object): dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - action = torch.randn((b, 16, self.model.agent_action_dim), - device=device) - state = torch.randn((b, 16, self.model.agent_state_dim), - device=device) - else: - img = x_T - action = torch.randn((b, 16, self.model.agent_action_dim), - device=device) - state = torch.randn((b, 16, self.model.agent_state_dim), - device=device) + action = torch.randn((b, 16, self.model.agent_action_dim), + device=device) + state = torch.randn((b, 16, self.model.agent_state_dim), + device=device) + img = torch.randn(shape, device=device) if x_T is None else x_T if precision is not None: if precision == 16: @@ -362,12 +357,13 @@ class DDIMSampler(object): **kwargs) else: raise NotImplementedError - model_output = e_t_uncond + unconditional_guidance_scale * ( - e_t_cond - e_t_uncond) - model_output_action = e_t_uncond_action + unconditional_guidance_scale * ( - e_t_cond_action - e_t_uncond_action) - model_output_state = e_t_uncond_state + unconditional_guidance_scale * ( - e_t_cond_state - e_t_uncond_state) + model_output = torch.lerp(e_t_uncond, e_t_cond, + unconditional_guidance_scale) + model_output_action = torch.lerp(e_t_uncond_action, + e_t_cond_action, + unconditional_guidance_scale) + model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state, + unconditional_guidance_scale) if guidance_rescale > 0.0: model_output = rescale_noise_cfg( @@ -396,18 +392,28 @@ class DDIMSampler(object): sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + if use_original_steps: + sqrt_alphas = alphas.sqrt() + sqrt_alphas_prev = alphas_prev.sqrt() + dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt() + else: + sqrt_alphas = self.ddim_sqrt_alphas + sqrt_alphas_prev = self.ddim_sqrt_alphas_prev + dir_coeffs = self.ddim_dir_coeff + if is_video: size = (1, 1, 1, 1, 1) else: size = (1, 1, 1, 1) - a_t = alphas[index].view(size) - a_prev = alphas_prev[index].view(size) + sqrt_at = sqrt_alphas[index].view(size) + sqrt_a_prev = sqrt_alphas_prev[index].view(size) sigma_t = sigmas[index].view(size) sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size) + dir_coeff = dir_coeffs[index].view(size) if self.model.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at else: pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) @@ -420,14 +426,11 @@ class DDIMSampler(object): if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) - - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise return x_prev, pred_x0, model_output_action, model_output_state @@ -475,7 +478,7 @@ class DDIMSampler(object): sqrt_alphas_cumprod = self.sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_alphas_cumprod = self.ddim_sqrt_alphas sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas if noise is None: