Merge pull request #25 from maybleMyers/oomtest

fix fp16 with cpp diffusion fix
This commit is contained in:
benjimon 2025-12-04 00:11:32 -08:00 committed by GitHub
commit 51100a8001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 7 deletions

View File

@ -8,6 +8,52 @@ import backend.args
import huggingface_guess
def patch_zimage_for_fp16(model):
import torch.nn.functional as F
from diffusers.models.transformers.transformer_z_image import FeedForward, ZImageTransformerBlock
def clamp_fp16(x):
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
def patched_forward_silu_gating(self, x1, x3):
return clamp_fp16(F.silu(x1) * x3)
for module in model.modules():
if isinstance(module, FeedForward):
module._forward_silu_gating = patched_forward_silu_gating.__get__(module, FeedForward)
original_block_forward = ZImageTransformerBlock.forward
def patched_block_forward(self, x, attn_mask, freqs_cis, adaln_input=None):
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + gate_msa * self.attention_norm2(clamp_fp16(attn_out))
x = x + gate_mlp * self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x) * scale_mlp)))
else:
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
x = x + self.attention_norm2(clamp_fp16(attn_out))
x = x + self.ffn_norm2(clamp_fp16(self.feed_forward(self.ffn_norm1(x))))
return x
for module in model.modules():
if isinstance(module, ZImageTransformerBlock):
module.forward = patched_block_forward.__get__(module, ZImageTransformerBlock)
def convert_comfy_zimage_state_dict(state_dict):
"""
Convert ComfyUI Z-Image state dict format to Diffusers format.
@ -336,6 +382,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict)
if cls_name == 'ZImageTransformer2DModel':
patch_zimage_for_fp16(model)
if hasattr(model, '_internal_dict'):
model._internal_dict = unet_config
else:

View File

@ -102,12 +102,8 @@ def tensor2parameter(x):
def fp16_fix(x):
# An interesting trick to avoid fp16 overflow
# Source: https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/1114
# Related: https://github.com/comfyanonymous/ComfyUI/blob/f1d6cef71c70719cc3ed45a2455a4e5ac910cd5e/comfy/ldm/flux/layers.py#L180
if x.dtype in [torch.float16]:
return x.clip(-32768.0, 32768.0)
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@ -251,7 +251,10 @@ def refresh_model_loading_parameters():
model_data.forge_loading_parameters = dict(
checkpoint_info=checkpoint_info,
additional_modules=shared.opts.forge_additional_modules,
unet_storage_dtype=unet_storage_dtype
unet_storage_dtype=unet_storage_dtype,
z_transformer_dtype=getattr(shared.opts, 'z_transformer_dtype', 'Automatic'),
z_vae_dtype=getattr(shared.opts, 'z_vae_dtype', 'Automatic'),
z_text_encoder_dtype=getattr(shared.opts, 'z_text_encoder_dtype', 'Automatic'),
)
print(f'Model selected: {model_data.forge_loading_parameters}')