mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-27 21:26:07 +00:00
Multithreading softinpainting (#2927)
This commit is contained in:
parent
963e7643f0
commit
715c24b0e2
@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion
|
||||
import modules.scripts as scripts
|
||||
from modules.torch_utils import float64
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from scipy.ndimage import convolve
|
||||
from joblib import Parallel, delayed, cpu_count
|
||||
|
||||
class SoftInpaintingSettings:
|
||||
def __init__(self,
|
||||
@ -244,7 +247,76 @@ def apply_masks(
|
||||
return masks_for_overlay
|
||||
|
||||
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
||||
|
||||
|
||||
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
|
||||
"""
|
||||
Apply the weighted histogram filter to a single pixel.
|
||||
This function is now refactored to be accessible for parallelization.
|
||||
"""
|
||||
idx = np.array(idx)
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = np.array(kernel.shape) - kernel_center
|
||||
|
||||
# Precompute the minimum and maximum valid indices for the kernel
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
|
||||
# Initialize values and weights arrays
|
||||
values = []
|
||||
weights = []
|
||||
|
||||
for window_tup in np.ndindex(*window_shape):
|
||||
window_index = np.array(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
values.append(img[tuple(image_index)])
|
||||
weights.append(kernel[tuple(kernel_index)])
|
||||
|
||||
# Convert to NumPy arrays
|
||||
values = np.array(values)
|
||||
weights = np.array(weights)
|
||||
|
||||
# Sort values and weights by values
|
||||
sorted_indices = np.argsort(values)
|
||||
values = values[sorted_indices]
|
||||
weights = weights[sorted_indices]
|
||||
|
||||
# Calculate cumulative weights
|
||||
cumulative_weights = np.cumsum(weights)
|
||||
|
||||
# Define window boundaries
|
||||
sum_weights = cumulative_weights[-1]
|
||||
window_min = sum_weights * percentile_min
|
||||
window_max = sum_weights * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure window is at least `min_width` wide
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
|
||||
if window_max > sum_weights:
|
||||
window_max = sum_weights
|
||||
window_min = sum_weights - min_width
|
||||
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
|
||||
# Calculate overlap for each value
|
||||
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||
|
||||
# Weighted average calculation
|
||||
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||
return result
|
||||
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
|
||||
"""
|
||||
Generalization convolution filter capable of applying
|
||||
weighted mean, median, maximum, and minimum filters
|
||||
@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
|
||||
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||
"""
|
||||
|
||||
# Converts an index tuple into a vector.
|
||||
def vec(x):
|
||||
return np.array(x)
|
||||
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = vec(kernel.shape) - kernel_center
|
||||
# Ensure kernel_center is a 1D array
|
||||
if isinstance(kernel_center, int):
|
||||
kernel_center = np.array([kernel_center, kernel_center])
|
||||
elif len(kernel_center) == 1:
|
||||
kernel_center = np.array([kernel_center[0], kernel_center[0]])
|
||||
kernel_radius = max(kernel_center)
|
||||
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
|
||||
img_out = np.zeros_like(img)
|
||||
img_shape = img.shape
|
||||
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
|
||||
|
||||
def weighted_histogram_filter_single(idx):
|
||||
idx = vec(idx)
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
"""
|
||||
Single-pixel weighted histogram calculation.
|
||||
"""
|
||||
row, col = idx
|
||||
idx = (row + kernel_radius, col + kernel_radius)
|
||||
min_index = np.array(idx) - kernel_center
|
||||
max_index = min_index + kernel.shape
|
||||
|
||||
class WeightedElement:
|
||||
"""
|
||||
An element of the histogram, its weight
|
||||
and bounds.
|
||||
"""
|
||||
window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
|
||||
window_values = window.flatten()
|
||||
window_weights = kernel.flatten()
|
||||
|
||||
def __init__(self, value, weight):
|
||||
self.value: float = value
|
||||
self.weight: float = weight
|
||||
self.window_min: float = 0.0
|
||||
self.window_max: float = 1.0
|
||||
sorted_indices = np.argsort(window_values)
|
||||
values = window_values[sorted_indices]
|
||||
weights = window_weights[sorted_indices]
|
||||
|
||||
# Collect the values in the image as WeightedElements,
|
||||
# weighted by their corresponding kernel values.
|
||||
values = []
|
||||
for window_tup in np.ndindex(tuple(window_shape)):
|
||||
window_index = vec(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
||||
values.append(element)
|
||||
cumulative_weights = np.cumsum(weights)
|
||||
sum_weights = cumulative_weights[-1]
|
||||
window_min = max(0, sum_weights * percentile_min)
|
||||
window_max = min(sum_weights, sum_weights * percentile_max)
|
||||
|
||||
def sort_key(x: WeightedElement):
|
||||
return x.value
|
||||
|
||||
values.sort(key=sort_key)
|
||||
|
||||
# Calculate the height of the stack (sum)
|
||||
# and each sample's range they occupy in the stack
|
||||
sum = 0
|
||||
for i in range(len(values)):
|
||||
values[i].window_min = sum
|
||||
sum += values[i].weight
|
||||
values[i].window_max = sum
|
||||
|
||||
# Calculate what range of this stack ("window")
|
||||
# we want to get the weighted average across.
|
||||
window_min = sum * percentile_min
|
||||
window_max = sum * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure the window is within the stack and at least a certain size.
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
window_min = max(0, window_center - min_width / 2)
|
||||
window_max = min(sum_weights, window_center + min_width / 2)
|
||||
|
||||
if window_max > sum:
|
||||
window_max = sum
|
||||
window_min = sum - min_width
|
||||
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
|
||||
overlap_end = np.minimum(window_max, cumulative_weights)
|
||||
overlap = np.maximum(0, overlap_end - overlap_start)
|
||||
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
|
||||
|
||||
value = 0
|
||||
value_weight = 0
|
||||
# Split pixel_coords into equal chunks based on n_jobs
|
||||
n_jobs = -1
|
||||
if cpu_count() > 6:
|
||||
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
|
||||
|
||||
# Get the weighted average of all the samples
|
||||
# that overlap with the window, weighted
|
||||
# by the size of their overlap.
|
||||
for i in range(len(values)):
|
||||
if window_min >= values[i].window_max:
|
||||
continue
|
||||
if window_max <= values[i].window_min:
|
||||
break
|
||||
chunk_size = len(pixel_coords) // n_jobs
|
||||
pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
|
||||
|
||||
s = max(window_min, values[i].window_min)
|
||||
e = min(window_max, values[i].window_max)
|
||||
w = e - s
|
||||
# joblib to process chunks in parallel
|
||||
def process_chunk(chunk):
|
||||
chunk_result = {}
|
||||
for idx in chunk:
|
||||
chunk_result[idx] = weighted_histogram_filter_single(idx)
|
||||
return chunk_result
|
||||
|
||||
value += values[i].value * w
|
||||
value_weight += w
|
||||
results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
|
||||
delayed(process_chunk)(chunk) for chunk in pixel_chunks
|
||||
)
|
||||
|
||||
return value / value_weight if value_weight != 0 else 0
|
||||
|
||||
img_out = img.copy()
|
||||
|
||||
# Apply the kernel operation over each pixel.
|
||||
for index in np.ndindex(img.shape):
|
||||
img_out[index] = weighted_histogram_filter_single(index)
|
||||
# Combine results into the output image
|
||||
for chunk_result in results:
|
||||
for (row, col), value in chunk_result.items():
|
||||
img_out[row, col] = value
|
||||
|
||||
return img_out
|
||||
|
||||
@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings(
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
# self.section = "inpaint"
|
||||
self.section = "inpaint"
|
||||
self.masks_for_overlay = None
|
||||
self.overlay_images = None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user