算子融合

This commit is contained in:
qhy
2026-02-07 16:40:33 +08:00
parent 25de36b9bc
commit aba2a90045

View File

@@ -55,16 +55,13 @@ class DDIMSampler(object):
to_torch(self.model.alphas_cumprod_prev))
# Calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# Computed directly on GPU to avoid CPU↔GPU transfers
ac = to_torch(alphas_cumprod)
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
# DDIM sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
@@ -86,6 +83,11 @@ class DDIMSampler(object):
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
torch.sqrt(1. - ddim_alphas))
# Precomputed coefficients for DDIM update formula
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
self.register_buffer('ddim_dir_coeff',
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -208,18 +210,11 @@ class DDIMSampler(object):
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
img = torch.randn(shape, device=device) if x_T is None else x_T
if precision is not None:
if precision == 16:
@@ -362,12 +357,13 @@ class DDIMSampler(object):
**kwargs)
else:
raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * (
e_t_cond - e_t_uncond)
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
e_t_cond_action - e_t_uncond_action)
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
e_t_cond_state - e_t_uncond_state)
model_output = torch.lerp(e_t_uncond, e_t_cond,
unconditional_guidance_scale)
model_output_action = torch.lerp(e_t_uncond_action,
e_t_cond_action,
unconditional_guidance_scale)
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
unconditional_guidance_scale)
if guidance_rescale > 0.0:
model_output = rescale_noise_cfg(
@@ -396,18 +392,28 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if use_original_steps:
sqrt_alphas = alphas.sqrt()
sqrt_alphas_prev = alphas_prev.sqrt()
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
else:
sqrt_alphas = self.ddim_sqrt_alphas
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
dir_coeffs = self.ddim_dir_coeff
if is_video:
size = (1, 1, 1, 1, 1)
else:
size = (1, 1, 1, 1)
a_t = alphas[index].view(size)
a_prev = alphas_prev[index].view(size)
sqrt_at = sqrt_alphas[index].view(size)
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
dir_coeff = dir_coeffs[index].view(size)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
@@ -420,14 +426,11 @@ class DDIMSampler(object):
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
return x_prev, pred_x0, model_output_action, model_output_state
@@ -475,7 +478,7 @@ class DDIMSampler(object):
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None: