算子融合
This commit is contained in:
@@ -55,16 +55,13 @@ class DDIMSampler(object):
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
# Computed directly on GPU to avoid CPU↔GPU transfers
|
||||
ac = to_torch(alphas_cumprod)
|
||||
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
|
||||
|
||||
# DDIM sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
@@ -86,6 +83,11 @@ class DDIMSampler(object):
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
torch.sqrt(1. - ddim_alphas))
|
||||
# Precomputed coefficients for DDIM update formula
|
||||
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
|
||||
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
|
||||
self.register_buffer('ddim_dir_coeff',
|
||||
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -208,18 +210,11 @@ class DDIMSampler(object):
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
img = torch.randn(shape, device=device) if x_T is None else x_T
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
@@ -362,12 +357,13 @@ class DDIMSampler(object):
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
model_output = torch.lerp(e_t_uncond, e_t_cond,
|
||||
unconditional_guidance_scale)
|
||||
model_output_action = torch.lerp(e_t_uncond_action,
|
||||
e_t_cond_action,
|
||||
unconditional_guidance_scale)
|
||||
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
|
||||
unconditional_guidance_scale)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
@@ -396,18 +392,28 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if use_original_steps:
|
||||
sqrt_alphas = alphas.sqrt()
|
||||
sqrt_alphas_prev = alphas_prev.sqrt()
|
||||
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
|
||||
else:
|
||||
sqrt_alphas = self.ddim_sqrt_alphas
|
||||
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
|
||||
dir_coeffs = self.ddim_dir_coeff
|
||||
|
||||
if is_video:
|
||||
size = (1, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (1, 1, 1, 1)
|
||||
|
||||
a_t = alphas[index].view(size)
|
||||
a_prev = alphas_prev[index].view(size)
|
||||
sqrt_at = sqrt_alphas[index].view(size)
|
||||
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
|
||||
sigma_t = sigmas[index].view(size)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
||||
dir_coeff = dir_coeffs[index].view(size)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
@@ -420,14 +426,11 @@ class DDIMSampler(object):
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
|
||||
|
||||
return x_prev, pred_x0, model_output_action, model_output_state
|
||||
|
||||
@@ -475,7 +478,7 @@ class DDIMSampler(object):
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
|
||||
Reference in New Issue
Block a user