disable xformers for t5

This commit is contained in:
layerdiffusion 2024-08-15 00:55:49 -07:00
parent d336597fa5
commit ce16d34d03
2 changed files with 3 additions and 13 deletions

View File

@ -284,18 +284,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
b, _, dim_head = q.shape
dim_head //= heads
disabled_xformers = False
if BROKEN_XFORMERS:
if b * heads > 65535:
disabled_xformers = True
if not disabled_xformers:
if torch.jit.is_tracing() or torch.jit.is_scripting():
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)
if BROKEN_XFORMERS and b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(

View File

@ -1,7 +1,7 @@
import torch
import math
from backend.attention import attention_function
from backend.attention import attention_pytorch as attention_function
activations = {