From 7a44be7f974c0962f3023f5d064a391d2b4f20b1 Mon Sep 17 00:00:00 2001 From: Vladislav Blinov Date: Thu, 10 Sep 2020 20:13:47 +1000 Subject: [PATCH] Refactor project dependencies (#2) * Refactor project dependencies * Install latest pip and poetry --- .gitattributes | 1 + .travis.yml | 4 +- Makefile | 2 +- maru/resource/crf/__init__.py | 13 +- maru/resource/linear/__init__.py | 20 +- maru/resource/linear/coefficients.gz | 3 + maru/resource/linear/coefficients.joblib | 3 - maru/resource/linear/intercept.gz | 3 + maru/resource/linear/intercept.joblib | 3 - maru/resource/rnn/__init__.py | 28 +- maru/tagger/rnn.py | 2 +- pyproject.toml | 21 +- rnn.ipynb | 1479 ---------------------- setup.cfg | 1 + tests/tagger/base.py | 21 +- tests/tagger/test_crf.py | 92 +- tests/tagger/test_linear.py | 92 +- tests/tagger/test_numerical.py | 94 +- tests/tagger/test_punctuation.py | 53 +- tests/tagger/test_rnn.py | 89 +- 20 files changed, 342 insertions(+), 1682 deletions(-) create mode 100644 maru/resource/linear/coefficients.gz delete mode 100644 maru/resource/linear/coefficients.joblib create mode 100644 maru/resource/linear/intercept.gz delete mode 100644 maru/resource/linear/intercept.joblib delete mode 100644 rnn.ipynb diff --git a/.gitattributes b/.gitattributes index 97e8ab4..6b75148 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2,3 +2,4 @@ *.joblib filter=lfs diff=lfs merge=lfs -text *.crfsuite filter=lfs diff=lfs merge=lfs -text *.h5 filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text diff --git a/.travis.yml b/.travis.yml index 5264f43..680c639 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,8 @@ python: - 3.6 install: - - "pip install pip==18.1" - - "pip install poetry==0.12.10" + - "pip install -U pip" + - "pip install poetry" - "poetry install --no-interaction" jobs: diff --git a/Makefile b/Makefile index f058496..3853e5d 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ lint: $(PYTHON) mypy $(CODE) test: - $(PYTHON) pytest -n 8 --boxed tests + $(PYTHON) pytest tests coverage: $(PYTHON) pytest --cov=maru diff --git a/maru/resource/crf/__init__.py b/maru/resource/crf/__init__.py index ddfbded..e506fbc 100644 --- a/maru/resource/crf/__init__.py +++ b/maru/resource/crf/__init__.py @@ -1,25 +1,24 @@ -import functools -import os +import pathlib from typing import Dict +import joblib import pycrfsuite -from sklearn.externals import joblib from maru.feature.extractor import IFeatureExtractor from maru.tag import Tag -_get_path = functools.partial(os.path.join, os.path.dirname(__file__)) +_DIRECTORY = pathlib.Path(__file__).parent.absolute() def load_extractor() -> IFeatureExtractor: - return joblib.load(_get_path('extractor.joblib')) + return joblib.load(_DIRECTORY / 'extractor.joblib') def load_tags() -> Dict[int, Tag]: - return joblib.load(_get_path('tags.joblib')) + return joblib.load(_DIRECTORY / 'tags.joblib') def load_tagger() -> pycrfsuite.Tagger: tagger = pycrfsuite.Tagger() - tagger.open(_get_path('tagger.crfsuite')) + tagger.open(str(_DIRECTORY / 'tagger.crfsuite')) return tagger diff --git a/maru/resource/linear/__init__.py b/maru/resource/linear/__init__.py index 2c32236..db25b89 100644 --- a/maru/resource/linear/__init__.py +++ b/maru/resource/linear/__init__.py @@ -1,35 +1,37 @@ -import functools +import gzip import json -import os +import pathlib from typing import Dict +import joblib import numpy -from sklearn.externals import joblib from maru.feature.extractor import IFeatureExtractor from maru.feature.vocabulary import PositionalFeatureVocabulary from maru.tag import Tag -_get_path = functools.partial(os.path.join, os.path.dirname(__file__)) +_DIRECTORY = pathlib.Path(__file__).parent.absolute() def load_extractor() -> IFeatureExtractor: - return joblib.load(_get_path('extractor.joblib')) + return joblib.load(_DIRECTORY / 'extractor.joblib') def load_vocabulary() -> PositionalFeatureVocabulary: - with open(_get_path('vocabulary.json'), encoding='utf8') as f: + with (_DIRECTORY / 'vocabulary.json').open(encoding='utf8') as f: data = {int(index): mapping for index, mapping in json.load(f).items()} return PositionalFeatureVocabulary(data) def load_tags() -> Dict[int, Tag]: - return joblib.load(_get_path('tags.joblib')) + return joblib.load(_DIRECTORY / 'tags.joblib') def load_coefficients() -> numpy.array: - return joblib.load(_get_path('coefficients.joblib')) + with gzip.open(_DIRECTORY / 'coefficients.gz', 'rb') as data: + return numpy.load(data) def load_intercept() -> numpy.array: - return joblib.load(_get_path('intercept.joblib')) + with gzip.open(_DIRECTORY / 'intercept.gz', 'rb') as data: + return numpy.load(data) diff --git a/maru/resource/linear/coefficients.gz b/maru/resource/linear/coefficients.gz new file mode 100644 index 0000000..52060b4 --- /dev/null +++ b/maru/resource/linear/coefficients.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31515f519ba561b7cedf2eba896228bd91d3a3ebbf312e6f0fbeac60d3b5bdfc +size 21510976 diff --git a/maru/resource/linear/coefficients.joblib b/maru/resource/linear/coefficients.joblib deleted file mode 100644 index 74410c4..0000000 --- a/maru/resource/linear/coefficients.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6516a68a68ed47347fe49118a798323bbb9387cdceee9055a5ba9910b195dfdb -size 22331754 diff --git a/maru/resource/linear/intercept.gz b/maru/resource/linear/intercept.gz new file mode 100644 index 0000000..339c001 --- /dev/null +++ b/maru/resource/linear/intercept.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d22cf2dc4436b2d9ebce1493bbdb207fc828ef1729dd0eed9f022993fcd29cdd +size 2543 diff --git a/maru/resource/linear/intercept.joblib b/maru/resource/linear/intercept.joblib deleted file mode 100644 index 66b55f1..0000000 --- a/maru/resource/linear/intercept.joblib +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5f5f4171e03c95332c55ff3bcdc0c26c8df2915885702ce24fadc8d3810151d2 -size 2645 diff --git a/maru/resource/rnn/__init__.py b/maru/resource/rnn/__init__.py index bc6e1c1..3080854 100644 --- a/maru/resource/rnn/__init__.py +++ b/maru/resource/rnn/__init__.py @@ -1,42 +1,42 @@ -import functools import json -import os +import pathlib from typing import Dict -import keras -import tensorflow -from sklearn.externals import joblib +import joblib +import tensorflow.keras from maru.feature.extractor import IFeatureExtractor from maru.feature.vocabulary import FeatureVocabulary from maru.tag import Tag -_get_path = functools.partial(os.path.join, os.path.dirname(__file__)) +_DIRECTORY = pathlib.Path(__file__).parent.absolute() def load_extractor() -> IFeatureExtractor: - return joblib.load(_get_path('extractor.joblib')) + return joblib.load(_DIRECTORY / 'extractor.joblib') def load_tags() -> Dict[int, Tag]: - return joblib.load(_get_path('tags.joblib')) + return joblib.load(_DIRECTORY / 'tags.joblib') -def load_tagger() -> keras.Model: +def load_tagger() -> tensorflow.keras.Model: # this restrains tensorflow from allocating all of available GPU memory - config = tensorflow.ConfigProto() + config = tensorflow.compat.v1.ConfigProto() config.gpu_options.allow_growth = True - keras.backend.set_session(tensorflow.Session(config=config)) + tensorflow.compat.v1.keras.backend.set_session( + tensorflow.compat.v1.Session(config=config) + ) - return keras.models.load_model(_get_path('tagger.h5')) + return tensorflow.keras.models.load_model(_DIRECTORY / 'tagger.h5') def load_char_vocabulary() -> FeatureVocabulary: - with open(_get_path('char_vocabulary.json'), encoding='utf8') as f: + with (_DIRECTORY / 'char_vocabulary.json').open(encoding='utf8') as f: return FeatureVocabulary(json.load(f)) def load_grammeme_vocabulary() -> FeatureVocabulary: - with open(_get_path('grammeme_vocabulary.json'), encoding='utf8') as f: + with (_DIRECTORY / 'grammeme_vocabulary.json').open(encoding='utf8') as f: return FeatureVocabulary(json.load(f)) diff --git a/maru/tagger/rnn.py b/maru/tagger/rnn.py index 1f410a3..7414a39 100644 --- a/maru/tagger/rnn.py +++ b/maru/tagger/rnn.py @@ -19,7 +19,7 @@ def __init__(self, cache_size: Optional[int] = 15000): self._tagger = rnn.load_tagger() self._tags = rnn.load_tags() - _, _, max_word_length = self._tagger.get_layer(_CHAR_INPUT).input_shape + _, _, max_word_length = self._tagger.get_layer(_CHAR_INPUT).input_shape[0] grammeme_vocabulary = rnn.load_grammeme_vocabulary() char_vocabulary = rnn.load_char_vocabulary() diff --git a/pyproject.toml b/pyproject.toml index 99c699e..fa6106f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "maru" -version = "0.1.2" +version = "0.1.3" description = "Morphological Analyzer for Russian 💬" license = "MIT" authors = ["Vladislav Blinov "] @@ -15,20 +15,19 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.6" -numpy = "^1.15.0" -pymorphy2 = { version = "^0.8", extras = [ "fast" ] } -scipy = "^1.1.0" -keras = "^2.2.2" -tensorflow = ">=1.9.0, <1.15.0" -scikit-learn = "^0.19.0" -python-crfsuite = "^0.9.5" -lru-dict = "^1.1.6" -tensorflow-gpu = { version = ">=1.9.0, <1.15.0", optional = true } +numpy = ">=1.15.0" +pymorphy2 = { version = ">=0.8", extras = [ "fast" ] } +scipy = ">=1.1.0" +keras = ">=2.2.2" +tensorflow = ">=1.14.0" +python-crfsuite = ">=0.9.5" +lru-dict = ">=1.1.6" +tensorflow-gpu = { version = ">=1.14.0", optional = true } +joblib = ">=0.11.0" [tool.poetry.dev-dependencies] pytest = "^5.2.2" pytest-cov = "^2.8.1" -pytest-xdist = "^1.30.0" mypy = "^0.740" flake8 = "^3.7.9" flake8-isort = "^2.7.0" diff --git a/rnn.ipynb b/rnn.ipynb deleted file mode 100644 index d103f7e..0000000 --- a/rnn.ipynb +++ /dev/null @@ -1,1479 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def read_corpus(path):\n", - " with open(path, encoding='utf8') as file:\n", - " sentence = []\n", - " for line in file:\n", - " if not line.strip():\n", - " yield sentence\n", - " sentence = []\n", - " else:\n", - " _, word, lemma, pos, tagline, *_ = line.split()\n", - " \n", - " tags = {}\n", - " if tagline != '_':\n", - " tags.update(elem.split(\"=\") for elem in tagline.split(\"|\"))\n", - "\n", - " sentence.append({'Word': word, 'POS': pos, 'Lemma': lemma.lower(), 'Tags': tags})" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# data is located at https://github.com/dialogue-evaluation/morphoRuEval-2017\n", - "data = []\n", - "for corpus in [\n", - " 'models/data/gikrya_new_train.out', \n", - " 'models/data/gikrya_new_test.out', \n", - "]:\n", - " data.extend(read_corpus(corpus))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "import unicodedata\n", - "\n", - "\n", - "def preprocess(word):\n", - " word = word.strip().lower().replace('_', ' ')\n", - " word = re.sub('\\d', 'D', word)\n", - " word = word.replace('', '.')\n", - " return word\n", - "\n", - "\n", - "def is_useful_example(word):\n", - " return (\n", - " word['POS'] != 'PUNCT' and \n", - " word['Tags'].get('NumForm') != 'Digit' and \n", - " not (word['POS'] == 'VERB' and word['Word'] in ['гуля', 'МАША']) and\n", - " not all(unicodedata.category(ch)[0] == 'P' for ch in word['Word']) and\n", - " not re.match('\\d+([.,]\\d+)?$', word['Word'])\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import pymorphy2\n", - "\n", - "morph = pymorphy2.MorphAnalyzer()\n", - "\n", - "for i, word in enumerate(word for sent in data for word in sent if is_useful_example(word)):\n", - " if word['POS'] == 'VERB' and word['Tags'].get('Tense') == 'Notpast':\n", - " if word['Word'].lower() in ['нет', 'нету', 'мятясь', 'внемлет', 'ебу', 'упоминаеться']:\n", - " word['Tags']['Tense'] = 'Pres'\n", - " else:\n", - " parses = morph.parse(word['Word'])\n", - " for parse in parses:\n", - " if parse.tag.POS in ['VERB', 'GRND']:\n", - " if parse.tag.tense == 'futr':\n", - " word['Tags']['Tense'] = 'Fut'\n", - " break\n", - " elif parse.tag.tense == 'pres':\n", - " word['Tags']['Tense'] = 'Pres'\n", - " break\n", - " else:\n", - " word['Tags'].pop('Tense')" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "for i, word in enumerate(word for sent in data for word in sent if is_useful_example(word)):\n", - " if word['POS'] == 'VERB':\n", - " parses = morph.parse(word['Word'])\n", - " for parse in parses:\n", - " if parse.tag.POS in ['VERB']:\n", - " if parse.tag.aspect == 'perf':\n", - " word['Tags']['Aspect'] = 'Perf'\n", - " break\n", - " elif parse.tag.aspect == 'impf':\n", - " word['Tags']['Aspect'] = 'Imp'\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "for i, word in enumerate(word for sent in data for word in sent if is_useful_example(word)):\n", - " if word['POS'] == 'ADJ' and 'Variant' not in word['Tags']:\n", - " word['Tags']['Variant'] = 'Full'" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using TensorFlow backend.\n" - ] - } - ], - "source": [ - "from sklearn.preprocessing import LabelEncoder\n", - "from maru.grammeme import (\n", - " Animacy,\n", - " Aspect,\n", - " Case,\n", - " Degree,\n", - " Gender,\n", - " Mood,\n", - " Number,\n", - " NumericalForm,\n", - " Person,\n", - " PartOfSpeech,\n", - " Tense,\n", - " Variant,\n", - " VerbForm,\n", - " Voice,\n", - ")\n", - "from maru.tag import Tag\n", - "\n", - "GRAMMEMES = {\n", - " 'animacy': Animacy,\n", - " 'aspect': Aspect,\n", - " 'case': Case,\n", - " 'degree': Degree,\n", - " 'gender': Gender,\n", - " 'mood': Mood,\n", - " 'number': Number,\n", - " 'numform': NumericalForm,\n", - " 'person': Person,\n", - " 'pos': PartOfSpeech,\n", - " 'tense': Tense,\n", - " 'variant': Variant,\n", - " 'verbform': VerbForm,\n", - " 'voice': Voice,\n", - "}\n", - "\n", - "\n", - "\n", - "def to_tag(parts):\n", - " grammemes = {}\n", - "\n", - " for part in parts:\n", - " label, value = part.split('=')\n", - " grammeme = GRAMMEMES[label]\n", - " grammemes[label] = grammeme(value)\n", - "\n", - " return Tag(**grammemes)\n", - "\n", - "\n", - "def get_class(word): \n", - " return to_tag([f\"pos={word['POS']}\"] + [f'{name.lower()}={value}' for name, value in word['Tags'].items()])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tags = {}\n", - "y = []\n", - "\n", - "for sent in data:\n", - " classes = []\n", - " for word in sent:\n", - " cls = get_class(word) if is_useful_example(word) else ''\n", - " if cls:\n", - " cls = tags.setdefault(cls, str(len(tags) + 1))\n", - " classes.append(cls)\n", - " y.append(classes)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "83150" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(y)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "343" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(tags)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "\n", - "char_vocabulary = {}\n", - "\n", - "for sent in data:\n", - " for word in sent:\n", - " for sym in preprocess(word['Word']):\n", - " char_vocabulary.setdefault(sym, len(char_vocabulary) + 1)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'ч': 1,\n", - " 'ь': 2,\n", - " 'я': 3,\n", - " '-': 4,\n", - " 'т': 5,\n", - " 'о': 6,\n", - " 'р': 7,\n", - " 'у': 8,\n", - " 'к': 9,\n", - " 'а': 10,\n", - " 'л': 11,\n", - " 'е': 12,\n", - " 'г': 13,\n", - " 'м': 14,\n", - " 'н': 15,\n", - " 'п': 16,\n", - " '.': 17,\n", - " 'д': 18,\n", - " ',': 19,\n", - " 'з': 20,\n", - " 'в': 21,\n", - " 'ж': 22,\n", - " 'и': 23,\n", - " 'б': 24,\n", - " 'с': 25,\n", - " 'ц': 26,\n", - " 'ю': 27,\n", - " 'ш': 28,\n", - " 'ы': 29,\n", - " 'х': 30,\n", - " 'э': 31,\n", - " 'й': 32,\n", - " 'щ': 33,\n", - " 'ф': 34,\n", - " 'ё': 35,\n", - " ':': 36,\n", - " '—': 37,\n", - " ' ': 38,\n", - " 'ъ': 39,\n", - " 'D': 40,\n", - " ')': 41,\n", - " '?': 42,\n", - " '!': 43,\n", - " '\"': 44,\n", - " '(': 45,\n", - " ';': 46,\n", - " '/': 47,\n", - " '[': 48,\n", - " ']': 49,\n", - " '+': 50,\n", - " '>': 51,\n", - " '<': 52,\n", - " \"'\": 53,\n", - " '|': 54}" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "char_vocabulary" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "import maru.feature.extractor\n", - "import maru.feature.window\n", - "\n", - "extractor = maru.feature.extractor.Cache(\n", - " maru.feature.extractor.PymorphyExtractor(hypotheses=10000000),\n", - " size=40000,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "from maru.feature.vocabulary import FeatureVocabulary" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "features = []\n", - "for sent in data:\n", - " for word in sent:\n", - " if is_useful_example(word):\n", - " features.append(list(extractor.extract(preprocess(word['Word']))))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "grammeme_vocabulary = FeatureVocabulary.train(features, min_count=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "599" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(grammeme_vocabulary)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "from maru.vectorizer.sparse import SparseFeatureVectorizer\n", - "\n", - "vectorizer = SparseFeatureVectorizer(grammeme_vocabulary)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "word_data = [[preprocess(word['Word']) for word in sent] for sent in data]" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.model_selection import train_test_split\n", - "\n", - "train_data, test_data, train_labels, test_labels = train_test_split(word_data, y, test_size=0.05, random_state=12)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "from keras.layers import Input, Embedding, BatchNormalization, Activation\n", - "from keras.layers.core import Dense, Reshape, Dropout\n", - "from keras.layers.recurrent import LSTM\n", - "from keras.layers.wrappers import Bidirectional, TimeDistributed\n", - "from keras.layers.merge import concatenate\n", - "from keras.models import Model\n", - "\n", - "MAX_WORD_LENGTH = 12\n", - "\n", - "VOCABULARY_SIZE = len(char_vocabulary) + 1\n", - "CLASS_COUNT = 344\n", - "\n", - "RNN_DROPOUT = 0.2\n", - "\n", - "CHAR_EMBEDDING_SIZE = 24\n", - "CHAR_HIDDEN_LAYER_SIZE = 256\n", - "CHAR_OUTPUT_LAYER_SIZE = 256\n", - "CHAR_EMBEDDING_DROPOUT = 0.3\n", - "\n", - "GRAMMEME_EMBEDDING_SIZE = 128\n", - "GRAMMEME_DROPOUT = 0.3\n", - "\n", - "LSTM_INPUT_SIZE = 128\n", - "\n", - "WORD_LSTM_SIZE = 128\n", - "\n", - "DENSE_SIZE = 128\n", - "DENSE_DROPOUT = 0.5\n", - "\n", - "\n", - "def create_grammeme_embedding():\n", - " grammeme_input = Input(shape=(None, len(grammeme_vocabulary)), name='grammemes')\n", - " grammeme_embedding = Dropout(GRAMMEME_DROPOUT)(grammeme_input)\n", - " grammeme_embedding = Dense(GRAMMEME_EMBEDDING_SIZE, activation='relu')(grammeme_embedding)\n", - " return grammeme_input, grammeme_embedding\n", - "\n", - "\n", - "def create_char_embedding():\n", - " char_input = Input(shape=(None, MAX_WORD_LENGTH), name='chars')\n", - " char_dropout = Dropout(CHAR_EMBEDDING_DROPOUT)\n", - " char_embedding = Embedding(VOCABULARY_SIZE, CHAR_EMBEDDING_SIZE, name='char_embedding')\n", - " char_embedding = TimeDistributed(char_embedding)(char_input)\n", - " char_embedding = Reshape((-1, MAX_WORD_LENGTH * CHAR_EMBEDDING_SIZE))(char_embedding)\n", - " char_embedding = char_dropout(char_embedding)\n", - " char_embedding = char_dropout(Dense(CHAR_HIDDEN_LAYER_SIZE, activation='relu')(char_embedding))\n", - " char_embedding = char_dropout(char_embedding)\n", - " char_embedding = char_dropout(Dense(CHAR_OUTPUT_LAYER_SIZE, activation='relu')(char_embedding))\n", - " return char_input, char_embedding\n", - "\n", - "\n", - "def create_network():\n", - " grammeme_input, grammeme_embedding = create_grammeme_embedding()\n", - " char_input, char_embedding = create_char_embedding()\n", - "\n", - " embeddings = concatenate([grammeme_embedding, char_embedding], name='lstm_input')\n", - "\n", - " lstm_input = Dense(LSTM_INPUT_SIZE, activation='relu')(embeddings)\n", - " \n", - " lstm_1 = LSTM(WORD_LSTM_SIZE, dropout=RNN_DROPOUT, recurrent_dropout=RNN_DROPOUT, return_sequences=True, name='lstm_1')\n", - " lstm_1 = Bidirectional(lstm_1)(lstm_input)\n", - " \n", - " lstm_2 = LSTM(WORD_LSTM_SIZE, dropout=RNN_DROPOUT, recurrent_dropout=RNN_DROPOUT, return_sequences=True, name='lstm_2')\n", - " lstm_2 = Bidirectional(lstm_2)(lstm_1)\n", - " \n", - " dense = TimeDistributed(Dense(DENSE_SIZE))(lstm_2)\n", - " dense = TimeDistributed(Dropout(DENSE_DROPOUT))(dense)\n", - " dense = TimeDistributed(BatchNormalization())(dense)\n", - " dense = TimeDistributed(Activation('relu'))(dense)\n", - " \n", - " prob = Dense(CLASS_COUNT, activation='softmax')(dense)\n", - "\n", - " model = Model(inputs=[grammeme_input, char_input], outputs=prob)\n", - " model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])\n", - " \n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "model = create_network()" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "__________________________________________________________________________________________________\n", - "Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - "chars (InputLayer) (None, None, 12) 0 \n", - "__________________________________________________________________________________________________\n", - "time_distributed_1 (TimeDistrib (None, None, 12, 24) 1320 chars[0][0] \n", - "__________________________________________________________________________________________________\n", - "reshape_1 (Reshape) (None, None, 288) 0 time_distributed_1[0][0] \n", - "__________________________________________________________________________________________________\n", - "dropout_2 (Dropout) multiple 0 reshape_1[0][0] \n", - " dense_2[0][0] \n", - " dropout_2[1][0] \n", - " dense_3[0][0] \n", - "__________________________________________________________________________________________________\n", - "dense_2 (Dense) (None, None, 256) 73984 dropout_2[0][0] \n", - "__________________________________________________________________________________________________\n", - "grammemes (InputLayer) (None, None, 599) 0 \n", - "__________________________________________________________________________________________________\n", - "dropout_1 (Dropout) (None, None, 599) 0 grammemes[0][0] \n", - "__________________________________________________________________________________________________\n", - "dense_3 (Dense) (None, None, 256) 65792 dropout_2[2][0] \n", - "__________________________________________________________________________________________________\n", - "dense_1 (Dense) (None, None, 128) 76800 dropout_1[0][0] \n", - "__________________________________________________________________________________________________\n", - "lstm_input (Concatenate) (None, None, 384) 0 dense_1[0][0] \n", - " dropout_2[3][0] \n", - "__________________________________________________________________________________________________\n", - "dense_4 (Dense) (None, None, 128) 49280 lstm_input[0][0] \n", - "__________________________________________________________________________________________________\n", - "bidirectional_1 (Bidirectional) (None, None, 256) 263168 dense_4[0][0] \n", - "__________________________________________________________________________________________________\n", - "bidirectional_2 (Bidirectional) (None, None, 256) 394240 bidirectional_1[0][0] \n", - "__________________________________________________________________________________________________\n", - "time_distributed_2 (TimeDistrib (None, None, 128) 32896 bidirectional_2[0][0] \n", - "__________________________________________________________________________________________________\n", - "time_distributed_3 (TimeDistrib (None, None, 128) 0 time_distributed_2[0][0] \n", - "__________________________________________________________________________________________________\n", - "time_distributed_4 (TimeDistrib (None, None, 128) 512 time_distributed_3[0][0] \n", - "__________________________________________________________________________________________________\n", - "time_distributed_5 (TimeDistrib (None, None, 128) 0 time_distributed_4[0][0] \n", - "__________________________________________________________________________________________________\n", - "dense_6 (Dense) (None, None, 344) 44376 time_distributed_5[0][0] \n", - "==================================================================================================\n", - "Total params: 1,002,368\n", - "Trainable params: 1,002,112\n", - "Non-trainable params: 256\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import tqdm\n", - "from keras.callbacks import Callback\n", - "\n", - "\n", - "class ModelEvaluation(Callback):\n", - " def on_epoch_end(self, epoch, logs=None):\n", - " if (epoch + 1) % 5 == 0:\n", - " predictions = []\n", - " for sent, lab in tqdm.tqdm_notebook(zip(test_data, test_labels), total=len(test_data)):\n", - " predictions.append(\n", - " model.predict_generator(\n", - " iter_batches([sent], [lab], 1),\n", - " steps=1,\n", - " ).argmax(axis=2)\n", - " )\n", - "\n", - " tag_acc = []\n", - " sent_acc = []\n", - " for pred, true in zip(predictions, test_labels):\n", - " pred = pred[0]\n", - " true = [int(x or 0) for x in true]\n", - "\n", - " sent_acc.append(all(x == y for x, y in zip(pred, true)))\n", - " tag_acc.extend(x == y for x, y in zip(pred, true))\n", - "\n", - " print(f'Tag accuracy: {numpy.mean(tag_acc)}')\n", - " print(f'Sentence accuracy: {numpy.mean(sent_acc)}')" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "import random\n", - "\n", - "from keras.preprocessing.sequence import pad_sequences\n", - "from keras.utils import to_categorical\n", - "\n", - "\n", - "buckets = [\n", - " range(1, 7),\n", - " range(7, 15),\n", - " range(15, 26),\n", - " range(26, 41),\n", - " range(41, 51),\n", - " range(51, 1000000),\n", - "]\n", - "\n", - "\n", - "def get_input(batch):\n", - " batch_labels = [\n", - " to_categorical(\n", - " [int(label or 0) for label in sent], \n", - " num_classes=CLASS_COUNT,\n", - " )\n", - " for _, sent in batch\n", - " ]\n", - " batch_labels = pad_sequences(batch_labels, value=[0] * CLASS_COUNT)\n", - "\n", - " batch_data_chars = []\n", - " for sent, _ in batch:\n", - " words = []\n", - " for word in sent:\n", - " words.append([char_vocabulary.get(sym, 0) for sym in word])\n", - "\n", - " words = pad_sequences(words, maxlen=MAX_WORD_LENGTH, padding='pre')\n", - " batch_data_chars.append(words)\n", - " batch_data_chars = pad_sequences(batch_data_chars, value=[0] * MAX_WORD_LENGTH)\n", - "\n", - " batch_data_grammemes = []\n", - " for sent, _ in batch:\n", - " features = []\n", - " for word in sent:\n", - " features.append(extractor.extract(word))\n", - "\n", - " features = vectorizer.transform(features)\n", - " batch_data_grammemes.append(features.todense())\n", - " batch_data_grammemes = pad_sequences(batch_data_grammemes, value=[0] * len(grammeme_vocabulary))\n", - "\n", - " return {'grammemes': batch_data_grammemes, 'chars': batch_data_chars}, batch_labels\n", - "\n", - "\n", - "def iter_batches(data, labels, batch_size):\n", - " data_labels = list(zip(data, labels))\n", - "\n", - " while True: \n", - " random.shuffle(data_labels)\n", - " \n", - " batches = [([], sizes) for sizes in buckets]\n", - " \n", - " for x in data_labels:\n", - " for batch, size in batches:\n", - " if len(x[0]) in size:\n", - " batch.append(x)\n", - " if len(batch) == batch_size:\n", - " yield get_input(batch)\n", - " batch.clear()\n", - " break\n", - " else:\n", - " raise AssertionError(f'Bucket not found for sentence of length {x[0]}')\n", - " for batch, _ in batches:\n", - " if batch:\n", - " yield get_input(batch)\n", - " batch.clear()" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Training', max=50), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 0', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 1', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 2', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 3', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 4', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9740025978845797\n", - "Sentence accuracy: 0.7493987493987494\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 5', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 6', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 7', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 8', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 9', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9742809426609761\n", - "Sentence accuracy: 0.7532467532467533\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 10', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 11', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 12', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 13', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 14', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9747448506216366\n", - "Sentence accuracy: 0.7556517556517557\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 15', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 16', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 17', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 18', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 19', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9747819632584895\n", - "Sentence accuracy: 0.7566137566137566\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 20', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 21', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 22', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 23', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 24', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9750788643533123\n", - "Sentence accuracy: 0.7619047619047619\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 25', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 26', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 27', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 28', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 29', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9749489701243274\n", - "Sentence accuracy: 0.7602212602212602\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 30', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 31', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 32', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 33', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 34', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9748561885321952\n", - "Sentence accuracy: 0.7602212602212602\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 35', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 36', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 37', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 38', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 39', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9753200964928558\n", - "Sentence accuracy: 0.7633477633477633\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 40', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 41', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 42', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 43', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 44', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9755242159955465\n", - "Sentence accuracy: 0.765993265993266\n" - ] - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 45', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 46', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 47', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 48', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Epoch 49', max=155), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "HBox(children=(IntProgress(value=0, max=4158), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tag accuracy: 0.9751345333085916\n", - "Sentence accuracy: 0.7623857623857624\n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import math\n", - "\n", - "from keras.callbacks import ModelCheckpoint, TensorBoard\n", - "from keras_tqdm.tqdm_notebook_callback import TQDMNotebookCallback\n", - "\n", - "BATCH_SIZE = 256\n", - "\n", - "model.fit_generator(\n", - " epochs=50,\n", - " verbose=0,\n", - "\n", - " generator=iter_batches(train_data, train_labels, BATCH_SIZE),\n", - " steps_per_epoch=int(math.ceil(len(train_data) / BATCH_SIZE)),\n", - "\n", - " validation_data=iter_batches(test_data, test_labels, BATCH_SIZE),\n", - " validation_steps=int(math.ceil(len(test_data) / BATCH_SIZE)),\n", - "\n", - " callbacks=[\n", - " TQDMNotebookCallback(),\n", - " ModelEvaluation(),\n", - " ModelCheckpoint('{epoch:02d}-{val_loss:.2f}.h5', monitor='val_loss'),\n", - " TensorBoard(log_dir='rnn_logs'),\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [], - "source": [ - "from keras.models import load_model\n", - "\n", - "model_best = load_model('45-0.07.h5')\n", - "model_best.save('maru/model/rnn/tagger.h5')" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['maru/model/rnn/extractor.joblib']" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import pickle\n", - "from sklearn.externals import joblib\n", - "\n", - "joblib.dump(maru.feature.extractor.PymorphyExtractor(hypotheses=10000000), 'maru/model/rnn/extractor.joblib', compress=True, protocol=pickle.HIGHEST_PROTOCOL)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "with open('maru/model/rnn/grammeme_vocabulary.json', 'w', encoding='utf8') as f:\n", - " json.dump(grammeme_vocabulary, f, indent=4, ensure_ascii=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['maru/model/rnn/tags.joblib']" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "joblib.dump({int(num): tag for tag, num in tags.items()}, 'maru/model/rnn/tags.joblib', compress=True, protocol=pickle.HIGHEST_PROTOCOL)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open('maru/model/rnn/char_vocabulary.json', 'w', encoding='utf8') as f:\n", - " json.dump(char_vocabulary, f, indent=4, ensure_ascii=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/setup.cfg b/setup.cfg index 1f309d3..6b34ce6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ disable= E1130, ; Bad operand type for unary E203, ; Whitespace C0330, ; Wrong hanging indentation before block (add 4 spaces) + E0306, ; __repr__ does not return str (invalid-repr-returned) [mypy] python_version = 3.6 diff --git a/tests/tagger/base.py b/tests/tagger/base.py index a2bbd2b..5cfe222 100644 --- a/tests/tagger/base.py +++ b/tests/tagger/base.py @@ -1,16 +1,17 @@ -from typing import Sequence +from typing import NamedTuple, Optional, Sequence from maru.tagger.abstract import ITagger, Tagged from maru.types import Indices, Word -def assert_tags_equal( - tagger: ITagger, - expected: Sequence[Tagged], - words: Sequence[Word], - indices: Indices = None, -): - if indices is None: - indices = range(len(words)) +class TaggerTest(NamedTuple): + words: Sequence[Word] + tags: Sequence[Tagged] + indices: Optional[Indices] = None - assert expected == list(tagger.tag(words, indices)) + def run(self, tagger: ITagger): + indices = self.indices + if indices is None: + indices = range(len(self.words)) + + assert self.tags == list(tagger.tag(self.words, indices)) diff --git a/tests/tagger/test_crf.py b/tests/tagger/test_crf.py index c078ef2..7d5c67b 100644 --- a/tests/tagger/test_crf.py +++ b/tests/tagger/test_crf.py @@ -1,3 +1,5 @@ +import pytest + from maru.grammeme import ( Animacy, Case, @@ -9,34 +11,70 @@ ) from maru.tag import Tag from maru.tagger import CRFTagger -from tests.tagger.base import assert_tags_equal +from tests.tagger.base import TaggerTest + + +@pytest.fixture(name='tagger', scope='session') +def create_tagger(): + return CRFTagger() -def test(): - assert_tags_equal( - tagger=CRFTagger(), - expected=[ - ( - 0, - Tag( - pos=PartOfSpeech.ADJECTIVE, - case=Case.NOMINATIVE, - degree=Degree.POSITIVE, - gender=Gender.MASCULINE, - number=Number.SINGULAR, - variant=Variant.FULL, +@pytest.mark.parametrize( + 'test', + [ + TaggerTest( + words=['настоящий', 'детектив'], + tags=[ + ( + 0, + Tag( + pos=PartOfSpeech.ADJECTIVE, + case=Case.NOMINATIVE, + degree=Degree.POSITIVE, + gender=Gender.MASCULINE, + number=Number.SINGULAR, + variant=Variant.FULL, + ), ), - ), - ( - 1, - Tag( - pos=PartOfSpeech.NOUN, - animacy=Animacy.ANIMATE, - case=Case.NOMINATIVE, - gender=Gender.MASCULINE, - number=Number.SINGULAR, + ( + 1, + Tag( + pos=PartOfSpeech.NOUN, + animacy=Animacy.ANIMATE, + case=Case.NOMINATIVE, + gender=Gender.MASCULINE, + number=Number.SINGULAR, + ), ), - ), - ], - words=['настоящий', 'детектив'], - ) + ], + ), + TaggerTest( + words=['настоящий', 'робот'], + tags=[ + ( + 0, + Tag( + pos=PartOfSpeech.ADJECTIVE, + case=Case.NOMINATIVE, + degree=Degree.POSITIVE, + gender=Gender.MASCULINE, + number=Number.SINGULAR, + variant=Variant.FULL, + ), + ), + ( + 1, + Tag( + pos=PartOfSpeech.NOUN, + animacy=Animacy.INANIMATE, + case=Case.NOMINATIVE, + gender=Gender.MASCULINE, + number=Number.SINGULAR, + ), + ), + ], + ), + ], +) +def test_crf(test, tagger): + test.run(tagger) diff --git a/tests/tagger/test_linear.py b/tests/tagger/test_linear.py index 24ad9be..944ede4 100644 --- a/tests/tagger/test_linear.py +++ b/tests/tagger/test_linear.py @@ -1,3 +1,5 @@ +import pytest + from maru.grammeme import ( Animacy, Case, @@ -9,34 +11,70 @@ ) from maru.tag import Tag from maru.tagger import LinearTagger -from tests.tagger.base import assert_tags_equal +from tests.tagger.base import TaggerTest + + +@pytest.fixture(name='tagger', scope='session') +def create_tagger(): + return LinearTagger() -def test(): - assert_tags_equal( - tagger=LinearTagger(), - expected=[ - ( - 0, - Tag( - pos=PartOfSpeech.ADJECTIVE, - case=Case.NOMINATIVE, - degree=Degree.POSITIVE, - gender=Gender.NEUTER, - number=Number.SINGULAR, - variant=Variant.FULL, +@pytest.mark.parametrize( + 'test', + [ + TaggerTest( + words=['чёрное', 'зеркало'], + tags=[ + ( + 0, + Tag( + pos=PartOfSpeech.ADJECTIVE, + case=Case.NOMINATIVE, + degree=Degree.POSITIVE, + gender=Gender.NEUTER, + number=Number.SINGULAR, + variant=Variant.FULL, + ), ), - ), - ( - 1, - Tag( - pos=PartOfSpeech.NOUN, - animacy=Animacy.INANIMATE, - case=Case.NOMINATIVE, - gender=Gender.NEUTER, - number=Number.SINGULAR, + ( + 1, + Tag( + pos=PartOfSpeech.NOUN, + animacy=Animacy.INANIMATE, + case=Case.NOMINATIVE, + gender=Gender.NEUTER, + number=Number.SINGULAR, + ), ), - ), - ], - words=['чёрное', 'зеркало'], - ) + ], + ), + TaggerTest( + words=['чёрного', 'зеркала'], + tags=[ + ( + 0, + Tag( + pos=PartOfSpeech.ADJECTIVE, + case=Case.GENITIVE, + degree=Degree.POSITIVE, + gender=Gender.NEUTER, + number=Number.SINGULAR, + variant=Variant.FULL, + ), + ), + ( + 1, + Tag( + pos=PartOfSpeech.NOUN, + animacy=Animacy.INANIMATE, + case=Case.GENITIVE, + gender=Gender.NEUTER, + number=Number.SINGULAR, + ), + ), + ], + ), + ], +) +def test_linear(test, tagger): + test.run(tagger) diff --git a/tests/tagger/test_numerical.py b/tests/tagger/test_numerical.py index 90164c8..816cf50 100644 --- a/tests/tagger/test_numerical.py +++ b/tests/tagger/test_numerical.py @@ -1,49 +1,63 @@ +import pytest + from maru.grammeme import NumericalForm, PartOfSpeech from maru.tag import Tag from maru.tagger import NumericalTagger -from tests.tagger.base import assert_tags_equal +from tests.tagger.base import TaggerTest _INTEGER = Tag(pos=PartOfSpeech.NUMERICAL, numform=NumericalForm.INTEGER) _REAL = Tag(pos=PartOfSpeech.NUMERICAL, numform=NumericalForm.REAL) _RANGE = Tag(pos=PartOfSpeech.NUMERICAL, numform=NumericalForm.RANGE) -def test_integer(): - assert_tags_equal( - tagger=NumericalTagger(), - expected=[(0, _INTEGER), (1, _INTEGER)], - words=['123', '51515'], - ) - - -def test_real(): - assert_tags_equal( - tagger=NumericalTagger(), - expected=[(0, _REAL), (1, _REAL), (2, _REAL)], - words=['123.1231', '1231,34555', '2/3'], - ) - - -def test_indices(): - assert_tags_equal( - tagger=NumericalTagger(), - expected=[(0, _REAL), (2, _INTEGER)], - words=['1.1', '123', '567'], - indices=[0, 2], - ) - - -def test_numerical_range(): - assert_tags_equal( - tagger=NumericalTagger(), - expected=[(0, _RANGE), (1, _RANGE)], - words=['16-18', '1942—1944'], - ) - - -def test_non_numerical(): - assert_tags_equal( - tagger=NumericalTagger(), - expected=[], - words=['', ' ', '!!!!', 'XV', 'unknown', '<<123>>', '23years'], - ) +@pytest.fixture(name='tagger', scope='session') +def create_tagger(): + return NumericalTagger() + + +@pytest.mark.parametrize( + 'test', + [ + TaggerTest( + words=['123', '51515', '777'], + tags=[(0, _INTEGER), (1, _INTEGER), (2, _INTEGER)], + ), + TaggerTest( + words=['123.1231', '1231,34555', '2/3'], + tags=[(0, _REAL), (1, _REAL), (2, _REAL)], + ), + TaggerTest( + words=['1.1', '123', '567'], + tags=[(0, _REAL), (2, _INTEGER)], + indices=[0, 2], + ), + TaggerTest( + words=['1.1', '123', '567'], + tags=[(1, _INTEGER), (2, _INTEGER)], + indices=[1, 2], + ), + TaggerTest( + words=['16-18', '1942—1944', '1'], + tags=[(0, _RANGE), (1, _RANGE), (2, _INTEGER)], + ), + TaggerTest( + words=[ + '', + ' ', + '!!!!', + 'XV', + 'unknown', + '<<123>>', + '23years', + '-', + '.', + ',', + ',1', + '.2', + ], + tags=[], + ), + ], +) +def test_numerical(test, tagger): + test.run(tagger) diff --git a/tests/tagger/test_punctuation.py b/tests/tagger/test_punctuation.py index 9f4855d..2eaf115 100644 --- a/tests/tagger/test_punctuation.py +++ b/tests/tagger/test_punctuation.py @@ -3,31 +3,40 @@ from maru.grammeme import PartOfSpeech from maru.tag import Tag from maru.tagger.punctuation import PunctuationTagger -from tests.tagger.base import assert_tags_equal +from tests.tagger.base import TaggerTest _PUNCTUATION = Tag(pos=PartOfSpeech.PUNCTUATION) -@pytest.mark.parametrize( - ['word'], [['!'], ['@'], ['.....,'], ['?!'], ['"'], [':'], [';'], ['()'], ['%']] -) -def test_punctuation(word: str): - assert_tags_equal( - tagger=PunctuationTagger(), expected=[(0, _PUNCTUATION)], words=[word], - ) - +@pytest.fixture(name='tagger', scope='session') +def create_tagger(): + return PunctuationTagger() -@pytest.mark.parametrize(['word'], [['12313'], ['unknown'], ['XV'], [' '], ['']]) -def test_non_punctuation(word: str): - assert_tags_equal( - tagger=PunctuationTagger(), expected=[], words=[word], - ) - -def test_indices(): - assert_tags_equal( - tagger=PunctuationTagger(), - expected=[(1, _PUNCTUATION), (2, _PUNCTUATION)], - words=['?', ',', '!'], - indices=[1, 2], - ) +@pytest.mark.parametrize( + 'test', + [ + TaggerTest( + words=['!', '@', '.....,'], + tags=[(0, _PUNCTUATION), (1, _PUNCTUATION), (2, _PUNCTUATION)], + ), + TaggerTest( + words=['?!', '"', ':', ';'], + tags=[ + (0, _PUNCTUATION), + (1, _PUNCTUATION), + (2, _PUNCTUATION), + (3, _PUNCTUATION), + ], + ), + TaggerTest(words=['()', '%'], tags=[(0, _PUNCTUATION), (1, _PUNCTUATION)],), + TaggerTest( + words=['?', ',', '!'], + tags=[(1, _PUNCTUATION), (2, _PUNCTUATION)], + indices=[1, 2], + ), + TaggerTest(words=['12313', 'unknown', 'XV', ' ', ''], tags=[],), + ], +) +def test_punctuation(test, tagger): + test.run(tagger) diff --git a/tests/tagger/test_rnn.py b/tests/tagger/test_rnn.py index 2f72548..b04b6a8 100644 --- a/tests/tagger/test_rnn.py +++ b/tests/tagger/test_rnn.py @@ -1,3 +1,5 @@ +import pytest + from maru.grammeme import ( Animacy, Case, @@ -9,33 +11,68 @@ ) from maru.tag import Tag from maru.tagger import RNNTagger -from tests.tagger.base import assert_tags_equal +from tests.tagger.base import TaggerTest + + +@pytest.fixture(name='tagger', scope='session') +def create_tagger(): + return RNNTagger() -def test(): - assert_tags_equal( - tagger=RNNTagger(), - expected=[ - ( - 0, - Tag( - pos=PartOfSpeech.ADJECTIVE, - case=Case.NOMINATIVE, - degree=Degree.POSITIVE, - number=Number.PLURAL, - variant=Variant.FULL, +@pytest.mark.parametrize( + 'test', + [ + TaggerTest( + words=['необычные', 'дела'], + tags=[ + ( + 0, + Tag( + pos=PartOfSpeech.ADJECTIVE, + case=Case.NOMINATIVE, + degree=Degree.POSITIVE, + number=Number.PLURAL, + variant=Variant.FULL, + ), ), - ), - ( - 1, - Tag( - pos=PartOfSpeech.NOUN, - animacy=Animacy.INANIMATE, - case=Case.NOMINATIVE, - gender=Gender.NEUTER, - number=Number.PLURAL, + ( + 1, + Tag( + pos=PartOfSpeech.NOUN, + animacy=Animacy.INANIMATE, + case=Case.NOMINATIVE, + gender=Gender.NEUTER, + number=Number.PLURAL, + ), ), - ), - ], - words=['необычные', 'дела'], - ) + ], + ), + TaggerTest( + words=['необычных', 'дел'], + tags=[ + ( + 0, + Tag( + pos=PartOfSpeech.ADJECTIVE, + case=Case.GENITIVE, + degree=Degree.POSITIVE, + number=Number.PLURAL, + variant=Variant.FULL, + ), + ), + ( + 1, + Tag( + pos=PartOfSpeech.NOUN, + animacy=Animacy.INANIMATE, + case=Case.GENITIVE, + gender=Gender.NEUTER, + number=Number.PLURAL, + ), + ), + ], + ), + ], +) +def test_rnn(test, tagger): + test.run(tagger)