Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention processor #961

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions library/original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,9 @@ def __init__(
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False

# Attention processor
self.processor = None

def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
Expand All @@ -590,7 +593,28 @@ def reshape_batch_dim_to_heads(self, tensor):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def forward(self, hidden_states, context=None, mask=None):
def set_processor(self):
return self.processor

def get_processor(self):
return self.processor

def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
Expand Down Expand Up @@ -703,6 +727,21 @@ def forward_sdpa(self, x, context=None, mask=None):
out = self.to_out[0](out)
return out

def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states

# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask

return hidden_states, context, mask

# feedforward
class GEGLU(nn.Module):
Expand Down Expand Up @@ -1331,7 +1370,7 @@ def __init__(
self.out_channels = OUT_CHANNELS

self.sample_size = sample_size
self.prepare_config()
self.prepare_config(sample_size=sample_size)

# state_dictの書式が変わるのでmoduleの持ち方は変えられない

Expand Down Expand Up @@ -1418,8 +1457,8 @@ def __init__(
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)

# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace(**kwargs)

@property
def dtype(self) -> torch.dtype:
Expand Down