mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-28 05:35:00 +00:00
Multithreading softinpainting (#2927)
This commit is contained in:
parent
963e7643f0
commit
715c24b0e2
@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
from modules.torch_utils import float64
|
from modules.torch_utils import float64
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from scipy.ndimage import convolve
|
||||||
|
from joblib import Parallel, delayed, cpu_count
|
||||||
|
|
||||||
class SoftInpaintingSettings:
|
class SoftInpaintingSettings:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -244,7 +247,76 @@ def apply_masks(
|
|||||||
return masks_for_overlay
|
return masks_for_overlay
|
||||||
|
|
||||||
|
|
||||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
|
||||||
|
|
||||||
|
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
|
||||||
|
"""
|
||||||
|
Apply the weighted histogram filter to a single pixel.
|
||||||
|
This function is now refactored to be accessible for parallelization.
|
||||||
|
"""
|
||||||
|
idx = np.array(idx)
|
||||||
|
kernel_min = -kernel_center
|
||||||
|
kernel_max = np.array(kernel.shape) - kernel_center
|
||||||
|
|
||||||
|
# Precompute the minimum and maximum valid indices for the kernel
|
||||||
|
min_index = np.maximum(0, idx + kernel_min)
|
||||||
|
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
|
||||||
|
window_shape = max_index - min_index
|
||||||
|
|
||||||
|
# Initialize values and weights arrays
|
||||||
|
values = []
|
||||||
|
weights = []
|
||||||
|
|
||||||
|
for window_tup in np.ndindex(*window_shape):
|
||||||
|
window_index = np.array(window_tup)
|
||||||
|
image_index = window_index + min_index
|
||||||
|
centered_kernel_index = image_index - idx
|
||||||
|
kernel_index = centered_kernel_index + kernel_center
|
||||||
|
values.append(img[tuple(image_index)])
|
||||||
|
weights.append(kernel[tuple(kernel_index)])
|
||||||
|
|
||||||
|
# Convert to NumPy arrays
|
||||||
|
values = np.array(values)
|
||||||
|
weights = np.array(weights)
|
||||||
|
|
||||||
|
# Sort values and weights by values
|
||||||
|
sorted_indices = np.argsort(values)
|
||||||
|
values = values[sorted_indices]
|
||||||
|
weights = weights[sorted_indices]
|
||||||
|
|
||||||
|
# Calculate cumulative weights
|
||||||
|
cumulative_weights = np.cumsum(weights)
|
||||||
|
|
||||||
|
# Define window boundaries
|
||||||
|
sum_weights = cumulative_weights[-1]
|
||||||
|
window_min = sum_weights * percentile_min
|
||||||
|
window_max = sum_weights * percentile_max
|
||||||
|
window_width = window_max - window_min
|
||||||
|
|
||||||
|
# Ensure window is at least `min_width` wide
|
||||||
|
if window_width < min_width:
|
||||||
|
window_center = (window_min + window_max) / 2
|
||||||
|
window_min = window_center - min_width / 2
|
||||||
|
window_max = window_center + min_width / 2
|
||||||
|
|
||||||
|
if window_max > sum_weights:
|
||||||
|
window_max = sum_weights
|
||||||
|
window_min = sum_weights - min_width
|
||||||
|
|
||||||
|
if window_min < 0:
|
||||||
|
window_min = 0
|
||||||
|
window_max = min_width
|
||||||
|
|
||||||
|
# Calculate overlap for each value
|
||||||
|
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||||
|
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||||
|
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||||
|
|
||||||
|
# Weighted average calculation
|
||||||
|
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||||
|
return result
|
||||||
|
|
||||||
|
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
|
||||||
"""
|
"""
|
||||||
Generalization convolution filter capable of applying
|
Generalization convolution filter capable of applying
|
||||||
weighted mean, median, maximum, and minimum filters
|
weighted mean, median, maximum, and minimum filters
|
||||||
@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
|
|||||||
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Converts an index tuple into a vector.
|
# Ensure kernel_center is a 1D array
|
||||||
def vec(x):
|
if isinstance(kernel_center, int):
|
||||||
return np.array(x)
|
kernel_center = np.array([kernel_center, kernel_center])
|
||||||
|
elif len(kernel_center) == 1:
|
||||||
kernel_min = -kernel_center
|
kernel_center = np.array([kernel_center[0], kernel_center[0]])
|
||||||
kernel_max = vec(kernel.shape) - kernel_center
|
kernel_radius = max(kernel_center)
|
||||||
|
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
|
||||||
|
img_out = np.zeros_like(img)
|
||||||
|
img_shape = img.shape
|
||||||
|
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
|
||||||
|
|
||||||
def weighted_histogram_filter_single(idx):
|
def weighted_histogram_filter_single(idx):
|
||||||
idx = vec(idx)
|
"""
|
||||||
min_index = np.maximum(0, idx + kernel_min)
|
Single-pixel weighted histogram calculation.
|
||||||
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
"""
|
||||||
window_shape = max_index - min_index
|
row, col = idx
|
||||||
|
idx = (row + kernel_radius, col + kernel_radius)
|
||||||
|
min_index = np.array(idx) - kernel_center
|
||||||
|
max_index = min_index + kernel.shape
|
||||||
|
|
||||||
class WeightedElement:
|
window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
|
||||||
"""
|
window_values = window.flatten()
|
||||||
An element of the histogram, its weight
|
window_weights = kernel.flatten()
|
||||||
and bounds.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, value, weight):
|
sorted_indices = np.argsort(window_values)
|
||||||
self.value: float = value
|
values = window_values[sorted_indices]
|
||||||
self.weight: float = weight
|
weights = window_weights[sorted_indices]
|
||||||
self.window_min: float = 0.0
|
|
||||||
self.window_max: float = 1.0
|
|
||||||
|
|
||||||
# Collect the values in the image as WeightedElements,
|
cumulative_weights = np.cumsum(weights)
|
||||||
# weighted by their corresponding kernel values.
|
sum_weights = cumulative_weights[-1]
|
||||||
values = []
|
window_min = max(0, sum_weights * percentile_min)
|
||||||
for window_tup in np.ndindex(tuple(window_shape)):
|
window_max = min(sum_weights, sum_weights * percentile_max)
|
||||||
window_index = vec(window_tup)
|
|
||||||
image_index = window_index + min_index
|
|
||||||
centered_kernel_index = image_index - idx
|
|
||||||
kernel_index = centered_kernel_index + kernel_center
|
|
||||||
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
|
||||||
values.append(element)
|
|
||||||
|
|
||||||
def sort_key(x: WeightedElement):
|
|
||||||
return x.value
|
|
||||||
|
|
||||||
values.sort(key=sort_key)
|
|
||||||
|
|
||||||
# Calculate the height of the stack (sum)
|
|
||||||
# and each sample's range they occupy in the stack
|
|
||||||
sum = 0
|
|
||||||
for i in range(len(values)):
|
|
||||||
values[i].window_min = sum
|
|
||||||
sum += values[i].weight
|
|
||||||
values[i].window_max = sum
|
|
||||||
|
|
||||||
# Calculate what range of this stack ("window")
|
|
||||||
# we want to get the weighted average across.
|
|
||||||
window_min = sum * percentile_min
|
|
||||||
window_max = sum * percentile_max
|
|
||||||
window_width = window_max - window_min
|
window_width = window_max - window_min
|
||||||
|
|
||||||
# Ensure the window is within the stack and at least a certain size.
|
|
||||||
if window_width < min_width:
|
if window_width < min_width:
|
||||||
window_center = (window_min + window_max) / 2
|
window_center = (window_min + window_max) / 2
|
||||||
window_min = window_center - min_width / 2
|
window_min = max(0, window_center - min_width / 2)
|
||||||
window_max = window_center + min_width / 2
|
window_max = min(sum_weights, window_center + min_width / 2)
|
||||||
|
|
||||||
if window_max > sum:
|
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||||
window_max = sum
|
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||||
window_min = sum - min_width
|
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||||
|
|
||||||
if window_min < 0:
|
return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||||
window_min = 0
|
|
||||||
window_max = min_width
|
|
||||||
|
|
||||||
value = 0
|
# Split pixel_coords into equal chunks based on n_jobs
|
||||||
value_weight = 0
|
n_jobs = -1
|
||||||
|
if cpu_count() > 6:
|
||||||
|
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
|
||||||
|
|
||||||
# Get the weighted average of all the samples
|
chunk_size = len(pixel_coords) // n_jobs
|
||||||
# that overlap with the window, weighted
|
pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
|
||||||
# by the size of their overlap.
|
|
||||||
for i in range(len(values)):
|
|
||||||
if window_min >= values[i].window_max:
|
|
||||||
continue
|
|
||||||
if window_max <= values[i].window_min:
|
|
||||||
break
|
|
||||||
|
|
||||||
s = max(window_min, values[i].window_min)
|
# joblib to process chunks in parallel
|
||||||
e = min(window_max, values[i].window_max)
|
def process_chunk(chunk):
|
||||||
w = e - s
|
chunk_result = {}
|
||||||
|
for idx in chunk:
|
||||||
|
chunk_result[idx] = weighted_histogram_filter_single(idx)
|
||||||
|
return chunk_result
|
||||||
|
|
||||||
value += values[i].value * w
|
results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
|
||||||
value_weight += w
|
delayed(process_chunk)(chunk) for chunk in pixel_chunks
|
||||||
|
)
|
||||||
|
|
||||||
return value / value_weight if value_weight != 0 else 0
|
# Combine results into the output image
|
||||||
|
for chunk_result in results:
|
||||||
img_out = img.copy()
|
for (row, col), value in chunk_result.items():
|
||||||
|
img_out[row, col] = value
|
||||||
# Apply the kernel operation over each pixel.
|
|
||||||
for index in np.ndindex(img.shape):
|
|
||||||
img_out[index] = weighted_histogram_filter_single(index)
|
|
||||||
|
|
||||||
return img_out
|
return img_out
|
||||||
|
|
||||||
@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings(
|
|||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# self.section = "inpaint"
|
self.section = "inpaint"
|
||||||
self.masks_for_overlay = None
|
self.masks_for_overlay = None
|
||||||
self.overlay_images = None
|
self.overlay_images = None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user