Add support for fp8 scaled (#2946)

This commit is contained in:
Mathieu Croquelois 2025-06-25 01:19:27 +01:00 committed by GitHub
parent 715c24b0e2
commit d6b1d188b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ stash = {}
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
scale_weight = getattr(layer, 'scale_weight', None)
patches = getattr(layer, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
@ -32,6 +33,8 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None,
weight = weight_fn(weight)
if weight_args is not None:
weight = weight.to(**weight_args)
if scale_weight is not None:
weight = weight*scale_weight.to(device=weight.device, dtype=weight.dtype)
if weight_patches is not None:
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
@ -127,6 +130,7 @@ class ForgeOperations:
self.out_features = out_features
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.scale_weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
@ -134,6 +138,8 @@ class ForgeOperations:
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'scale_weight' in state_dict:
self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight'])
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy