mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-27 21:26:07 +00:00
Add support for fp8 scaled (#2946)
This commit is contained in:
parent
715c24b0e2
commit
d6b1d188b9
@ -12,6 +12,7 @@ stash = {}
|
||||
|
||||
|
||||
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
|
||||
scale_weight = getattr(layer, 'scale_weight', None)
|
||||
patches = getattr(layer, 'forge_online_loras', None)
|
||||
weight_patches, bias_patches = None, None
|
||||
|
||||
@ -32,6 +33,8 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None,
|
||||
weight = weight_fn(weight)
|
||||
if weight_args is not None:
|
||||
weight = weight.to(**weight_args)
|
||||
if scale_weight is not None:
|
||||
weight = weight*scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
if weight_patches is not None:
|
||||
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
|
||||
|
||||
@ -127,6 +130,7 @@ class ForgeOperations:
|
||||
self.out_features = out_features
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.scale_weight = None
|
||||
self.bias = None
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
@ -134,6 +138,8 @@ class ForgeOperations:
|
||||
if hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
|
||||
if prefix + 'scale_weight' in state_dict:
|
||||
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
del self.dummy
|
||||
|
||||
Loading…
Reference in New Issue
Block a user