Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix importing torchtext batch #6365

Merged
merged 6 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from abc import ABC
from collections.abc import Mapping, Sequence
from copy import copy
Expand All @@ -22,10 +22,10 @@
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE

if _TORCHTEXT_AVAILABLE:
if _module_available("torchtext.legacy.data"):
if _compare_version("torchtext", operator.ge, "0.9.0"):
from torchtext.legacy.data import Batch
else:
from torchtext.data import Batch
Expand Down
8 changes: 8 additions & 0 deletions tests/helpers/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import operator

from pytorch_lightning.utilities.imports import _compare_version

if _compare_version("torchtext", operator.ge, "0.9.0"):
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
else:
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
2 changes: 1 addition & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest
import torch
from torchtext.data import Batch, Dataset, Example, Field, LabelField

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
Expand All @@ -25,6 +24,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel

Expand Down
9 changes: 4 additions & 5 deletions tests/utilities/test_apply_func_torchtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.
import pytest
import torch
import torchtext
from torchtext.data.example import Example

from pytorch_lightning.utilities.apply_func import move_data_to_device
from tests.helpers.imports import Dataset, Example, Field, Iterator
from tests.helpers.runif import RunIf


def _get_torchtext_data_iterator(include_lengths=False):
text_field = torchtext.data.Field(
text_field = Field(
sequential=True,
pad_first=False, # nosec
init_token="<s>",
Expand All @@ -33,13 +32,13 @@ def _get_torchtext_data_iterator(include_lengths=False):
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})

dataset = torchtext.data.Dataset(
dataset = Dataset(
[example1, example2, example3],
{"text": text_field},
)
text_field.build_vocab(dataset)

iterator = torchtext.data.Iterator(
iterator = Iterator(
dataset,
batch_size=3,
sort_key=None,
Expand Down