mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-28 05:35:00 +00:00
[revised] change some dtype behaviors based on community feedbacks
only influence old devices like 1080/70/60/50. please remove cmd flags if you are on 1080/70/60/50 and previously used many cmd flags to tune performance
This commit is contained in:
parent
1419ef29aa
commit
4e3c78178a
@ -107,10 +107,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_size = memory_management.state_dict_size(state_dict)
|
||||
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||
state_dict_dtype = memory_management.state_dict_dtype(state_dict)
|
||||
|
||||
storage_dtype = memory_management.unet_dtype(model_params=state_dict_size, supported_dtypes=guess.supported_inference_dtypes)
|
||||
storage_dtype = memory_management.unet_dtype(model_params=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
||||
|
||||
unet_storage_dtype_overwrite = backend.args.dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
@ -140,15 +140,15 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
print(f'Using GGUF state dict: {type_counts}')
|
||||
|
||||
load_device = memory_management.get_torch_device()
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes)
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, parameters=state_dict_parameters, supported_dtypes=guess.supported_inference_dtypes)
|
||||
offload_device = memory_management.unet_offload_device()
|
||||
|
||||
if storage_dtype in ['nf4', 'fp4', 'gguf']:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype)
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=computation_dtype)
|
||||
with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype):
|
||||
model = model_loader(unet_config)
|
||||
else:
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=storage_dtype)
|
||||
initial_device = memory_management.unet_inital_load_device(parameters=state_dict_parameters, dtype=storage_dtype)
|
||||
need_manual_cast = storage_dtype != computation_dtype
|
||||
to_args = dict(device=initial_device, dtype=storage_dtype)
|
||||
|
||||
|
||||
@ -301,6 +301,13 @@ def state_dict_size(sd, exclude_device=None):
|
||||
return module_mem
|
||||
|
||||
|
||||
def state_dict_parameters(sd):
|
||||
module_mem = 0
|
||||
for k, v in sd.items():
|
||||
module_mem += v.nelement()
|
||||
return module_mem
|
||||
|
||||
|
||||
def state_dict_dtype(state_dict):
|
||||
for k, v in state_dict.items():
|
||||
if hasattr(v, 'is_gguf'):
|
||||
@ -653,44 +660,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
|
||||
for candidate in supported_dtypes:
|
||||
if candidate == torch.float16:
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
|
||||
return candidate
|
||||
if candidate == torch.bfloat16:
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
|
||||
return candidate
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
bf16_supported = should_use_bf16(inference_device)
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
for candidate in supported_dtypes:
|
||||
if candidate == torch.float16:
|
||||
if should_use_fp16(inference_device, prioritize_performance=False):
|
||||
if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
|
||||
return candidate
|
||||
if candidate == torch.bfloat16:
|
||||
if should_use_bf16(inference_device):
|
||||
if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
|
||||
return candidate
|
||||
|
||||
return torch.float32
|
||||
@ -1020,19 +1005,17 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if props.major < 6:
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
# when the model doesn't actually fit on the card
|
||||
# TODO: actually test if GP106 and others have the same type of behavior
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
|
||||
if fp16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
if manual_cast:
|
||||
# For storage dtype
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
else:
|
||||
# For computation dtype
|
||||
return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
|
||||
|
||||
if props.major < 7:
|
||||
return False
|
||||
@ -1077,12 +1060,14 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
if torch.cuda.is_bf16_supported():
|
||||
# This device is an old enough device but bf16 somewhat reports supported.
|
||||
# So in this case bf16 should only be used as storge dtype
|
||||
if manual_cast:
|
||||
# For storage dtype
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -1116,43 +1101,3 @@ def soft_empty_cache(force=False):
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
||||
return weight
|
||||
|
||||
|
||||
# TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
interrupt_processing_mutex = threading.RLock()
|
||||
|
||||
interrupt_processing = False
|
||||
|
||||
|
||||
def interrupt_current_processing(value=True):
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
interrupt_processing = value
|
||||
|
||||
|
||||
def processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
return interrupt_processing
|
||||
|
||||
|
||||
def throw_exception_if_processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
if interrupt_processing:
|
||||
interrupt_processing = False
|
||||
raise InterruptProcessingException()
|
||||
|
||||
@ -438,7 +438,7 @@ class ControlLora(ControlNet):
|
||||
|
||||
self.manual_cast_dtype = model.computation_dtype
|
||||
|
||||
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
|
||||
with using_forge_operations(operations=ControlLoraOps, dtype=dtype, manual_cast_enabled=self.manual_cast_dtype != dtype):
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
|
||||
|
||||
@ -110,12 +110,12 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
controlnet_config['dtype'] = unet_dtype
|
||||
|
||||
load_device = memory_management.get_torch_device()
|
||||
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device)
|
||||
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
|
||||
with using_forge_operations(dtype=unet_dtype):
|
||||
with using_forge_operations(dtype=unet_dtype, manual_cast_enabled=computation_dtype != unet_dtype):
|
||||
control_model = cldm.ControlNet(**controlnet_config).to(dtype=unet_dtype)
|
||||
|
||||
if pth:
|
||||
@ -139,7 +139,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
# TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=computation_dtype)
|
||||
return ControlNetPatcher(control)
|
||||
|
||||
def __init__(self, model_patcher):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user