Skip to content

Commit

Permalink
Fix type annotations for distributed_concat() (#13746)
Browse files Browse the repository at this point in the history
* Fix type annotations for `distributed_concat()`

* Use Any
  • Loading branch information
Renovamen authored Sep 27, 2021
1 parent e0d31a8 commit 3ccc270
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from logging import StreamHandler
from typing import Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -157,7 +157,7 @@ def nested_xla_mesh_reduce(tensors, name):
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")


def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
Expand Down

0 comments on commit 3ccc270

Please sign in to comment.