mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-28 05:35:00 +00:00
add lora support for z-image
This commit is contained in:
parent
01f2c2876f
commit
214f890c3c
@ -199,11 +199,16 @@ class ZImage(ForgeDiffusionEngine):
|
||||
timesteps=1000
|
||||
)
|
||||
|
||||
# Create config object for Z-Image identification (used by LoRA loader)
|
||||
class ZImageModelConfig:
|
||||
is_zimage = True
|
||||
huggingface_repo = 'Z-Image'
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=wrapped_transformer,
|
||||
diffusers_scheduler=components_dict['scheduler'],
|
||||
k_predictor=k_predictor,
|
||||
config=None
|
||||
config=ZImageModelConfig()
|
||||
)
|
||||
|
||||
self.text_processing_engine_qwen = QwenTextProcessingEngine(
|
||||
|
||||
17
backend/lora/__init__.py
Normal file
17
backend/lora/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
Backend LoRA utilities.
|
||||
"""
|
||||
|
||||
from backend.lora.zimage_lora import (
|
||||
load_zimage_lora_patches,
|
||||
load_zimage_lora,
|
||||
apply_zimage_lora_to_state_dict,
|
||||
lora_key_to_model_key,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'load_zimage_lora_patches',
|
||||
'load_zimage_lora',
|
||||
'apply_zimage_lora_to_state_dict',
|
||||
'lora_key_to_model_key',
|
||||
]
|
||||
345
backend/lora/zimage_lora.py
Normal file
345
backend/lora/zimage_lora.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Z-Image LoRA Loader
|
||||
|
||||
A dedicated LoRA loader for Z-Image models that handles the specific key mappings
|
||||
required between LoRA files and the Z-Image transformer architecture.
|
||||
|
||||
Key Mappings:
|
||||
LoRA: diffusion_model.layers.N.attention.out -> Model: layers.N.attention.to_out.0.weight
|
||||
LoRA: diffusion_model.layers.N.attention.to_q -> Model: layers.N.attention.to_q.weight
|
||||
LoRA: diffusion_model.layers.N.feed_forward.w1 -> Model: layers.N.feed_forward.w1.weight
|
||||
etc.
|
||||
"""
|
||||
|
||||
import torch
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
load_file = None
|
||||
|
||||
|
||||
def lora_key_to_model_key(lora_key: str) -> str:
|
||||
"""
|
||||
Convert a LoRA key to the corresponding model key.
|
||||
|
||||
Example:
|
||||
'diffusion_model.layers.0.attention.out' -> 'layers.0.attention.to_out.0.weight'
|
||||
'diffusion_model.layers.0.attention.to_q' -> 'layers.0.attention.to_q.weight'
|
||||
"""
|
||||
# Remove diffusion_model. prefix
|
||||
if lora_key.startswith('diffusion_model.'):
|
||||
model_key = lora_key[len('diffusion_model.'):]
|
||||
else:
|
||||
model_key = lora_key
|
||||
|
||||
# Handle attention.out -> attention.to_out.0 transformation
|
||||
model_key = model_key.replace('.attention.out', '.attention.to_out.0')
|
||||
|
||||
# Add .weight suffix
|
||||
model_key = model_key + '.weight'
|
||||
|
||||
return model_key
|
||||
|
||||
|
||||
def load_zimage_lora_patches(lora_state: dict, model) -> tuple[dict, dict]:
|
||||
"""
|
||||
Create LoRA patches for a Z-Image model in Forge's patch format.
|
||||
|
||||
This function is compatible with Forge's LoRA loading system and returns
|
||||
patches in the same format as the regular load_lora function.
|
||||
|
||||
Args:
|
||||
lora_state: Dictionary of LoRA weights (already loaded from file)
|
||||
model: The Z-Image model (KModel wrapping ZImageTransformerWrapper)
|
||||
|
||||
Returns:
|
||||
(patch_dict, remaining_dict) where:
|
||||
- patch_dict: {model_key: ("lora", (lora_up, lora_down, alpha, mid, dora_scale))}
|
||||
- remaining_dict: Unmatched LoRA keys
|
||||
"""
|
||||
# Get model state dict keys for matching
|
||||
# KModel structure: KModel.diffusion_model (wrapper) -> wrapper.transformer (actual model)
|
||||
model_state_keys = set(model.state_dict().keys())
|
||||
|
||||
# Debug: Print sample keys to understand structure
|
||||
sample_keys = list(model_state_keys)[:5]
|
||||
print(f"[Z-Image LoRA] Model state dict sample keys: {sample_keys}")
|
||||
|
||||
# Group LoRA keys by base key
|
||||
lora_groups = {}
|
||||
for key in lora_state.keys():
|
||||
result = extract_lora_base_key(key)
|
||||
if result:
|
||||
base_key, key_type = result
|
||||
if base_key not in lora_groups:
|
||||
lora_groups[base_key] = {}
|
||||
lora_groups[base_key][key_type] = key # Store the full key name
|
||||
|
||||
patch_dict = {}
|
||||
remaining_dict = {}
|
||||
loaded_keys = set()
|
||||
|
||||
for lora_base_key, lora_key_names in lora_groups.items():
|
||||
# Convert LoRA key to model key (returns key without diffusion_model prefix)
|
||||
# e.g., 'diffusion_model.layers.0.attention.out' -> 'layers.0.attention.to_out.0.weight'
|
||||
model_key = lora_key_to_model_key(lora_base_key)
|
||||
|
||||
# Try multiple prefixes to find the key in the model
|
||||
# KModel wraps: diffusion_model.transformer.{actual_key}
|
||||
actual_model_key = None
|
||||
prefixes_to_try = [
|
||||
'', # Direct match
|
||||
'diffusion_model.', # Standard prefix
|
||||
'diffusion_model.transformer.', # Wrapped transformer
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_try:
|
||||
candidate = prefix + model_key
|
||||
if candidate in model_state_keys:
|
||||
actual_model_key = candidate
|
||||
break
|
||||
|
||||
if actual_model_key is None:
|
||||
# Key not found - add all related keys to remaining
|
||||
for key_type, full_key in lora_key_names.items():
|
||||
remaining_dict[full_key] = lora_state.get(full_key)
|
||||
continue
|
||||
|
||||
# Get LoRA components
|
||||
up_key = lora_key_names.get('up')
|
||||
down_key = lora_key_names.get('down')
|
||||
alpha_key = lora_key_names.get('alpha')
|
||||
|
||||
if up_key is None or down_key is None:
|
||||
continue
|
||||
|
||||
lora_up = lora_state[up_key]
|
||||
lora_down = lora_state[down_key]
|
||||
alpha = lora_state[alpha_key].item() if alpha_key else None
|
||||
|
||||
# Mark keys as loaded
|
||||
loaded_keys.add(up_key)
|
||||
loaded_keys.add(down_key)
|
||||
if alpha_key:
|
||||
loaded_keys.add(alpha_key)
|
||||
|
||||
# Create patch in Forge format: ("lora", (up, down, alpha, mid, dora_scale))
|
||||
patch_dict[actual_model_key] = ("lora", (lora_up, lora_down, alpha, None, None))
|
||||
|
||||
# Add unloaded keys to remaining_dict
|
||||
for key in lora_state.keys():
|
||||
if key not in loaded_keys:
|
||||
remaining_dict[key] = lora_state[key]
|
||||
|
||||
print(f"[Z-Image LoRA] Created {len(patch_dict)} patches, {len(remaining_dict)} unmatched keys")
|
||||
|
||||
return patch_dict, remaining_dict
|
||||
|
||||
|
||||
def extract_lora_base_key(full_key: str) -> tuple[str, str] | None:
|
||||
"""
|
||||
Extract the base key and type from a full LoRA key.
|
||||
|
||||
Returns (base_key, type) where type is 'up', 'down', or 'alpha'
|
||||
Returns None if not a valid LoRA key.
|
||||
"""
|
||||
if full_key.endswith('.lora_up.weight'):
|
||||
return full_key[:-len('.lora_up.weight')], 'up'
|
||||
elif full_key.endswith('.lora_down.weight'):
|
||||
return full_key[:-len('.lora_down.weight')], 'down'
|
||||
elif full_key.endswith('.alpha'):
|
||||
return full_key[:-len('.alpha')], 'alpha'
|
||||
return None
|
||||
|
||||
|
||||
def load_zimage_lora(lora_path: str, model: torch.nn.Module, strength: float = 1.0) -> dict:
|
||||
"""
|
||||
Load a LoRA file and apply it to a Z-Image model.
|
||||
|
||||
Args:
|
||||
lora_path: Path to the LoRA safetensors file
|
||||
model: The Z-Image transformer model
|
||||
strength: LoRA strength multiplier (default 1.0)
|
||||
|
||||
Returns:
|
||||
dict with 'matched', 'unmatched', and 'applied' counts
|
||||
"""
|
||||
# Load LoRA weights
|
||||
lora_state = load_file(lora_path)
|
||||
|
||||
# Get model state dict
|
||||
model_state = model.state_dict()
|
||||
|
||||
# Group LoRA keys by base key
|
||||
lora_groups = {}
|
||||
for key in lora_state.keys():
|
||||
result = extract_lora_base_key(key)
|
||||
if result:
|
||||
base_key, key_type = result
|
||||
if base_key not in lora_groups:
|
||||
lora_groups[base_key] = {}
|
||||
lora_groups[base_key][key_type] = lora_state[key]
|
||||
|
||||
matched = 0
|
||||
unmatched = 0
|
||||
applied = 0
|
||||
unmatched_keys = []
|
||||
|
||||
# Apply LoRA weights
|
||||
for lora_base_key, lora_weights in lora_groups.items():
|
||||
model_key = lora_key_to_model_key(lora_base_key)
|
||||
|
||||
if model_key not in model_state:
|
||||
unmatched += 1
|
||||
unmatched_keys.append(lora_base_key)
|
||||
continue
|
||||
|
||||
matched += 1
|
||||
|
||||
# Get LoRA components
|
||||
lora_up = lora_weights.get('up')
|
||||
lora_down = lora_weights.get('down')
|
||||
alpha = lora_weights.get('alpha')
|
||||
|
||||
if lora_up is None or lora_down is None:
|
||||
continue
|
||||
|
||||
# Calculate LoRA delta: up @ down * scale
|
||||
# Alpha scaling: if alpha exists, scale = alpha / rank
|
||||
rank = lora_down.shape[0]
|
||||
if alpha is not None:
|
||||
scale = alpha.item() / rank
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
scale *= strength
|
||||
|
||||
# Compute the delta weight
|
||||
# lora_up shape: [out_features, rank]
|
||||
# lora_down shape: [rank, in_features]
|
||||
# delta shape: [out_features, in_features]
|
||||
delta = (lora_up @ lora_down) * scale
|
||||
|
||||
# Get the parameter and apply delta
|
||||
param_name = model_key
|
||||
|
||||
# Navigate to the actual parameter in the model
|
||||
parts = param_name.split('.')
|
||||
target = model
|
||||
for part in parts[:-1]:
|
||||
if part.isdigit():
|
||||
target = target[int(part)]
|
||||
else:
|
||||
target = getattr(target, part)
|
||||
|
||||
param = getattr(target, parts[-1])
|
||||
|
||||
# Apply the delta
|
||||
with torch.no_grad():
|
||||
if param.shape == delta.shape:
|
||||
param.add_(delta.to(param.device, param.dtype))
|
||||
applied += 1
|
||||
else:
|
||||
print(f"[Z-Image LoRA] Shape mismatch for {model_key}: param={param.shape}, delta={delta.shape}")
|
||||
|
||||
if unmatched_keys:
|
||||
print(f"[Z-Image LoRA] Unmatched keys ({len(unmatched_keys)}): {unmatched_keys[:5]}...")
|
||||
|
||||
return {
|
||||
'matched': matched,
|
||||
'unmatched': unmatched,
|
||||
'applied': applied,
|
||||
'total_lora_layers': len(lora_groups)
|
||||
}
|
||||
|
||||
|
||||
def apply_zimage_lora_to_state_dict(lora_path: str, model_state: dict, strength: float = 1.0) -> tuple[dict, dict]:
|
||||
"""
|
||||
Apply LoRA weights to a model state dict (without requiring the model instance).
|
||||
|
||||
Args:
|
||||
lora_path: Path to the LoRA safetensors file
|
||||
model_state: The model's state dict
|
||||
strength: LoRA strength multiplier
|
||||
|
||||
Returns:
|
||||
(modified_state_dict, stats_dict)
|
||||
"""
|
||||
lora_state = load_file(lora_path)
|
||||
|
||||
# Group LoRA keys by base key
|
||||
lora_groups = {}
|
||||
for key in lora_state.keys():
|
||||
result = extract_lora_base_key(key)
|
||||
if result:
|
||||
base_key, key_type = result
|
||||
if base_key not in lora_groups:
|
||||
lora_groups[base_key] = {}
|
||||
lora_groups[base_key][key_type] = lora_state[key]
|
||||
|
||||
matched = 0
|
||||
applied = 0
|
||||
unmatched_keys = []
|
||||
|
||||
# Create a copy of the state dict
|
||||
new_state = {k: v.clone() for k, v in model_state.items()}
|
||||
|
||||
for lora_base_key, lora_weights in lora_groups.items():
|
||||
model_key = lora_key_to_model_key(lora_base_key)
|
||||
|
||||
if model_key not in new_state:
|
||||
unmatched_keys.append(lora_base_key)
|
||||
continue
|
||||
|
||||
matched += 1
|
||||
|
||||
lora_up = lora_weights.get('up')
|
||||
lora_down = lora_weights.get('down')
|
||||
alpha = lora_weights.get('alpha')
|
||||
|
||||
if lora_up is None or lora_down is None:
|
||||
continue
|
||||
|
||||
rank = lora_down.shape[0]
|
||||
scale = (alpha.item() / rank if alpha is not None else 1.0) * strength
|
||||
|
||||
delta = (lora_up @ lora_down) * scale
|
||||
|
||||
if new_state[model_key].shape == delta.shape:
|
||||
new_state[model_key] = new_state[model_key] + delta.to(new_state[model_key].dtype)
|
||||
applied += 1
|
||||
|
||||
stats = {
|
||||
'matched': matched,
|
||||
'unmatched': len(unmatched_keys),
|
||||
'applied': applied,
|
||||
'unmatched_keys': unmatched_keys[:10] if unmatched_keys else []
|
||||
}
|
||||
|
||||
return new_state, stats
|
||||
|
||||
|
||||
# Standalone test
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python zimage_lora.py <model.safetensors> <lora.safetensors> [strength]")
|
||||
sys.exit(1)
|
||||
|
||||
model_path = sys.argv[1]
|
||||
lora_path = sys.argv[2]
|
||||
strength = float(sys.argv[3]) if len(sys.argv) > 3 else 1.0
|
||||
|
||||
print(f"Loading model: {model_path}")
|
||||
model_state = load_file(model_path)
|
||||
|
||||
print(f"Applying LoRA: {lora_path} (strength={strength})")
|
||||
new_state, stats = apply_zimage_lora_to_state_dict(lora_path, model_state, strength)
|
||||
|
||||
print(f"\nResults:")
|
||||
print(f" Matched: {stats['matched']}")
|
||||
print(f" Applied: {stats['applied']}")
|
||||
print(f" Unmatched: {stats['unmatched']}")
|
||||
if stats['unmatched_keys']:
|
||||
print(f" Sample unmatched: {stats['unmatched_keys']}")
|
||||
@ -10,17 +10,37 @@ from backend.args import dynamic_args
|
||||
from modules import shared, sd_models, errors, scripts
|
||||
from backend.utils import load_torch_file
|
||||
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora
|
||||
from backend.lora.zimage_lora import load_zimage_lora_patches
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default', online_mode=False):
|
||||
model_flag = type(model.model).__name__ if model is not None else 'default'
|
||||
|
||||
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
|
||||
clip_keys = model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {}
|
||||
# Check if this is a Z-Image model (detect by wrapper class or model config)
|
||||
is_zimage = False
|
||||
if model is not None:
|
||||
model_config = getattr(model.model, 'config', None)
|
||||
if model_config is not None and getattr(model_config, 'is_zimage', False):
|
||||
is_zimage = True
|
||||
print(f'[LORA] Z-Image model detected via config.is_zimage')
|
||||
# Also check by class name pattern
|
||||
if 'ZImage' in model_flag:
|
||||
is_zimage = True
|
||||
print(f'[LORA] Z-Image model detected via class name: {model_flag}')
|
||||
|
||||
lora_unmatch = lora
|
||||
lora_unet, lora_unmatch = load_lora(lora_unmatch, unet_keys)
|
||||
lora_clip, lora_unmatch = load_lora(lora_unmatch, clip_keys)
|
||||
# Use dedicated Z-Image LoRA loader if applicable
|
||||
if is_zimage and model is not None:
|
||||
print(f'[LORA] Using Z-Image LoRA loader for {filename}')
|
||||
lora_unet, lora_unmatch = load_zimage_lora_patches(lora, model.model)
|
||||
clip_keys = model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {}
|
||||
lora_clip, lora_unmatch = load_lora(lora_unmatch, clip_keys)
|
||||
else:
|
||||
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
|
||||
clip_keys = model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {}
|
||||
|
||||
lora_unmatch = lora
|
||||
lora_unet, lora_unmatch = load_lora(lora_unmatch, unet_keys)
|
||||
lora_clip, lora_unmatch = load_lora(lora_unmatch, clip_keys)
|
||||
|
||||
#if len(lora_unmatch) > 12:
|
||||
#print(f'[LORA] LoRA version mismatch for {model_flag}: {filename}')
|
||||
|
||||
@ -339,7 +339,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
# key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
# key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
|
||||
|
||||
if 'flux' in model.config.huggingface_repo.lower() or 'chroma' in model.config.huggingface_repo.lower(): #Diffusers lora Flux and Chroma
|
||||
# Safe access to huggingface_repo (may be None for some models like Z-Image)
|
||||
repo = ''
|
||||
config = getattr(model, 'config', None)
|
||||
if config is not None:
|
||||
repo = (getattr(config, 'huggingface_repo', '') or '').lower()
|
||||
|
||||
if 'flux' in repo or 'chroma' in repo: # Diffusers lora Flux and Chroma
|
||||
diffusers_keys = utils.flux_to_diffusers(model.diffusion_model.config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
@ -348,4 +354,26 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # onetrainer
|
||||
|
||||
# Z-Image LoRA support
|
||||
is_zimage = (config is not None and getattr(config, 'is_zimage', False)) or repo == 'z-image'
|
||||
if not is_zimage:
|
||||
# Fallback: detect by checking for Z-Image layer structure
|
||||
if any('layers.0.feed_forward.w1' in k for k in sdk):
|
||||
is_zimage = True
|
||||
|
||||
if is_zimage:
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
# Handle both cases: with or without diffusion_model. prefix
|
||||
if k.startswith("diffusion_model."):
|
||||
model_key = k
|
||||
lora_key = k[:-len(".weight")]
|
||||
else:
|
||||
# Add diffusion_model. prefix for model key (Forge wrapping)
|
||||
model_key = "diffusion_model." + k
|
||||
lora_key = "diffusion_model." + k[:-len(".weight")]
|
||||
# Handle the to_out.0 -> out naming difference in LoRA files
|
||||
lora_key = lora_key.replace(".attention.to_out.0", ".attention.out")
|
||||
key_map[lora_key] = model_key
|
||||
|
||||
return sdk, key_map
|
||||
|
||||
Loading…
Reference in New Issue
Block a user