Update sub_quadratic_attention.py

This commit is contained in:
layerdiffusion 2024-08-03 15:19:45 -07:00
parent 4add428e25
commit c7b1789892

View File

@ -16,55 +16,60 @@ from torch.utils.checkpoint import checkpoint
import math
try:
from typing import Optional, NamedTuple, List, Protocol
from typing import Optional, NamedTuple, List, Protocol
except ImportError:
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from typing import Optional, NamedTuple, List
from typing_extensions import Protocol
from torch import Tensor
from typing import List
from ldm_patched.modules import model_management
from backend import memory_management
def dynamic_slice(
x: Tensor,
starts: List[int],
sizes: List[int],
x: Tensor,
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
return x[slicing]
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk(Protocol):
@staticmethod
def __call__(
query: Tensor,
key_t: Tensor,
value: Tensor,
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn(Protocol):
@staticmethod
def __call__(
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
scale: float,
upcast_attention: bool,
mask,
) -> AttnChunk:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type='cuda'):
query = query.float()
key_t = key_t.float()
attn_weights = torch.baddbmm(
@ -93,13 +98,14 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key_t: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
mask,
query: Tensor,
key_t: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
mask,
) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape
@ -116,7 +122,7 @@ def _query_chunk_attention(
(batch_x_heads, kv_chunk_size, v_channels_per_head)
)
if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
mask = mask[:, :, chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
@ -135,17 +141,18 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
) -> Tensor:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type='cuda'):
query = query.float()
key_t = key_t.float()
attn_scores = torch.baddbmm(
@ -169,7 +176,7 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except memory_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)
@ -180,20 +187,22 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key_t: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
mask = None,
query: Tensor,
key_t: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
mask=None,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@ -234,7 +243,7 @@ def efficient_dot_product_attention(
if mask is None:
return None
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
return mask[:, chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
@ -259,7 +268,7 @@ def efficient_dot_product_attention(
value=value,
mask=mask,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([