fix type hints

This commit is contained in:
layerdiffusion 2024-08-15 03:08:25 -07:00
parent 2690b654fd
commit fd0d25ba8a

View File

@ -125,7 +125,11 @@ class __Quant(ABC):
cls.grid = grid.reshape((1, 1, *cls.grid_shape))
@classmethod
def dequantize_pytorch(cls, data: torch.Tensor, original_shape=torch.float16) -> torch.Tensor:
def quantize_pytorch(cls, data: torch.Tensor) -> torch.Tensor:
return cls.quantize_blocks_pytorch(data)
@classmethod
def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
@ -139,6 +143,11 @@ class __Quant(ABC):
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError
@classmethod
@abstractmethod
def quantize_blocks_pytorch(cls, blocks) -> torch.Tensor:
raise NotImplementedError
@classmethod
@abstractmethod
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: