fix old version of pytorch

This commit is contained in:
layerdiffusion 2024-08-26 06:51:48 -07:00
parent f22b80ef94
commit acf99dd74e

View File

@ -8,7 +8,7 @@ def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
x = x.view(torch.uint8).view(x.size(0), -1)
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
reshaped = unpacked.view(x.size(0), -1)
reshaped = reshaped.to(torch.int8) - 8
reshaped = reshaped.view(torch.int8) - 8
return reshaped.view(torch.int32)
@ -19,11 +19,23 @@ def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
return reshaped.view(torch.int32)
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
disable_all_optimizations = False
if not hasattr(torch, 'uint16'):
disable_all_optimizations = True
if disable_all_optimizations:
print('You are using PyTorch below version 2.3. Some optimizations will be disabled.')
if not disable_all_optimizations:
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
def quick_unpack_4bits(x):
if disable_all_optimizations:
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8
global native_4bits_lookup_table
s0 = x.size(0)
@ -40,6 +52,9 @@ def quick_unpack_4bits(x):
def quick_unpack_4bits_u(x):
if disable_all_optimizations:
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1)
global native_4bits_lookup_table_u
s0 = x.size(0)