mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-28 13:42:02 +00:00
Merge pull request #25 from maybleMyers/oomtest
fix fp16 with cpp diffusion fix
This commit is contained in:
commit
51100a8001
@ -8,6 +8,52 @@ import backend.args
|
||||
import huggingface_guess
|
||||
|
||||
|
||||
def patch_zimage_for_fp16(model):
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.transformers.transformer_z_image import FeedForward, ZImageTransformerBlock
|
||||
|
||||
def clamp_fp16(x):
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
def patched_forward_silu_gating(self, x1, x3):
|
||||
return clamp_fp16(F.silu(x1) * x3)
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, FeedForward):
|
||||
module._forward_silu_gating = patched_forward_silu_gating.__get__(module, FeedForward)
|
||||
|
||||
original_block_forward = ZImageTransformerBlock.forward
|
||||
|
||||
def patched_block_forward(self, x, attn_mask, freqs_cis, adaln_input=None):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out))
|
||||
x = x + gate_mlp * self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp)))
|
||||
else:
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + self.attention_norm2(clamp_fp16(attn_out))
|
||||
x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x))))
|
||||
return x
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, ZImageTransformerBlock):
|
||||
module.forward = patched_block_forward.__get__(module, ZImageTransformerBlock)
|
||||
|
||||
|
||||
def convert_comfy_zimage_state_dict(state_dict):
|
||||
"""
|
||||
Convert ComfyUI Z-Image state dict format to Diffusers format.
|
||||
@ -336,6 +382,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
load_state_dict(model, state_dict)
|
||||
|
||||
if cls_name == 'ZImageTransformer2DModel':
|
||||
patch_zimage_for_fp16(model)
|
||||
|
||||
if hasattr(model, '_internal_dict'):
|
||||
model._internal_dict = unet_config
|
||||
else:
|
||||
|
||||
@ -102,12 +102,8 @@ def tensor2parameter(x):
|
||||
|
||||
|
||||
def fp16_fix(x):
|
||||
# An interesting trick to avoid fp16 overflow
|
||||
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
|
||||
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
|
||||
|
||||
if x.dtype in [torch.float16]:
|
||||
return x.clip(-32768.0, 32768.0)
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -251,7 +251,10 @@ def refresh_model_loading_parameters():
|
||||
model_data.forge_loading_parameters = dict(
|
||||
checkpoint_info=checkpoint_info,
|
||||
additional_modules=shared.opts.forge_additional_modules,
|
||||
unet_storage_dtype=unet_storage_dtype
|
||||
unet_storage_dtype=unet_storage_dtype,
|
||||
z_transformer_dtype=getattr(shared.opts, 'z_transformer_dtype', 'Automatic'),
|
||||
z_vae_dtype=getattr(shared.opts, 'z_vae_dtype', 'Automatic'),
|
||||
z_text_encoder_dtype=getattr(shared.opts, 'z_text_encoder_dtype', 'Automatic'),
|
||||
)
|
||||
|
||||
print(f'Model selected: {model_data.forge_loading_parameters}')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user