Multithreading softinpainting (#2927)

This commit is contained in:
NEON 2025-06-24 01:39:51 +03:00 committed by GitHub
parent 963e7643f0
commit 715c24b0e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,9 @@ from modules.ui_components import InputAccordion
import modules.scripts as scripts import modules.scripts as scripts
from modules.torch_utils import float64 from modules.torch_utils import float64
from concurrent.futures import ThreadPoolExecutor
from scipy.ndimage import convolve
from joblib import Parallel, delayed, cpu_count
class SoftInpaintingSettings: class SoftInpaintingSettings:
def __init__(self, def __init__(self,
@ -244,7 +247,76 @@ def apply_masks(
return masks_for_overlay return masks_for_overlay
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
def weighted_histogram_filter_single_pixel(idx, img, kernel, kernel_center, percentile_min, percentile_max, min_width):
"""
Apply the weighted histogram filter to a single pixel.
This function is now refactored to be accessible for parallelization.
"""
idx = np.array(idx)
kernel_min = -kernel_center
kernel_max = np.array(kernel.shape) - kernel_center
# Precompute the minimum and maximum valid indices for the kernel
min_index = np.maximum(0, idx + kernel_min)
max_index = np.minimum(np.array(img.shape), idx + kernel_max)
window_shape = max_index - min_index
# Initialize values and weights arrays
values = []
weights = []
for window_tup in np.ndindex(*window_shape):
window_index = np.array(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
values.append(img[tuple(image_index)])
weights.append(kernel[tuple(kernel_index)])
# Convert to NumPy arrays
values = np.array(values)
weights = np.array(weights)
# Sort values and weights by values
sorted_indices = np.argsort(values)
values = values[sorted_indices]
weights = weights[sorted_indices]
# Calculate cumulative weights
cumulative_weights = np.cumsum(weights)
# Define window boundaries
sum_weights = cumulative_weights[-1]
window_min = sum_weights * percentile_min
window_max = sum_weights * percentile_max
window_width = window_max - window_min
# Ensure window is at least `min_width` wide
if window_width < min_width:
window_center = (window_min + window_max) / 2
window_min = window_center - min_width / 2
window_max = window_center + min_width / 2
if window_max > sum_weights:
window_max = sum_weights
window_min = sum_weights - min_width
if window_min < 0:
window_min = 0
window_max = min_width
# Calculate overlap for each value
overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
overlap_end = np.minimum(window_max, cumulative_weights)
overlap = np.maximum(0, overlap_end - overlap_start)
# Weighted average calculation
result = np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
return result
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0, n_jobs=-1):
""" """
Generalization convolution filter capable of applying Generalization convolution filter capable of applying
weighted mean, median, maximum, and minimum filters weighted mean, median, maximum, and minimum filters
@ -271,101 +343,74 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe
(nparray): A filtered copy of the input image "img", a 2-D array of floats. (nparray): A filtered copy of the input image "img", a 2-D array of floats.
""" """
# Converts an index tuple into a vector. # Ensure kernel_center is a 1D array
def vec(x): if isinstance(kernel_center, int):
return np.array(x) kernel_center = np.array([kernel_center, kernel_center])
elif len(kernel_center) == 1:
kernel_min = -kernel_center kernel_center = np.array([kernel_center[0], kernel_center[0]])
kernel_max = vec(kernel.shape) - kernel_center kernel_radius = max(kernel_center)
padded_img = np.pad(img, kernel_radius, mode='constant', constant_values=0)
img_out = np.zeros_like(img)
img_shape = img.shape
pixel_coords = [(i, j) for i in range(img_shape[0]) for j in range(img_shape[1])]
def weighted_histogram_filter_single(idx): def weighted_histogram_filter_single(idx):
idx = vec(idx) """
min_index = np.maximum(0, idx + kernel_min) Single-pixel weighted histogram calculation.
max_index = np.minimum(vec(img.shape), idx + kernel_max) """
window_shape = max_index - min_index row, col = idx
idx = (row + kernel_radius, col + kernel_radius)
min_index = np.array(idx) - kernel_center
max_index = min_index + kernel.shape
class WeightedElement: window = padded_img[min_index[0]:max_index[0], min_index[1]:max_index[1]]
""" window_values = window.flatten()
An element of the histogram, its weight window_weights = kernel.flatten()
and bounds.
"""
def __init__(self, value, weight): sorted_indices = np.argsort(window_values)
self.value: float = value values = window_values[sorted_indices]
self.weight: float = weight weights = window_weights[sorted_indices]
self.window_min: float = 0.0
self.window_max: float = 1.0
# Collect the values in the image as WeightedElements, cumulative_weights = np.cumsum(weights)
# weighted by their corresponding kernel values. sum_weights = cumulative_weights[-1]
values = [] window_min = max(0, sum_weights * percentile_min)
for window_tup in np.ndindex(tuple(window_shape)): window_max = min(sum_weights, sum_weights * percentile_max)
window_index = vec(window_tup)
image_index = window_index + min_index
centered_kernel_index = image_index - idx
kernel_index = centered_kernel_index + kernel_center
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
values.append(element)
def sort_key(x: WeightedElement):
return x.value
values.sort(key=sort_key)
# Calculate the height of the stack (sum)
# and each sample's range they occupy in the stack
sum = 0
for i in range(len(values)):
values[i].window_min = sum
sum += values[i].weight
values[i].window_max = sum
# Calculate what range of this stack ("window")
# we want to get the weighted average across.
window_min = sum * percentile_min
window_max = sum * percentile_max
window_width = window_max - window_min window_width = window_max - window_min
# Ensure the window is within the stack and at least a certain size.
if window_width < min_width: if window_width < min_width:
window_center = (window_min + window_max) / 2 window_center = (window_min + window_max) / 2
window_min = window_center - min_width / 2 window_min = max(0, window_center - min_width / 2)
window_max = window_center + min_width / 2 window_max = min(sum_weights, window_center + min_width / 2)
if window_max > sum: overlap_start = np.maximum(window_min, np.concatenate(([0], cumulative_weights[:-1])))
window_max = sum overlap_end = np.minimum(window_max, cumulative_weights)
window_min = sum - min_width overlap = np.maximum(0, overlap_end - overlap_start)
if window_min < 0: return np.sum(values * overlap) / np.sum(overlap) if np.sum(overlap) > 0 else 0
window_min = 0
window_max = min_width
value = 0 # Split pixel_coords into equal chunks based on n_jobs
value_weight = 0 n_jobs = -1
if cpu_count() > 6:
n_jobs = 6 # More than 6 isn't worth unless it's more than 3000x3000px
# Get the weighted average of all the samples chunk_size = len(pixel_coords) // n_jobs
# that overlap with the window, weighted pixel_chunks = [pixel_coords[i:i + chunk_size] for i in range(0, len(pixel_coords), chunk_size)]
# by the size of their overlap.
for i in range(len(values)):
if window_min >= values[i].window_max:
continue
if window_max <= values[i].window_min:
break
s = max(window_min, values[i].window_min) # joblib to process chunks in parallel
e = min(window_max, values[i].window_max) def process_chunk(chunk):
w = e - s chunk_result = {}
for idx in chunk:
chunk_result[idx] = weighted_histogram_filter_single(idx)
return chunk_result
value += values[i].value * w results = Parallel(n_jobs=n_jobs, backend="loky")( # loky is fastest in my configuration
value_weight += w delayed(process_chunk)(chunk) for chunk in pixel_chunks
)
return value / value_weight if value_weight != 0 else 0 # Combine results into the output image
for chunk_result in results:
img_out = img.copy() for (row, col), value in chunk_result.items():
img_out[row, col] = value
# Apply the kernel operation over each pixel.
for index in np.ndindex(img.shape):
img_out[index] = weighted_histogram_filter_single(index)
return img_out return img_out
@ -485,7 +530,7 @@ el_ids = SoftInpaintingSettings(
class Script(scripts.Script): class Script(scripts.Script):
def __init__(self): def __init__(self):
# self.section = "inpaint" self.section = "inpaint"
self.masks_for_overlay = None self.masks_for_overlay = None
self.overlay_images = None self.overlay_images = None