-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #165 from deepmipt/dev
Release v0.1.9
- Loading branch information
Showing
51 changed files
with
4,306 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
metrics/* | ||
data/* | ||
*.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
FROM deeppavlov/base-gpu:0.17.2 | ||
|
||
RUN apt-key del 7fa2af80 && \ | ||
rm -f /etc/apt/sources.list.d/cuda*.list && \ | ||
wget -q https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb && \ | ||
dpkg -i cuda-keyring_1.0-1_all.deb | ||
|
||
RUN apt-get update && apt-get install -y --allow-unauthenticated wget && rm -rf /var/lib/apt/lists/* | ||
|
||
WORKDIR /src | ||
|
||
ARG CONFIG_NAME | ||
ENV CONFIG_NAME ${CONFIG_NAME} | ||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
|
||
COPY annotators/IntentCatcherTransformers/requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
COPY ./common/ ./common/ | ||
COPY annotators/IntentCatcherTransformers/ /src | ||
WORKDIR /src | ||
|
||
RUN python -m deeppavlov install ${CONFIG_NAME} | ||
RUN python -m deeppavlov download ${CONFIG_NAME} | ||
RUN python train_model_if_not_exist.py | ||
|
||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## IntentCatcher based on Transformers | ||
|
||
|
||
English version was trained on `intent_phrases.json` dataset using `DeepPavlov` library via command: | ||
``` | ||
python -m deeppavlov train intents_model_dp_config.json | ||
``` | ||
|
||
It consumes 3.5Gb GPU RAM during fine-tuning. Classification results after 5 epochs are the following: | ||
```json | ||
{"train": {"eval_examples_count": 209297, "metrics": {"accuracy": 0.9997, "f1_weighted": 1.0, "f1_macro": 0.9999, "roc_auc": 1.0}, "time_spent": "0:03:46"}} | ||
{"valid": {"eval_examples_count": 52325, "metrics": {"accuracy": 0.9995, "f1_weighted": 0.9999, "f1_macro": 0.9999, "roc_auc": 1.0}, "time_spent": "0:00:57"}} | ||
``` |
1,594 changes: 1,594 additions & 0 deletions
1,594
annotators/IntentCatcherTransformers/intent_phrases.json
Large diffs are not rendered by default.
Oops, something went wrong.
109 changes: 109 additions & 0 deletions
109
annotators/IntentCatcherTransformers/intents_dataset_reader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from collections import OrderedDict | ||
from itertools import chain | ||
from pathlib import Path | ||
from random import shuffle | ||
from typing import Optional | ||
|
||
import json | ||
from deeppavlov.core.data.dataset_reader import DatasetReader | ||
from xeger import Xeger | ||
|
||
|
||
class IntentsJsonReader(DatasetReader): | ||
""" | ||
Class provides reading intents dataset in .json format: | ||
```json | ||
{ | ||
"intent_phrases": { | ||
"intent_0": { | ||
"phrases": [ | ||
"(alexa ){0,1}(hi|hello)(( there)|( alexa)){0,1}" | ||
], | ||
"reg_phrases": [ | ||
"hi", | ||
"hello" | ||
], | ||
"punctuation": [ | ||
".", | ||
"!" | ||
] | ||
} | ||
} | ||
} | ||
``` | ||
to make it compatible with classification models in DeepPavlov pipelines: | ||
```json | ||
[ | ||
("alexa hi", "intent_0"), | ||
..., | ||
] | ||
``` | ||
""" | ||
|
||
@staticmethod | ||
def generate_phrases(template_re, punctuation, limit=2500): | ||
x = Xeger(limit=limit) | ||
phrases = [] | ||
for regex in template_re: | ||
try: | ||
phrases += list({x.xeger(regex) for _ in range(limit)}) | ||
except Exception as e: | ||
print(e) | ||
print(regex) | ||
raise e | ||
phrases = [phrases] + [[phrase + punct for phrase in phrases] for punct in punctuation] | ||
return list(chain.from_iterable(phrases)) | ||
|
||
def read(self, data_path: str, generated_data_path: Optional[str] = None, *args, **kwargs) -> dict: | ||
""" | ||
Read dataset from `data_path` file with extension `.json` | ||
Args: | ||
data_path: file with `.json` extension | ||
Returns: | ||
dictionary with data samples. | ||
Each field of dictionary is a list of tuples (x_i, y_i) | ||
where `x_i` is a text sample, `y_i` is a class name | ||
""" | ||
data_types = ["train", "valid", "test"] | ||
data = {data_type: [] for data_type in data_types} | ||
|
||
for data_type in data_types: | ||
file_name = kwargs.get(data_type, f"{data_type}.json") | ||
if file_name is None: | ||
continue | ||
|
||
file = Path(data_path).joinpath(file_name) | ||
if file.exists(): | ||
if generated_data_path and Path(generated_data_path).joinpath(file_name).exists(): | ||
with open(Path(generated_data_path).joinpath(file_name), "r") as fp: | ||
data[data_type] = json.load(fp) | ||
else: | ||
with open(file, "r") as fp: | ||
all_data = json.load(fp) | ||
intent_phrases = OrderedDict(all_data["intent_phrases"]) | ||
random_phrases = all_data["random_phrases"] | ||
random_phrases = self.generate_phrases(random_phrases["phrases"], random_phrases["punctuation"]) | ||
|
||
intent_data = {} | ||
for intent, intent_samples in intent_phrases.items(): | ||
phrases = self.generate_phrases(intent_samples["phrases"], intent_samples["punctuation"]) | ||
intent_data[intent] = { | ||
"generated_phrases": phrases, | ||
"num_punctuation": len(intent_samples["punctuation"]), | ||
} | ||
|
||
data[data_type] = [ | ||
(gen_phrase, [intent]) | ||
for intent in intent_phrases.keys() | ||
for gen_phrase in intent_data[intent]["generated_phrases"] | ||
] | ||
data[data_type] += [(gen_phrase, []) for gen_phrase in random_phrases] | ||
shuffle(data[data_type]) | ||
if generated_data_path: | ||
Path(generated_data_path).mkdir(exist_ok=True) | ||
with open(Path(generated_data_path).joinpath(file_name), "w") as fp: | ||
json.dump(data[data_type], fp, indent=2) | ||
elif data_type == "train": | ||
raise FileNotFoundError(f"Train file `{file_name}` is not provided in `{data_path}`.") | ||
|
||
return data |
190 changes: 190 additions & 0 deletions
190
annotators/IntentCatcherTransformers/intents_model_dp_config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
{ | ||
"dataset_reader": { | ||
"class_name": "intents_dataset_reader:IntentsJsonReader", | ||
"data_path": "./", | ||
"train": "intent_phrases.json", | ||
"generated_data_path": "./generated_data" | ||
}, | ||
"dataset_iterator": { | ||
"class_name": "basic_classification_iterator", | ||
"seed": 42, | ||
"split_seed": 23, | ||
"field_to_split": "train", | ||
"split_fields": [ | ||
"train", | ||
"valid" | ||
], | ||
"split_proportions": [ | ||
0.8, | ||
0.2 | ||
] | ||
}, | ||
"chainer": { | ||
"in": [ | ||
"x" | ||
], | ||
"in_y": [ | ||
"y" | ||
], | ||
"pipe": [ | ||
{ | ||
"class_name": "torch_transformers_preprocessor", | ||
"vocab_file": "{TRANSFORMER}", | ||
"do_lower_case": true, | ||
"max_seq_length": 64, | ||
"in": [ | ||
"x" | ||
], | ||
"out": [ | ||
"bert_features" | ||
] | ||
}, | ||
{ | ||
"id": "classes_vocab", | ||
"class_name": "simple_vocab", | ||
"fit_on": [ | ||
"y" | ||
], | ||
"save_path": "{MODEL_PATH}/classes.dict", | ||
"load_path": "{MODEL_PATH}/classes.dict", | ||
"in": [ | ||
"y" | ||
], | ||
"out": [ | ||
"y_ids" | ||
] | ||
}, | ||
{ | ||
"id": "my_one_hotter", | ||
"in": [ | ||
"y_ids" | ||
], | ||
"out": [ | ||
"y_onehot" | ||
], | ||
"class_name": "one_hotter", | ||
"depth": "#classes_vocab.len", | ||
"single_vector": true | ||
}, | ||
{ | ||
"class_name": "torch_transformers_classifier", | ||
"n_classes": "#classes_vocab.len", | ||
"return_probas": true, | ||
"one_hot_labels": true, | ||
"multilabel": true, | ||
"pretrained_bert": "{TRANSFORMER}", | ||
"save_path": "{MODEL_PATH}/model", | ||
"load_path": "{MODEL_PATH}/model", | ||
"optimizer": "AdamW", | ||
"optimizer_parameters": { | ||
"lr": 1e-05 | ||
}, | ||
"learning_rate_drop_patience": 5, | ||
"learning_rate_drop_div": 2.0, | ||
"in": [ | ||
"bert_features" | ||
], | ||
"in_y": [ | ||
"y_onehot" | ||
], | ||
"out": [ | ||
"y_pred_probas" | ||
] | ||
}, | ||
{ | ||
"in": [ | ||
"y_pred_probas" | ||
], | ||
"out": [ | ||
"y_pred_ids" | ||
], | ||
"class_name": "proba2labels", | ||
"max_proba": false, | ||
"confidence_threshold": 0.5 | ||
}, | ||
{ | ||
"ref": "my_one_hotter", | ||
"in": [ | ||
"y_pred_ids" | ||
], | ||
"out": [ | ||
"y_pred_onehot" | ||
] | ||
}, | ||
{ | ||
"in": [ | ||
"y_pred_ids" | ||
], | ||
"out": [ | ||
"y_pred_labels" | ||
], | ||
"ref": "classes_vocab" | ||
} | ||
], | ||
"out": [ | ||
"y_pred_labels", | ||
"y_pred_probas" | ||
] | ||
}, | ||
"train": { | ||
"epochs": 5, | ||
"batch_size": 64, | ||
"metrics": [ | ||
{ | ||
"name": "accuracy", | ||
"inputs": [ | ||
"y", | ||
"y_pred_labels" | ||
] | ||
}, | ||
{ | ||
"name": "f1_weighted", | ||
"inputs": [ | ||
"y_onehot", | ||
"y_pred_onehot" | ||
] | ||
}, | ||
{ | ||
"name": "f1_macro", | ||
"inputs": [ | ||
"y_onehot", | ||
"y_pred_onehot" | ||
] | ||
}, | ||
{ | ||
"name": "roc_auc", | ||
"inputs": [ | ||
"y_onehot", | ||
"y_pred_probas" | ||
] | ||
} | ||
], | ||
"validation_patience": 5, | ||
"val_every_n_epochs": 1, | ||
"log_every_n_epochs": 1, | ||
"show_examples": false, | ||
"evaluation_targets": [ | ||
"train", | ||
"valid" | ||
], | ||
"class_name": "torch_trainer" | ||
}, | ||
"metadata": { | ||
"imports": [ | ||
"intents_dataset_reader" | ||
], | ||
"variables": { | ||
"TRANSFORMER": "distilbert-base-uncased", | ||
"ROOT_PATH": "~/.deeppavlov", | ||
"DOWNLOADS_PATH": "{ROOT_PATH}/downloads", | ||
"MODELS_PATH": "{ROOT_PATH}/models", | ||
"MODEL_PATH": "{MODELS_PATH}/classifiers/intents_model_v2" | ||
}, | ||
"download": [ | ||
{ | ||
"url": "http://files.deeppavlov.ai/deeppavlov_data/intents_model_v2.tar.gz", | ||
"subdir": "{MODELS_PATH}/classifiers" | ||
} | ||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
transformers==4.6.0 | ||
sentencepiece==0.1.94 | ||
flask==1.1.1 | ||
itsdangerous==2.0.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 | ||
healthcheck==1.3.3 | ||
jinja2<=3.0.3 | ||
Werkzeug<=2.0.3 | ||
pandas==0.25.3 | ||
huggingface-hub==0.0.8 | ||
datasets==1.11.0 | ||
scikit-learn==0.21.2 | ||
xeger==0.3.5 |
Oops, something went wrong.