Skip to content

Commit

Permalink
Comment out hack for now
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Walmsley committed Jul 8, 2024
1 parent 61ec432 commit e8663dd
Showing 1 changed file with 35 additions and 30 deletions.
65 changes: 35 additions & 30 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,36 +195,41 @@ def generate( # noqa: PLR0911
generation_config.pad_token_id,
generation_config.eos_token_id,
)
eos_token_tensor = (
torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
if generation_config.eos_token_id is not None
else None
)

# hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now
# for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel
if inputs_tensor.device.type == "mps":
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)

is_pad_token_in_inputs = (pad_token_tensor is not None) and (
custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()

model_kwargs["attention_mask"] = (
attention_mask_from_padding * can_infer_attention_mask
+ default_attention_mask * ~can_infer_attention_mask
)
else:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
pad_token_tensor,
eos_token_tensor,
)
# pad_token_tensor = (
# torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device)
# if generation_config.pad_token_id is not None
# else None
# )
# eos_token_tensor = (
# torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
# if generation_config.eos_token_id is not None
# else None
# )

# # hack to produce attention mask for mps devices since transformers bails but pytorch supports torch.isin on mps now
# # for this to work, you must run with PYTORCH_ENABLE_MPS_FALLBACK=1 and call model.to(mps_device) on the XttsModel
# if inputs_tensor.device.type == "mps":
# default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)

# is_pad_token_in_inputs = (pad_token_tensor is not None) and (
# custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any()
# )
# is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~(
# custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
# )
# can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
# attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()

# model_kwargs["attention_mask"] = (
# attention_mask_from_padding * can_infer_attention_mask
# + default_attention_mask * ~can_infer_attention_mask
# )
# else:
# model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
# inputs_tensor,
# pad_token_tensor,
# eos_token_tensor,
# )

# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
Expand Down

0 comments on commit e8663dd

Please sign in to comment.