Skip to content

Commit

Permalink
Optional tf (#3)
Browse files Browse the repository at this point in the history
Move TensorFlow to extra dependencies
  • Loading branch information
ojomio authored Jul 2, 2021
1 parent 7a44be7 commit dae41e2
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ ENV/

# idea
.idea/

# not needed for libraries
poetry.lock
13 changes: 11 additions & 2 deletions maru/resource/rnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import pathlib
import typing
from typing import Dict

import joblib
import tensorflow.keras

from maru.feature.extractor import IFeatureExtractor
from maru.feature.vocabulary import FeatureVocabulary
from maru.tag import Tag

if typing.TYPE_CHECKING:
import tensorflow.keras

_DIRECTORY = pathlib.Path(__file__).parent.absolute()


Expand All @@ -20,7 +23,13 @@ def load_tags() -> Dict[int, Tag]:
return joblib.load(_DIRECTORY / 'tags.joblib')


def load_tagger() -> tensorflow.keras.Model:
def load_tagger() -> 'tensorflow.keras.Model':
try:
import tensorflow.keras
except ModuleNotFoundError:
raise ImportError(
'RNN tagger requires TensorFlow. You can install it with "pip install maru[tf]"'
)
# this restrains tensorflow from allocating all of available GPU memory
config = tensorflow.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "maru"
version = "0.1.3"
version = "0.2.0"
description = "Morphological Analyzer for Russian 💬"
license = "MIT"
authors = ["Vladislav Blinov <cunningplan@yandex.ru>"]
Expand All @@ -19,7 +19,7 @@ numpy = ">=1.15.0"
pymorphy2 = { version = ">=0.8", extras = [ "fast" ] }
scipy = ">=1.1.0"
keras = ">=2.2.2"
tensorflow = ">=1.14.0"
tensorflow = { version = ">=1.14.0", optional = true }
python-crfsuite = ">=0.9.5"
lru-dict = ">=1.1.6"
tensorflow-gpu = { version = ">=1.14.0", optional = true }
Expand All @@ -46,6 +46,7 @@ codecov = "^2.0.15"

[tool.poetry.extras]
gpu = ["tensorflow-gpu"]
tf = ["tensorflow"]

[build-system]
requires = ["poetry>=0.12"]
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ disable=
E203, ; Whitespace
C0330, ; Wrong hanging indentation before block (add 4 spaces)
E0306, ; __repr__ does not return str (invalid-repr-returned)
R0801, ; Similar lines in 2 files

[mypy]
python_version = 3.6
Expand Down
3 changes: 3 additions & 0 deletions tests/tagger/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest

pytest.register_assert_rewrite('tests.tagger.base')
2 changes: 1 addition & 1 deletion tests/tagger/test_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_tagger():
'test',
[
TaggerTest(
words=['настоящий', 'детектив'],
words=['настоящий', 'полковник'],
tags=[
(
0,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ envlist = py36,py37,py38

[testenv]
commands =
poetry install
poetry install -E tf
poetry run pytest tests

0 comments on commit dae41e2

Please sign in to comment.