Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dnnc with no finetuning #1630

Merged
merged 61 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
07e2cbb
Add mrpc binary head config
vaskonov Jul 18, 2022
43f8a9a
Fix binary head
vaskonov Jul 18, 2022
612c053
add few-shot infer support
nastyachizhikova Aug 23, 2022
9b49988
add few-shot metrics
nastyachizhikova Aug 23, 2022
40f9413
dnnc infer eval
Sep 21, 2022
ee36fe6
add preprocessor
Sep 21, 2022
a976260
init dnnc
Sep 28, 2022
a4dc57b
add dnnc training
Oct 19, 2022
45ee6e3
modified data processing
Nov 2, 2022
4209b05
fix imports
Nov 2, 2022
bbc83b7
change nli labels to strings
Nov 4, 2022
52e378f
add documentation
Nov 16, 2022
be99c31
fix conversion of labels to ids and ids to labels
Dec 5, 2022
612f08a
binary head dropout fix
Dec 5, 2022
50b7875
fix few-shot dos
Dec 5, 2022
96bc194
add return format flag
Dec 6, 2022
98cb660
add dataset and model downloading
Dec 6, 2022
b17791b
Fix: change paths
vaskonov Dec 7, 2022
d5c657d
Fix: change paths
vaskonov Dec 7, 2022
79ebde6
Fix: download paths
vaskonov Dec 7, 2022
a27b92a
remove skdlearn requirements
Dec 7, 2022
61db0e6
fex ix metrics
Dec 7, 2022
1d392b3
fix configs format
Dec 7, 2022
c188a55
Merge branch 'dnnc' of https://github.com/deeppavlov/DeepPavlov into …
Dec 7, 2022
39e6ad1
Fix: configs format and paths
Dec 7, 2022
dab193a
Upd: documentation
Dec 12, 2022
718d1de
Fix: typing
Dec 12, 2022
7cd75a5
Upd: add oos removal in iterator
Dec 12, 2022
e8cf40c
Fix: config format
Dec 12, 2022
72867ca
made the support dataset part of the input
Mar 7, 2023
c20d2d5
Fix: index.rst
Mar 13, 2023
5d7a198
Fix: index.rst
Mar 15, 2023
5742605
Fix: empty reference in docs
LogicZMaksimka Mar 27, 2023
6faccd3
Fix: metrics registry
LogicZMaksimka Mar 27, 2023
d711fa5
Fix: bidirectional scores averaging
LogicZMaksimka Mar 29, 2023
af23394
Fix: index.rst
LogicZMaksimka Mar 29, 2023
cdf0d9f
Conflicts resolved
LogicZMaksimka Mar 29, 2023
7702d65
refactor: minor style changes
IgnatovFedor Apr 7, 2023
fff9a9d
Fix: accuracy_oos arguments
LogicZMaksimka Apr 14, 2023
c705688
refactor: deleted a few-shot iterator that was not used anywhere
LogicZMaksimka Apr 19, 2023
a5975c8
Refactor: dnnc_preprocessor
LogicZMaksimka Apr 19, 2023
6344c34
Refactor: dnnc_proba2labels
LogicZMaksimka Apr 19, 2023
b4bd9f1
Refactor: config dnnc_infer
LogicZMaksimka Apr 19, 2023
799e64c
canceled changes in torch_transformers_classifier
LogicZMaksimka Apr 19, 2023
56ec9e3
Fix: removed few_shot_iterator from registry
LogicZMaksimka Apr 20, 2023
cbf89b3
Merge branch 'dev' with fixed bug
LogicZMaksimka Apr 28, 2023
255a425
fix: delete whitespaces
vaskonov Jun 14, 2023
8bdf2bd
fix: delete unused
vaskonov Jun 14, 2023
ebcff3a
fix: call arguments
vaskonov Jun 14, 2023
c6f57d0
fix: delete whitespaces
vaskonov Jun 14, 2023
dc55ea7
fix: remove unused
vaskonov Jun 14, 2023
243589b
fix: __call__ arguments
vaskonov Jun 14, 2023
ebcdcff
docs: optimizer few_shot_classification ipynb file
IgnatovFedor Jun 27, 2023
4537463
remove: trailing spaces
IgnatovFedor Jun 27, 2023
a05646b
fix: remove unused metrics
LogicZMaksimka Jun 27, 2023
628f1e3
remove: unused parameters
LogicZMaksimka Jun 28, 2023
f756dbe
docs: updated to new format
LogicZMaksimka Jun 28, 2023
f3783b5
Merge branch 'dev' into feat/dnnc_no_finetuning
LogicZMaksimka Jun 28, 2023
ca199dc
refactor: rename config
LogicZMaksimka Jun 28, 2023
fcc1d9d
docs: optimized few-shot classification doc
IgnatovFedor Jun 29, 2023
59348f5
feat: few-shot tests
IgnatovFedor Jun 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions deeppavlov/configs/classifiers/few_shot_roberta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"chainer": {
"in": ["texts", "dataset"],
"in_y": ["y_true"],
"pipe": [
{
"class_name": "dnnc_pair_generator",
"in": ["texts", "dataset"],
"out": ["x", "x_support", "x_populated", "y_support"],
"bidirectional": true
},
{
"class_name": "torch_transformers_preprocessor",
"in": ["x_populated", "x_support"],
"out": ["bert_features"],
"vocab_file": "{BASE_MODEL}",
"do_lower_case": true,
"max_seq_length": 128
},
{
"class_name": "torch_transformers_classifier",
"main": true,
"in": ["bert_features"],
"out": ["simmilarity_scores"],
"n_classes": 2,
"return_probas": true,
"pretrained_bert": "{BASE_MODEL}",
"save_path": "{MODEL_PATH}/model",
"load_path": "{MODEL_PATH}/model",
"is_binary": "{BINARY_CLASSIFICATION}"
},
{
"class_name": "dnnc_proba2labels",
"is_binary": "{BINARY_CLASSIFICATION}",
"in": ["simmilarity_scores", "x", "x_populated", "x_support", "y_support"],
"out": ["y_pred"],
"confidence_threshold": 0.0
}
],
"out": ["y_pred"]
},
"metadata": {
"variables": {
"ROOT_PATH": "~/.deeppavlov",
"MODEL_PATH": "{ROOT_PATH}/models/fewshot/roberta_nli_mrpc_1_10",
"BINARY_CLASSIFICATION": true,
"BASE_MODEL": "roberta-base"
},
"download": [
{
"url": "http://files.deeppavlov.ai/v1/classifiers/fewshot/roberta_nli_mrpc_1_10.tar.gz",
"subdir": "{MODEL_PATH}"
}
]
}
}
2 changes: 2 additions & 0 deletions deeppavlov/core/common/registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"dirty_comments_preprocessor": "deeppavlov.models.preprocessors.dirty_comments_preprocessor:DirtyCommentsPreprocessor",
"docred_reader": "deeppavlov.dataset_readers.docred_reader:DocREDDatasetReader",
"document_chunker": "deeppavlov.models.preprocessors.odqa_preprocessors:DocumentChunker",
"dnnc_pair_generator": "deeppavlov.models.preprocessors.dnnc_preprocessor:PairGenerator",
"dnnc_proba2labels": "deeppavlov.models.classifiers.dnnc_proba2labels:Proba2Labels",
"entity_detection_parser": "deeppavlov.models.entity_extraction.entity_detection_parser:EntityDetectionParser",
"entity_linker": "deeppavlov.models.entity_extraction.entity_linking:EntityLinker",
"entity_type_split": "deeppavlov.models.entity_extraction.entity_detection_parser:entity_type_split",
Expand Down
90 changes: 90 additions & 0 deletions deeppavlov/models/classifiers/dnnc_proba2labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import List

import numpy as np

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component

log = getLogger(__name__)


@register('dnnc_proba2labels')
class Proba2Labels(Component):
"""
Converts pairwise simmilarity scores into class label

Args:
confidence_threshold: used to determine whether example belongs to one
of the classes in 'y_support' or not
pooling: strategy for averaging similarity scores for each label
is_binary: determines whether the similarity is a number or a probability vector
"""

IgnatovFedor marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
confidence_threshold: float = 0.0,
pooling: str = 'max',
is_binary: bool = True,
**kwargs) -> None:

self.confidence_threshold = confidence_threshold
self.pooling = pooling
self.is_binary = is_binary

def __call__(self,
simmilarity_scores: List[float],
x: List[str],
x_populated: List[str],
x_support: List[str],
y_support: List[str]
) -> List[str]:

y_pred = []

simmilarity_scores = np.array(simmilarity_scores)
x_populated = np.array(x_populated)
x_support = np.array(x_support)
y_support = np.array(y_support)
unique_labels = np.unique(y_support)

# Transform probits vector into a simmilarity score
if not self.is_binary:
simmilarity_scores = simmilarity_scores[:, 1]

for example in x:
example_mask = np.where(np.logical_xor(x_populated == example, x_support == example))
example_simmilarity_scores = simmilarity_scores[example_mask]
example_y_support = y_support[example_mask]

probability_by_label = []
for label in unique_labels:
label_mask = np.where(example_y_support == label)
label_simmilarity_scores = example_simmilarity_scores[label_mask]
if self.pooling == 'avg':
label_probability = np.mean(label_simmilarity_scores)
elif self.pooling == 'max':
label_probability = np.max(label_simmilarity_scores)
probability_by_label.append(label_probability)

probability_by_label = np.array(probability_by_label)
max_probability = max(probability_by_label)
max_probability_label = unique_labels[np.argmax(probability_by_label)]
prediction = "oos" if max_probability < self.confidence_threshold else max_probability_label

y_pred.append(prediction)

return y_pred
55 changes: 55 additions & 0 deletions deeppavlov/models/preprocessors/dnnc_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import List, Tuple

import numpy as np

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component

log = getLogger(__name__)


@register('dnnc_pair_generator')
class PairGenerator(Component):
"""
Generates all possible ordered pairs from 'texts_batch' and 'support_dataset'

Args:
bidirectional: adds pairs in reverse order
"""

def __init__(self, bidirectional: bool = False, **kwargs) -> None:
self.bidirectional = bidirectional

def __call__(self,
texts: List[str],
dataset: List[List[str]],
) -> Tuple[List[str], List[str], List[str], List[str]]:
hypotesis_batch = []
premise_batch = []
hypotesis_labels_batch = []
for [premise, [hypotesis, hypotesis_labels]] in zip(texts * len(dataset),
np.repeat(dataset, len(texts), axis=0)):
premise_batch.append(premise)
hypotesis_batch.append(hypotesis)
hypotesis_labels_batch.append(hypotesis_labels)

if self.bidirectional:
premise_batch.append(hypotesis)
hypotesis_batch.append(premise)
hypotesis_labels_batch.append(hypotesis_labels)
return texts, hypotesis_batch, premise_batch, hypotesis_labels_batch
Loading