mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-28 05:35:00 +00:00
intergrate k-diffusion
This commit is contained in:
parent
14a759b5ca
commit
a07c758658
@ -39,8 +39,6 @@ class ForgeDiffusionEngine:
|
||||
self.is_sd2 = False
|
||||
self.is_sdxl = False
|
||||
self.is_sd3 = False
|
||||
self.parameterization = 'eps'
|
||||
self.alphas_cumprod = None
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
pass
|
||||
|
||||
@ -53,7 +53,6 @@ class StableDiffusion(ForgeDiffusionEngine):
|
||||
# WebUI Legacy
|
||||
self.is_sd1 = True
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@ -53,10 +53,6 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
||||
# WebUI Legacy
|
||||
self.is_sd2 = True
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
|
||||
|
||||
if not self.is_inpaint:
|
||||
self.parameterization = 'v'
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine.clip_skip = clip_skip
|
||||
|
||||
@ -72,7 +72,6 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
# WebUI Legacy
|
||||
self.is_sdxl = True
|
||||
self.first_stage_model = vae.first_stage_model
|
||||
self.alphas_cumprod = unet.model.predictor.alphas_cumprod
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
|
||||
@ -270,7 +270,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None, return_full=False):
|
||||
edit_strength = sum((item['strength'] if 'strength' in item else 1) for item in cond)
|
||||
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
@ -297,6 +297,9 @@ def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
if return_full:
|
||||
return cfg_result, cond_pred, uncond_pred
|
||||
|
||||
return cfg_result
|
||||
|
||||
|
||||
@ -333,8 +336,8 @@ def sampling_function(self, denoiser_params, cond_scale, cond_composition):
|
||||
for modifier in model_options.get('conditioning_modifiers', []):
|
||||
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
|
||||
|
||||
denoised = sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
|
||||
return denoised
|
||||
denoised, cond_pred, uncond_pred = sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options, seed, return_full=True)
|
||||
return denoised, cond_pred, uncond_pred
|
||||
|
||||
|
||||
def sampling_prepare(unet, x):
|
||||
|
||||
177
k_diffusion/external.py
Normal file
177
k_diffusion/external.py
Normal file
@ -0,0 +1,177 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import sampling, utils
|
||||
|
||||
|
||||
class VDenoiser(nn.Module):
|
||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
||||
|
||||
def __init__(self, inner_model):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
return (t * math.pi / 2).tan()
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
self.quantize = quantize
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
if quantize:
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for OpenAI diffusion models."""
|
||||
|
||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
||||
self.has_learned_sigmas = has_learned_sigmas
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
model_output = self.inner_model(*args, **kwargs)
|
||||
if self.has_learned_sigmas:
|
||||
return model_output.chunk(2, dim=1)[0]
|
||||
return model_output
|
||||
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def get_v(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models that output v."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond)
|
||||
702
k_diffusion/sampling.py
Normal file
702
k_diffusion/sampling.py
Normal file
@ -0,0 +1,702 @@
|
||||
import math
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchdiffeq import odeint
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
ramp = torch.linspace(0, 1, n)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
||||
"""Constructs an exponential noise schedule."""
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
"""Constructs an polynomial in log sigma noise schedule."""
|
||||
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
||||
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
if not eta:
|
||||
return sigma_to, 0.
|
||||
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
def default_noise_sampler(x):
|
||||
return lambda sigma, sigma_next: torch.randn_like(x)
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get('w0', torch.zeros_like(x))
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2 ** 63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
class BrownianTreeNoiseSampler:
|
||||
"""A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will
|
||||
use one BrownianTree per batch item, each with its own seed.
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver-2
|
||||
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
return ll_prior + delta_ll, {'fevals': fevals}
|
||||
|
||||
|
||||
class PIDStepSizeController:
|
||||
"""A PID controller for ODE adaptive step size control."""
|
||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||
self.h = h
|
||||
self.b1 = (pcoeff + icoeff + dcoeff) / order
|
||||
self.b2 = -(pcoeff + 2 * dcoeff) / order
|
||||
self.b3 = dcoeff / order
|
||||
self.accept_safety = accept_safety
|
||||
self.eps = eps
|
||||
self.errs = []
|
||||
|
||||
def limiter(self, x):
|
||||
return 1 + math.atan(x - 1)
|
||||
|
||||
def propose_step(self, error):
|
||||
inv_error = 1 / (float(error) + self.eps)
|
||||
if not self.errs:
|
||||
self.errs = [inv_error, inv_error, inv_error]
|
||||
self.errs[0] = inv_error
|
||||
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
|
||||
factor = self.limiter(factor)
|
||||
accept = factor >= self.accept_safety
|
||||
if accept:
|
||||
self.errs[2] = self.errs[1]
|
||||
self.errs[1] = self.errs[0]
|
||||
self.h *= factor
|
||||
return accept
|
||||
|
||||
|
||||
class DPMSolver(nn.Module):
|
||||
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
|
||||
|
||||
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.extra_args = {} if extra_args is None else extra_args
|
||||
self.eps_callback = eps_callback
|
||||
self.info_callback = info_callback
|
||||
|
||||
def t(self, sigma):
|
||||
return -sigma.log()
|
||||
|
||||
def sigma(self, t):
|
||||
return t.neg().exp()
|
||||
|
||||
def eps(self, eps_cache, key, x, t, *args, **kwargs):
|
||||
if key in eps_cache:
|
||||
return eps_cache[key], eps_cache
|
||||
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
|
||||
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
|
||||
if self.eps_callback is not None:
|
||||
self.eps_callback()
|
||||
return eps, {key: eps, **eps_cache}
|
||||
|
||||
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
x_1 = x - self.sigma(t_next) * h.expm1() * eps
|
||||
return x_1, eps_cache
|
||||
|
||||
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
s1 = t + r1 * h
|
||||
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
||||
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
||||
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
|
||||
return x_2, eps_cache
|
||||
|
||||
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
|
||||
eps_cache = {} if eps_cache is None else eps_cache
|
||||
h = t_next - t
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
s1 = t + r1 * h
|
||||
s2 = t + r2 * h
|
||||
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
|
||||
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
|
||||
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
|
||||
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
|
||||
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
|
||||
return x_3, eps_cache
|
||||
|
||||
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
if not t_end > t_start and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
|
||||
m = math.floor(nfe / 3) + 1
|
||||
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
|
||||
|
||||
if nfe % 3 == 0:
|
||||
orders = [3] * (m - 2) + [2, 1]
|
||||
else:
|
||||
orders = [3] * (m - 1) + [nfe % 3]
|
||||
|
||||
for i in range(len(orders)):
|
||||
eps_cache = {}
|
||||
t, t_next = ts[i], ts[i + 1]
|
||||
if eta:
|
||||
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
|
||||
t_next_ = torch.minimum(t_end, self.t(sd))
|
||||
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
|
||||
else:
|
||||
t_next_, su = t_next, 0.
|
||||
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
|
||||
denoised = x - self.sigma(t) * eps
|
||||
if self.info_callback is not None:
|
||||
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
|
||||
|
||||
if orders[i] == 1:
|
||||
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
elif orders[i] == 2:
|
||||
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
else:
|
||||
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
|
||||
|
||||
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
||||
|
||||
return x
|
||||
|
||||
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
if order not in {2, 3}:
|
||||
raise ValueError('order should be 2 or 3')
|
||||
forward = t_end > t_start
|
||||
if not forward and eta:
|
||||
raise ValueError('eta must be 0 for reverse sampling')
|
||||
h_init = abs(h_init) * (1 if forward else -1)
|
||||
atol = torch.tensor(atol)
|
||||
rtol = torch.tensor(rtol)
|
||||
s = t_start
|
||||
x_prev = x
|
||||
accept = True
|
||||
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
|
||||
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
|
||||
|
||||
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
|
||||
eps_cache = {}
|
||||
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
|
||||
if eta:
|
||||
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
|
||||
t_ = torch.minimum(t_end, self.t(sd))
|
||||
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
|
||||
else:
|
||||
t_, su = t, 0.
|
||||
|
||||
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
|
||||
denoised = x - self.sigma(s) * eps
|
||||
|
||||
if order == 2:
|
||||
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
|
||||
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
|
||||
else:
|
||||
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
|
||||
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
|
||||
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
|
||||
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
|
||||
accept = pid.propose_step(error)
|
||||
if accept:
|
||||
x_prev = x_low
|
||||
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
||||
s = t
|
||||
info['n_accept'] += 1
|
||||
else:
|
||||
info['n_reject'] += 1
|
||||
info['nfe'] += order
|
||||
info['steps'] += 1
|
||||
|
||||
if self.info_callback is not None:
|
||||
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
|
||||
|
||||
return x, info
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
||||
if sigma_min <= 0 or sigma_max <= 0:
|
||||
raise ValueError('sigma_min and sigma_max must not be 0')
|
||||
with tqdm(total=n, disable=disable) as pbar:
|
||||
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
||||
if callback is not None:
|
||||
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
||||
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
|
||||
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
||||
if sigma_min <= 0 or sigma_max <= 0:
|
||||
raise ValueError('sigma_min and sigma_max must not be 0')
|
||||
with tqdm(disable=disable) as pbar:
|
||||
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
|
||||
if callback is not None:
|
||||
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
|
||||
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
|
||||
if return_info:
|
||||
return x, info
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigmas[i + 1] - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
s = t + h * r
|
||||
fac = 1 / (2 * r)
|
||||
|
||||
# Step 1
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
||||
s_ = t_fn(sd)
|
||||
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
||||
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
||||
t_next_ = t_fn(sd)
|
||||
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
||||
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
|
||||
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
old_denoised = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
||||
h = t_next - t
|
||||
if old_denoised is None or sigmas[i + 1] == 0:
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
|
||||
else:
|
||||
h_last = t - t_fn(sigmas[i - 1])
|
||||
r = h_last / h
|
||||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
|
||||
if solver_type not in {'heun', 'midpoint'}:
|
||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == 'heun':
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == 'midpoint':
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""DPM-Solver++(3M) SDE."""
|
||||
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
denoised_1, denoised_2 = None, None
|
||||
h_1, h_2 = None, None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
h_eta = h * (eta + 1)
|
||||
|
||||
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
||||
|
||||
if h_2 is not None:
|
||||
r0 = h_1 / h
|
||||
r1 = h_2 / h
|
||||
d1_0 = (denoised - denoised_1) / r0
|
||||
d1_1 = (denoised_1 - denoised_2) / r1
|
||||
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
||||
d2 = (d1_0 - d1_1) / (r0 + r1)
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
phi_3 = phi_2 / h_eta - 0.5
|
||||
x = x + phi_2 * d1 - phi_3 * d2
|
||||
elif h_1 is not None:
|
||||
r = h_1 / h
|
||||
d = (denoised - denoised_1) / r
|
||||
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
||||
x = x + phi_2 * d
|
||||
|
||||
if eta:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
||||
|
||||
denoised_1, denoised_2 = denoised, denoised_1
|
||||
h_1, h_2 = h, h_1
|
||||
return x
|
||||
458
k_diffusion/utils.py
Normal file
458
k_diffusion/utils.py
Normal file
@ -0,0 +1,458 @@
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
from PIL import Image
|
||||
import safetensors
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.utils import data
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
|
||||
def from_pil_image(x):
|
||||
"""Converts from a PIL image to a tensor."""
|
||||
x = TF.to_tensor(x)
|
||||
if x.ndim == 2:
|
||||
x = x[..., None]
|
||||
return x * 2 - 1
|
||||
|
||||
|
||||
def to_pil_image(x):
|
||||
"""Converts from a tensor to a PIL image."""
|
||||
if x.ndim == 4:
|
||||
assert x.shape[0] == 1
|
||||
x = x[0]
|
||||
if x.shape[0] == 1:
|
||||
x = x[0]
|
||||
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
|
||||
|
||||
|
||||
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
|
||||
"""Apply passed in transforms for HuggingFace Datasets."""
|
||||
images = [transform(image.convert(mode)) for image in examples[image_key]]
|
||||
return {image_key: images}
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def n_params(module):
|
||||
"""Returns the number of trainable parameters in a module."""
|
||||
return sum(p.numel() for p in module.parameters())
|
||||
|
||||
|
||||
def download_file(path, url, digest=None):
|
||||
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not path.exists():
|
||||
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
|
||||
shutil.copyfileobj(response, f)
|
||||
if digest is not None:
|
||||
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
|
||||
if digest != file_digest:
|
||||
raise OSError(f'hash of {path} (url: {url}) failed to validate')
|
||||
return path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def train_mode(model, mode=True):
|
||||
"""A context manager that places a model into training mode and restores
|
||||
the previous mode on exit."""
|
||||
modes = [module.training for module in model.modules()]
|
||||
try:
|
||||
yield model.train(mode)
|
||||
finally:
|
||||
for i, module in enumerate(model.modules()):
|
||||
module.training = modes[i]
|
||||
|
||||
|
||||
def eval_mode(model):
|
||||
"""A context manager that places a model into evaluation mode and restores
|
||||
the previous mode on exit."""
|
||||
return train_mode(model, False)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(model, averaged_model, decay):
|
||||
"""Incorporates updated model parameters into an exponential moving averaged
|
||||
version of a model. It should be called after each optimizer step."""
|
||||
model_params = dict(model.named_parameters())
|
||||
averaged_params = dict(averaged_model.named_parameters())
|
||||
assert model_params.keys() == averaged_params.keys()
|
||||
|
||||
for name, param in model_params.items():
|
||||
averaged_params[name].lerp_(param, 1 - decay)
|
||||
|
||||
model_buffers = dict(model.named_buffers())
|
||||
averaged_buffers = dict(averaged_model.named_buffers())
|
||||
assert model_buffers.keys() == averaged_buffers.keys()
|
||||
|
||||
for name, buf in model_buffers.items():
|
||||
averaged_buffers[name].copy_(buf)
|
||||
|
||||
|
||||
class EMAWarmup:
|
||||
"""Implements an EMA warmup using an inverse decay schedule.
|
||||
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
|
||||
good values for models you plan to train for a million or more steps (reaches decay
|
||||
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
|
||||
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||
215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
max_value (float): The maximum EMA decay rate. Default: 1.
|
||||
start_at (int): The epoch to start averaging at. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
|
||||
last_epoch=0):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.start_at = start_at
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the class as a :class:`dict`."""
|
||||
return dict(self.__dict__.items())
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the class's state.
|
||||
Args:
|
||||
state_dict (dict): scaler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_value(self):
|
||||
"""Gets the current EMA decay rate."""
|
||||
epoch = max(0, self.last_epoch - self.start_at)
|
||||
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
|
||||
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
|
||||
|
||||
def step(self):
|
||||
"""Updates the step count."""
|
||||
self.last_epoch += 1
|
||||
|
||||
|
||||
class InverseLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an inverse decay learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr.
|
||||
inv_gamma is the number of steps/epochs required for the learning rate to decay to
|
||||
(1 / 2)**power of its original value.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
|
||||
power (float): Exponential factor of learning rate decay. Default: 1.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
|
||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class ExponentialLR(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements an exponential learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
|
||||
continuously by decay (default 0.5) every num_steps steps.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
num_steps (float): The number of steps to decay the learning rate by decay in.
|
||||
decay (float): The factor by which to decay the learning rate every num_steps
|
||||
steps. Default: 0.5.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
min_lr (float): The minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
|
||||
last_epoch=-1, verbose=False):
|
||||
self.num_steps = num_steps
|
||||
self.decay = decay
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
|
||||
return [warmup * max(self.min_lr, base_lr * lr_mult)
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class ConstantLRWithWarmup(optim.lr_scheduler._LRScheduler):
|
||||
"""Implements a constant learning rate schedule with an optional exponential
|
||||
warmup. When last_epoch=-1, sets initial lr as lr.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
|
||||
Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup=0., last_epoch=-1, verbose=False):
|
||||
if not 0. <= warmup < 1:
|
||||
raise ValueError('Invalid value for warmup')
|
||||
self.warmup = warmup
|
||||
super().__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.")
|
||||
|
||||
return self._get_closed_form_lr()
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
warmup = 1 - self.warmup ** (self.last_epoch + 1)
|
||||
return [warmup * base_lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
|
||||
"""Draws stratified samples from a uniform distribution."""
|
||||
if groups <= 0:
|
||||
raise ValueError(f"groups must be positive, got {groups}")
|
||||
if group < 0 or group >= groups:
|
||||
raise ValueError(f"group must be in [0, {groups})")
|
||||
n = shape[-1] * groups
|
||||
offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
|
||||
u = torch.rand(shape, dtype=dtype, device=device)
|
||||
return (offsets + u) / n
|
||||
|
||||
|
||||
stratified_settings = threading.local()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def enable_stratified(group=0, groups=1, disable=False):
|
||||
"""A context manager that enables stratified sampling."""
|
||||
try:
|
||||
stratified_settings.disable = disable
|
||||
stratified_settings.group = group
|
||||
stratified_settings.groups = groups
|
||||
yield
|
||||
finally:
|
||||
del stratified_settings.disable
|
||||
del stratified_settings.group
|
||||
del stratified_settings.groups
|
||||
|
||||
|
||||
@contextmanager
|
||||
def enable_stratified_accelerate(accelerator, disable=False):
|
||||
"""A context manager that enables stratified sampling, distributing the strata across
|
||||
all processes and gradient accumulation steps using settings from Hugging Face Accelerate."""
|
||||
try:
|
||||
rank = accelerator.process_index
|
||||
world_size = accelerator.num_processes
|
||||
acc_steps = accelerator.gradient_state.num_steps
|
||||
acc_step = accelerator.step % acc_steps
|
||||
group = rank * acc_steps + acc_step
|
||||
groups = world_size * acc_steps
|
||||
with enable_stratified(group, groups, disable=disable):
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def stratified_with_settings(shape, dtype=None, device=None):
|
||||
"""Draws stratified samples from a uniform distribution, using settings from a context
|
||||
manager."""
|
||||
if not hasattr(stratified_settings, 'disable') or stratified_settings.disable:
|
||||
return torch.rand(shape, dtype=dtype, device=device)
|
||||
return stratified_uniform(
|
||||
shape, stratified_settings.group, stratified_settings.groups, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
|
||||
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
u = stratified_with_settings(shape, device=device, dtype=dtype) * (1 - 2e-7) + 1e-7
|
||||
return torch.distributions.Normal(loc, scale).icdf(u).exp()
|
||||
|
||||
|
||||
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an optionally truncated log-logistic distribution."""
|
||||
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
|
||||
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
|
||||
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
|
||||
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
|
||||
u = stratified_with_settings(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
|
||||
return u.logit().mul(scale).add(loc).exp().to(dtype)
|
||||
|
||||
|
||||
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an log-uniform distribution."""
|
||||
min_value = math.log(min_value)
|
||||
max_value = math.log(max_value)
|
||||
return (stratified_with_settings(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
|
||||
|
||||
|
||||
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from a truncated v-diffusion training timestep distribution."""
|
||||
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
|
||||
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
|
||||
u = stratified_with_settings(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
|
||||
return torch.tan(u * math.pi / 2) * sigma_data
|
||||
|
||||
|
||||
def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
|
||||
|
||||
def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
|
||||
t_min = math.atan(math.exp(-0.5 * logsnr_max))
|
||||
t_max = math.atan(math.exp(-0.5 * logsnr_min))
|
||||
return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
|
||||
|
||||
def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
|
||||
shift = 2 * math.log(noise_d / image_d)
|
||||
return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
|
||||
|
||||
def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
|
||||
logsnr_low = logsnr_schedule_cosine_shifted(t, image_d, noise_d_low, logsnr_min, logsnr_max)
|
||||
logsnr_high = logsnr_schedule_cosine_shifted(t, image_d, noise_d_high, logsnr_min, logsnr_max)
|
||||
return torch.lerp(logsnr_low, logsnr_high, t)
|
||||
|
||||
logsnr_min = -2 * math.log(min_value / sigma_data)
|
||||
logsnr_max = -2 * math.log(max_value / sigma_data)
|
||||
u = stratified_with_settings(shape, device=device, dtype=dtype)
|
||||
logsnr = logsnr_schedule_cosine_interpolated(u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
|
||||
return torch.exp(-logsnr / 2) * sigma_data
|
||||
|
||||
|
||||
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
|
||||
"""Draws samples from a split lognormal distribution."""
|
||||
n = torch.randn(shape, device=device, dtype=dtype).abs()
|
||||
u = torch.rand(shape, device=device, dtype=dtype)
|
||||
n_left = n * -scale_1 + loc
|
||||
n_right = n * scale_2 + loc
|
||||
ratio = scale_1 / (scale_1 + scale_2)
|
||||
return torch.where(u < ratio, n_left, n_right).exp()
|
||||
|
||||
|
||||
class FolderOfImages(data.Dataset):
|
||||
"""Recursively finds all images in a directory. It does not support
|
||||
classes/targets."""
|
||||
|
||||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
|
||||
|
||||
def __init__(self, root, transform=None):
|
||||
super().__init__()
|
||||
self.root = Path(root)
|
||||
self.transform = nn.Identity() if transform is None else transform
|
||||
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
|
||||
|
||||
def __repr__(self):
|
||||
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, key):
|
||||
path = self.paths[key]
|
||||
with open(path, 'rb') as f:
|
||||
image = Image.open(f).convert('RGB')
|
||||
image = self.transform(image)
|
||||
return image,
|
||||
|
||||
|
||||
class CSVLogger:
|
||||
def __init__(self, filename, columns):
|
||||
self.filename = Path(filename)
|
||||
self.columns = columns
|
||||
if self.filename.exists():
|
||||
self.file = open(self.filename, 'a')
|
||||
else:
|
||||
self.file = open(self.filename, 'w')
|
||||
self.write(*self.columns)
|
||||
|
||||
def write(self, *args):
|
||||
print(*args, sep=',', file=self.file, flush=True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tf32_mode(cudnn=None, matmul=None):
|
||||
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
|
||||
cudnn_old = torch.backends.cudnn.allow_tf32
|
||||
matmul_old = torch.backends.cuda.matmul.allow_tf32
|
||||
try:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul
|
||||
yield
|
||||
finally:
|
||||
if cudnn is not None:
|
||||
torch.backends.cudnn.allow_tf32 = cudnn_old
|
||||
if matmul is not None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = matmul_old
|
||||
|
||||
|
||||
def get_safetensors_metadata(path):
|
||||
"""Retrieves the metadata from a safetensors file."""
|
||||
return safetensors.safe_open(path, "pt").metadata()
|
||||
|
||||
|
||||
def ema_update_dict(values, updates, decay):
|
||||
for k, v in updates.items():
|
||||
if k not in values:
|
||||
values[k] = v
|
||||
else:
|
||||
values[k] *= decay
|
||||
values[k] += (1 - decay) * v
|
||||
return values
|
||||
@ -393,14 +393,14 @@ def prepare_environment():
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
# k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
# k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "3f96b28763515dbe609792135df3615a440c66dc")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
@ -458,7 +458,7 @@ def prepare_environment():
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
# git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ sd_path = os.path.dirname(__file__)
|
||||
|
||||
path_dirs = [
|
||||
(os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
# (os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||
]
|
||||
|
||||
|
||||
@ -59,6 +59,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
|
||||
self.need_last_noise_uncond = False
|
||||
self.last_noise_uncond = None
|
||||
|
||||
# Backward Compatibility
|
||||
self.mask_before_denoising = False
|
||||
|
||||
@ -179,7 +182,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
|
||||
denoised = sampling_function(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition)
|
||||
denoised, cond_pred, uncond_pred = sampling_function(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition)
|
||||
|
||||
if self.need_last_noise_uncond:
|
||||
self.last_noise_uncond = (x - uncond_pred) / sigma[:, None, None, None]
|
||||
|
||||
if self.mask is not None:
|
||||
blended_latent = denoised * self.nmask + self.init_latent * self.mask
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import k_diffusion.external
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
@ -55,13 +56,11 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
|
||||
|
||||
if denoiser_constructor is not None:
|
||||
self.model_wrap = denoiser_constructor()
|
||||
else:
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.model_wrap = k_diffusion.external.DiscreteSchedule(
|
||||
sigmas=shared.sd_model.forge_objects.unet.model.predictor.sigmas,
|
||||
quantize=shared.opts.enable_quantization
|
||||
)
|
||||
self.model_wrap.inner_model = shared.sd_model
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
@ -13,9 +13,10 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
|
||||
self.skip_steps = timesteps // original_timesteps
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||
alphas_cumprod = 1.0 / (model.forge_objects.unet.model.predictor.sigmas ** 2.0 + 1.0)
|
||||
alphas_cumprod_valid = torch.zeros(original_timesteps, dtype=torch.float32)
|
||||
for x in range(original_timesteps):
|
||||
alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
super().__init__(model, alphas_cumprod_valid, quantize=None)
|
||||
|
||||
|
||||
@ -28,31 +28,18 @@ class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
self.inner_model.alphas_cumprod = 1.0 / (self.inner_model.forge_objects.unet.model.predictor.sigmas ** 2.0 + 1.0)
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
|
||||
|
||||
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||
return e_t
|
||||
|
||||
|
||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
|
||||
def __init__(self, sampler):
|
||||
super().__init__(sampler)
|
||||
|
||||
self.alphas = shared.sd_model.alphas_cumprod
|
||||
self.alphas = 1.0 / (shared.sd_model.forge_objects.unet.model.predictor.sigmas ** 2.0 + 1.0)
|
||||
self.classic_ddim_eps_estimation = True
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
@ -69,8 +56,7 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model)
|
||||
self.model_wrap = CompVisTimestepsDenoiser(shared.sd_model)
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user