mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2025-12-28 13:42:02 +00:00
Make Q4_K_S as fast as Q4_0
by baking the layer when model load
This commit is contained in:
parent
868f662eb6
commit
e60bb1c96f
29
packages_3rdparty/gguf/quants.py
vendored
29
packages_3rdparty/gguf/quants.py
vendored
@ -804,9 +804,10 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
QK_K = 256
|
||||
def bake_layer_weight(cls, layer, weight): # Only compute one time when model load
|
||||
# Copyright Forge 2024
|
||||
|
||||
blocks = weight.data
|
||||
K_SCALE_SIZE = 12
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
|
||||
@ -814,7 +815,27 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# Compute in each diffusion iteration
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if dm.device != qs.device:
|
||||
dm = dm.to(device=qs.device)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user