From 184bb04f8d4ba78c83f4fabc66727e9813f28f37 Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Fri, 21 Feb 2025 12:01:39 +0000 Subject: [PATCH] increased support for custom CLIPs (#2642) increased support for custom CLIPs more forms recognised now can be applied to sd1.5, sdxl, (sd3) --- backend/loader.py | 216 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 207 insertions(+), 9 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index d5b6ca90..baf0d347 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -209,12 +209,208 @@ def replace_state_dict(sd, asd, guess): for k, v in asd.items(): sd[vae_key_prefix + k] = v - if 'text_model.encoder.layers.0.layer_norm1.weight' in asd: - keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")] - for k in keys_to_delete: - del sd[k] - for k, v in asd.items(): - sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v + + ## identify model type + flux_test_key = "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" + sd3_test_key = "model.diffusion_model.final_layer.adaLN_modulation.1.bias" + legacy_test_key = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" + + model_type = "-" + if legacy_test_key in sd: + match sd[legacy_test_key].shape[1]: + case 768: + model_type = "sd1" + case 1024: + model_type = "sd2" + case 2048: + model_type = "sdxl" + elif flux_test_key in sd: + model_type = "flux" + elif sd3_test_key in sd: + model_type = "sd3" + + ## prefixes used by various model types for CLIP-L + prefix_L = { + "-" : None, + "sd1" : "cond_stage_model.transformer.", + "sd2" : None, + "sdxl": "conditioner.embedders.0.transformer.", + "flux": "text_encoders.clip_l.transformer.", + "sd3" : "text_encoders.clip_l.transformer.", + } + ## prefixes used by various model types for CLIP-G + prefix_G = { + "-" : None, + "sd1" : None, + "sd2" : None, + "sdxl": "conditioner.embedders.1.model.", + "flux": None, + "sd3" : "text_encoders.clip_g.", + } + ## prefixes used by various model types for CLIP-H + prefix_H = { + "-" : None, + "sd1" : None, + "sd2" : "conditioner.embedders.0.model.", + "sdxl": None, + "flux": None, + "sd3" : None, + } + + + ## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3). + if "first_stage_model.decoder.conv_in.weight" in asd: + channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1] + if model_type == "sd1" or model_type == "sd2" or model_type == "sdxl": + if channels == 4: + for k, v in asd.items(): + sd[k] = v + elif model_type == "sd3": + if channels == 16: + for k, v in asd.items(): + sd[k] = v + + ## CLIP-H + CLIP_H = { # key to identify source model old_prefix + 'cond_stage_model.model.ln_final.weight' : 'cond_stage_model.model.', +# 'text_model.encoder.layers.0.layer_norm1.bias' : 'text_model'. # would need converting + } + for CLIP_key in CLIP_H.keys(): + if CLIP_key in asd and asd[CLIP_key].shape[0] == 1024: + new_prefix = prefix_H[model_type] + old_prefix = CLIP_H[CLIP_key] + + if new_prefix is not None: + for k, v in asd.items(): + new_k = k.replace(old_prefix, new_prefix) + sd[new_k] = v + + ## CLIP-G + CLIP_G = { # key to identify source model old_prefix + 'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.', + 'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.', + 'text_model.encoder.layers.0.layer_norm1.bias' : '', + 'transformer.resblocks.0.ln_1.bias' : '' + } + for CLIP_key in CLIP_G.keys(): + if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280: + new_prefix = prefix_G[model_type] + old_prefix = CLIP_G[CLIP_key] + + if new_prefix is not None: + if "resblocks" not in CLIP_key: # need to convert + def convert_transformers(statedict, prefix_from, prefix_to, number): + keys_to_replace = { + "{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding", + "{}text_model.embeddings.token_embedding.weight" : "{}token_embedding.weight", + "{}text_model.final_layer_norm.weight" : "{}ln_final.weight", + "{}text_model.final_layer_norm.bias" : "{}ln_final.bias", + "text_projection.weight" : "{}text_projection", + } + resblock_to_replace = { + "layer_norm1" : "ln_1", + "layer_norm2" : "ln_2", + "mlp.fc1" : "mlp.c_fc", + "mlp.fc2" : "mlp.c_proj", + "self_attn.out_proj" : "attn.out_proj" , + } + + for x in keys_to_replace: + k = x.format(prefix_from) + statedict[keys_to_replace[x].format(prefix_to)] = statedict.pop(k) + + for resblock in range(number): + for y in ["weight", "bias"]: + for x in resblock_to_replace: + k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y) + k_to = "{}transformer.resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + statedict[k_to] = statedict.pop(k) + + k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y) + weightsQ = statedict.pop(k_from) + k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.k_proj", y) + weightsK = statedict.pop(k_from) + k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y) + weightsV = statedict.pop(k_from) + + k_to = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y) + statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV)) + return statedict + + asd = convert_transformers(asd, old_prefix, new_prefix, 32) + new_prefix = "" + + if old_prefix == "": + for k, v in asd.items(): + new_k = new_prefix + k + sd[new_k] = v + else: + for k, v in asd.items(): + new_k = k.replace(old_prefix, new_prefix) + sd[new_k] = v + + ## CLIP-L + CLIP_L = { # key to identify source model old_prefix + 'cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'cond_stage_model.transformer.', + 'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.', + 'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.', + 'text_model.encoder.layers.0.layer_norm1.bias' : '', + 'transformer.resblocks.0.ln_1.bias' : '' + } + + for CLIP_key in CLIP_L.keys(): + if CLIP_key in asd and asd[CLIP_key].shape[0] == 768: + new_prefix = prefix_L[model_type] + old_prefix = CLIP_L[CLIP_key] + + if new_prefix is not None: + if "resblocks" in CLIP_key: # need to convert + def transformers_convert(statedict, prefix_from, prefix_to, number): + keys_to_replace = { + "positional_embedding" : "{}text_model.embeddings.position_embedding.weight", + "token_embedding.weight": "{}text_model.embeddings.token_embedding.weight", + "ln_final.weight" : "{}text_model.final_layer_norm.weight", + "ln_final.bias" : "{}text_model.final_layer_norm.bias", + } + resblock_to_replace = { + "ln_1" : "layer_norm1", + "ln_2" : "layer_norm2", + "mlp.c_fc" : "mlp.fc1", + "mlp.c_proj" : "mlp.fc2", + "attn.out_proj" : "self_attn.out_proj", + } + + for k in keys_to_replace: + statedict[keys_to_replace[k].format(prefix_to)] = statedict.pop(k) + + for resblock in range(number): + for y in ["weight", "bias"]: + for x in resblock_to_replace: + k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) + k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + statedict[k_to] = statedict.pop(k) + + k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) + weights = statedict.pop(k_from) + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"] + k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y) + statedict[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return statedict + + asd = transformers_convert(asd, old_prefix, new_prefix, 12) + new_prefix = "" + + if old_prefix == "": + for k, v in asd.items(): + new_k = new_prefix + k + sd[new_k] = v + else: + for k, v in asd.items(): + new_k = k.replace(old_prefix, new_prefix) + sd[new_k] = v + if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd: keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")] @@ -227,9 +423,8 @@ def replace_state_dict(sd, asd, guess): def preprocess_state_dict(sd): - if any("double_block" in k for k in sd.keys()): - if not any(k.startswith("model.diffusion_model") for k in sd.keys()): - sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()} + if not any(k.startswith("model.diffusion_model") for k in sd.keys()): + sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()} return sd @@ -243,11 +438,14 @@ def split_state_dict(sd, additional_state_dicts: list = None): for asd in additional_state_dicts: asd = load_torch_file(asd) sd = replace_state_dict(sd, asd, guess) + del asd guess.clip_target = guess.clip_target(sd) guess.model_type = guess.model_type(sd) guess.ztsnr = 'ztsnr' in sd + sd = guess.process_vae_state_dict(sd) + state_dict = { guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)