diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 0071cc87..257ad313 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -7,6 +7,7 @@ from torchdiffeq import odeint import torchsde from tqdm.auto import trange, tqdm from k_diffusion import deis +from backend.modules.k_prediction import PredictionFlux from . import utils @@ -139,6 +140,8 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" + if isinstance(model.inner_model.predictor, PredictionFlux): + return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) @@ -155,6 +158,32 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x +@torch.no_grad() +def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + #sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + if sigmas[i + 1] == 0: + x = denoised + else: + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5 + # Euler method + sigma_down_i_ratio = sigma_down / sigmas[i] + x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised + if eta > 0: + x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + return x @torch.no_grad() def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): @@ -219,6 +248,8 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver second-order steps.""" + if isinstance(model.inner_model.predictor, PredictionFlux): + return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) @@ -244,6 +275,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x +@torch.no_grad() +def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta + sigma_down = sigmas[i+1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i+1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + if sigma_down == 0: + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + return x def linear_multistep_coeff(order, t, i, j): if order - 1 > i: