From d6b1d188b989bc458551486be3b135994410aac8 Mon Sep 17 00:00:00 2001 From: Mathieu Croquelois Date: Wed, 25 Jun 2025 01:19:27 +0100 Subject: [PATCH] Add support for fp8 scaled (#2946) --- backend/operations.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/operations.py b/backend/operations.py index 5200093b..7760756a 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -12,6 +12,7 @@ stash = {} def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None): + scale_weight = getattr(layer, 'scale_weight', None) patches = getattr(layer, 'forge_online_loras', None) weight_patches, bias_patches = None, None @@ -32,6 +33,8 @@ def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, weight = weight_fn(weight) if weight_args is not None: weight = weight.to(**weight_args) + if scale_weight is not None: + weight = weight*scale_weight.to(device=weight.device, dtype=weight.dtype) if weight_patches is not None: weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype) @@ -127,6 +130,7 @@ class ForgeOperations: self.out_features = out_features self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) self.weight = None + self.scale_weight = None self.bias = None self.parameters_manual_cast = current_manual_cast_enabled @@ -134,6 +138,8 @@ class ForgeOperations: if hasattr(self, 'dummy'): if prefix + 'weight' in state_dict: self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy)) + if prefix + 'scale_weight' in state_dict: + self.scale_weight = torch.nn.Parameter(state_dict[prefix + 'scale_weight']) if prefix + 'bias' in state_dict: self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) del self.dummy