Skip to content

Commit

Permalink
chore(format): run black on main (#504)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 1, 2024
1 parent b330df5 commit e628c86
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def generate(
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
attention_mask
)

pbar: Optional[tqdm] = None

if show_tqdm:
Expand Down Expand Up @@ -473,19 +473,15 @@ def generate(

del logits

idx_next = torch.multinomial(scores, num_samples=1).to(
finish.device
)
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)

if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[inputs_ids, idx_next.unsqueeze_(1)], 1
)
inputs_ids_tmp = torch.cat([inputs_ids, idx_next.unsqueeze_(1)], 1)
else:
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
Expand Down Expand Up @@ -525,9 +521,11 @@ def generate(
if finish.all() or context.get():
break

if pbar is not None: pbar.update(1)

if pbar is not None: pbar.close()
if pbar is not None:
pbar.update(1)

if pbar is not None:
pbar.close()

if not finish.all():
if context.get():
Expand Down

0 comments on commit e628c86

Please sign in to comment.