diff --git a/.gitignore b/.gitignore index 9ee6302..326137d 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,6 @@ ENV/ # idea .idea/ + +# not needed for libraries +poetry.lock diff --git a/maru/resource/rnn/__init__.py b/maru/resource/rnn/__init__.py index 3080854..90401c5 100644 --- a/maru/resource/rnn/__init__.py +++ b/maru/resource/rnn/__init__.py @@ -1,14 +1,17 @@ import json import pathlib +import typing from typing import Dict import joblib -import tensorflow.keras from maru.feature.extractor import IFeatureExtractor from maru.feature.vocabulary import FeatureVocabulary from maru.tag import Tag +if typing.TYPE_CHECKING: + import tensorflow.keras + _DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -20,7 +23,13 @@ def load_tags() -> Dict[int, Tag]: return joblib.load(_DIRECTORY / 'tags.joblib') -def load_tagger() -> tensorflow.keras.Model: +def load_tagger() -> 'tensorflow.keras.Model': + try: + import tensorflow.keras + except ModuleNotFoundError: + raise ImportError( + 'RNN tagger requires TensorFlow. You can install it with "pip install maru[tf]"' + ) # this restrains tensorflow from allocating all of available GPU memory config = tensorflow.compat.v1.ConfigProto() config.gpu_options.allow_growth = True diff --git a/pyproject.toml b/pyproject.toml index fa6106f..53b09f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "maru" -version = "0.1.3" +version = "0.2.0" description = "Morphological Analyzer for Russian πŸ’¬" license = "MIT" authors = ["Vladislav Blinov "] @@ -19,7 +19,7 @@ numpy = ">=1.15.0" pymorphy2 = { version = ">=0.8", extras = [ "fast" ] } scipy = ">=1.1.0" keras = ">=2.2.2" -tensorflow = ">=1.14.0" +tensorflow = { version = ">=1.14.0", optional = true } python-crfsuite = ">=0.9.5" lru-dict = ">=1.1.6" tensorflow-gpu = { version = ">=1.14.0", optional = true } @@ -46,6 +46,7 @@ codecov = "^2.0.15" [tool.poetry.extras] gpu = ["tensorflow-gpu"] +tf = ["tensorflow"] [build-system] requires = ["poetry>=0.12"] diff --git a/setup.cfg b/setup.cfg index 6b34ce6..30c7577 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ disable= E203, ; Whitespace C0330, ; Wrong hanging indentation before block (add 4 spaces) E0306, ; __repr__ does not return str (invalid-repr-returned) + R0801, ; Similar lines in 2 files [mypy] python_version = 3.6 diff --git a/tests/tagger/__init__.py b/tests/tagger/__init__.py index e69de29..3d32e69 100644 --- a/tests/tagger/__init__.py +++ b/tests/tagger/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.register_assert_rewrite('tests.tagger.base') diff --git a/tests/tagger/test_crf.py b/tests/tagger/test_crf.py index 7d5c67b..d6658bc 100644 --- a/tests/tagger/test_crf.py +++ b/tests/tagger/test_crf.py @@ -23,7 +23,7 @@ def create_tagger(): 'test', [ TaggerTest( - words=['настоящий', 'Π΄Π΅Ρ‚Π΅ΠΊΡ‚ΠΈΠ²'], + words=['настоящий', 'ΠΏΠΎΠ»ΠΊΠΎΠ²Π½ΠΈΠΊ'], tags=[ ( 0, diff --git a/tox.ini b/tox.ini index ce52b85..b6a7257 100644 --- a/tox.ini +++ b/tox.ini @@ -3,5 +3,5 @@ envlist = py36,py37,py38 [testenv] commands = - poetry install + poetry install -E tf poetry run pytest tests