diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py new file mode 100644 index 00000000..90060c43 --- /dev/null +++ b/scripts/postprocessing_focal_crop.py @@ -0,0 +1,56 @@ + +from modules import scripts_postprocessing, ui_components, errors +import gradio as gr + +from modules.textual_inversion import autocrop + + +class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto focal point crop" + order = 4010 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: + face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") + entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") + edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") + debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + return { + "enable": enable, + "face_weight": face_weight, + "entropy_weight": entropy_weight, + "edges_weight": edges_weight, + "debug": debug, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): + if not enable: + return + + if not pp.shared.target_width or not pp.shared.target_height: + return + + pp.image = pp.image.convert('RGB') + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models() + except Exception: + errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) + + autocrop_settings = autocrop.Settings( + crop_width=pp.shared.target_width, + crop_height=pp.shared.target_height, + face_points_weight=face_weight, + entropy_points_weight=entropy_weight, + corner_points_weight=edges_weight, + annotate_image=debug, + dnn_model_path=dnn_model_path, + ) + + result, *others = autocrop.crop_image(pp.image, autocrop_settings) + + pp.image = result + pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] +