Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix importing torchtext batch #6365

Merged
merged 6 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/docs-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ jobs:
- name: Install dependencies
run: |
python --version
pip --version
# remove Horovod from requirements
python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)"
# python -m pip install --upgrade --user pip
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
pip install --requirement requirements/extra.txt
pip install --requirement requirements/loggers.txt
pip install --requirement requirements/docs.txt
python --version
pip --version
pip list
shell: bash

Expand Down Expand Up @@ -84,12 +84,12 @@ jobs:
- name: Install dependencies
run: |
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
python --version
pip --version
# pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
pip install --requirement requirements/docs.txt
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
python --version
pip --version
pip list
shell: bash

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from abc import ABC
from collections.abc import Mapping, Sequence
from copy import copy
Expand All @@ -22,10 +22,10 @@
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE

if _TORCHTEXT_AVAILABLE:
if _module_available("torchtext.legacy.data"):
if _compare_version("torchtext", operator.ge, "0.9.0"):
from torchtext.legacy.data import Batch
else:
from torchtext.data import Batch
Expand Down
8 changes: 8 additions & 0 deletions tests/helpers/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import operator

from pytorch_lightning.utilities.imports import _compare_version

if _compare_version("torchtext", operator.ge, "0.9.0"):
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
else:
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401
2 changes: 1 addition & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import pytest
import torch
from torchtext.data import Batch, Dataset, Example, Field, LabelField

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
Expand All @@ -25,6 +24,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel

Expand Down
9 changes: 4 additions & 5 deletions tests/utilities/test_apply_func_torchtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.
import pytest
import torch
import torchtext
from torchtext.data.example import Example

from pytorch_lightning.utilities.apply_func import move_data_to_device
from tests.helpers.imports import Dataset, Example, Field, Iterator
from tests.helpers.runif import RunIf


def _get_torchtext_data_iterator(include_lengths=False):
text_field = torchtext.data.Field(
text_field = Field(
sequential=True,
pad_first=False, # nosec
init_token="<s>",
Expand All @@ -33,13 +32,13 @@ def _get_torchtext_data_iterator(include_lengths=False):
example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})

dataset = torchtext.data.Dataset(
dataset = Dataset(
[example1, example2, example3],
{"text": text_field},
)
text_field.build_vocab(dataset)

iterator = torchtext.data.Iterator(
iterator = Iterator(
dataset,
batch_size=3,
sort_key=None,
Expand Down