From cfd214bf68a081c063005384dffc7109646d38c8 Mon Sep 17 00:00:00 2001 From: Nan Wang Date: Wed, 27 Jul 2022 09:35:14 +0800 Subject: [PATCH 1/2] feat: replace traversal_paths with access_paths --- dpr_text.py | 24 ++++++++++++++++++------ tests/unit/test_encoder.py | 6 +++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/dpr_text.py b/dpr_text.py index 805bbe5..03274bf 100644 --- a/dpr_text.py +++ b/dpr_text.py @@ -3,6 +3,7 @@ import torch from jina import DocumentArray, Executor, requests from jina.logging.logger import JinaLogger +from jina.helper import deprecated_alias from transformers import (DPRContextEncoder, DPRContextEncoderTokenizerFast, DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast) @@ -17,6 +18,7 @@ class DPRTextEncoder(Executor): encoding method used in model training. """ + @deprecated_alias(traversal_paths=('access_paths', 0)) def __init__( self, pretrained_model_name_or_path: str = "facebook/dpr-question_encoder-single-nq-base", @@ -24,7 +26,7 @@ def __init__( base_tokenizer_model: Optional[str] = None, title_tag_key: Optional[str] = None, max_length: Optional[int] = None, - traversal_paths: str = "@r", + access_paths: str = "@r", batch_size: int = 32, device: str = "cpu", *args, @@ -45,8 +47,8 @@ def __init__( tag property. It is recommended to set this property for context encoders, to match the model pre-training. It has no effect for question encoders. :param max_length: Max length argument for the tokenizer - :param traversal_paths: Default traversal paths for encoding, used if the - traversal path is not passed as a parameter with the request. + :param access_paths: Default access paths for encoding, used if the + access path is not passed as a parameter with the request. :param batch_size: Default batch size for encoding, used if the batch size is not passed as a parameter with the request. :param device: The device (cpu or gpu) that the model should be on. @@ -102,7 +104,7 @@ def __init__( self.model = self.model.to(self.device).eval() - self.traversal_paths = traversal_paths + self.access_paths = access_paths self.batch_size = batch_size @requests @@ -117,13 +119,23 @@ def encode( ``text`` attribute. :param parameters: dictionary to define the ``traversal_path`` and the ``batch_size``. For example, - ``parameters={'traversal_paths': '@r', 'batch_size': 10}`` + ``parameters={'access_paths': '@r', 'batch_size': 10}`` """ + access_paths = parameters.get('traversal_paths', None) + if access_paths is not None: + import warnings + warnings.warn( + f'`traversal_paths` is renamed to `access_paths` with the same usage, please use the latter instead. ' + f'The old function will be removed soon.', + DeprecationWarning, + ) + parameters['access_paths'] = access_paths + document_batches_generator = DocumentArray( filter( lambda x: bool(x.text), - docs[parameters.get("traversal_paths", self.traversal_paths)], + docs[parameters.get("access_paths", self.access_paths)], ) ).batch(batch_size=parameters.get("batch_size", self.batch_size)) diff --git a/tests/unit/test_encoder.py b/tests/unit/test_encoder.py index 7b34673..b4769ee 100644 --- a/tests/unit/test_encoder.py +++ b/tests/unit/test_encoder.py @@ -83,7 +83,7 @@ def test_encoding_gpu(): @pytest.mark.parametrize( - "traversal_paths, counts", + "access_paths, counts", [ ("@r", [["@r", 1], ["@c", 0], ["@cc", 0]]), ("@c", [["@r", 0], ["@c", 3], ["@cc", 0]]), @@ -92,7 +92,7 @@ def test_encoding_gpu(): ], ) def test_traversal_path( - traversal_paths: List[str], counts: List, basic_encoder: DPRTextEncoder + access_paths: List[str], counts: List, basic_encoder: DPRTextEncoder ): text = "blah" docs = DocumentArray([Document(id="root1", text=text)]) @@ -106,7 +106,7 @@ def test_traversal_path( Document(id="chunk112", text=text), ] - basic_encoder.encode(docs=docs, parameters={"traversal_paths": traversal_paths}) + basic_encoder.encode(docs=docs, parameters={"access_paths": access_paths}) for path, count in counts: embeddings = docs[path].embeddings From 55ace8a4e9dd7b805eb543bf4a91d866608b7bca Mon Sep 17 00:00:00 2001 From: Nan Wang Date: Wed, 27 Jul 2022 17:45:51 +0800 Subject: [PATCH 2/2] fix: put traversal_paths back --- dpr_text.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/dpr_text.py b/dpr_text.py index 03274bf..8258c00 100644 --- a/dpr_text.py +++ b/dpr_text.py @@ -3,7 +3,6 @@ import torch from jina import DocumentArray, Executor, requests from jina.logging.logger import JinaLogger -from jina.helper import deprecated_alias from transformers import (DPRContextEncoder, DPRContextEncoderTokenizerFast, DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast) @@ -18,7 +17,6 @@ class DPRTextEncoder(Executor): encoding method used in model training. """ - @deprecated_alias(traversal_paths=('access_paths', 0)) def __init__( self, pretrained_model_name_or_path: str = "facebook/dpr-question_encoder-single-nq-base", @@ -26,7 +24,8 @@ def __init__( base_tokenizer_model: Optional[str] = None, title_tag_key: Optional[str] = None, max_length: Optional[int] = None, - access_paths: str = "@r", + access_paths: str = '@r', + traversal_paths: Optional[str] = None, batch_size: int = 32, device: str = "cpu", *args, @@ -49,6 +48,7 @@ def __init__( :param max_length: Max length argument for the tokenizer :param access_paths: Default access paths for encoding, used if the access path is not passed as a parameter with the request. + :param traversal_paths: Please use `access_paths` :param batch_size: Default batch size for encoding, used if the batch size is not passed as a parameter with the request. :param device: The device (cpu or gpu) that the model should be on. @@ -105,6 +105,15 @@ def __init__( self.model = self.model.to(self.device).eval() self.access_paths = access_paths + if traversal_paths is not None: + import warnings + warnings.warn( + f'`traversal_paths` is renamed to `access_paths` with the same usage, please use the latter instead. ' + f'`traversal_paths` will be removed soon.', + DeprecationWarning, + ) + self.access_paths = traversal_paths + self.batch_size = batch_size @requests @@ -127,7 +136,7 @@ def encode( import warnings warnings.warn( f'`traversal_paths` is renamed to `access_paths` with the same usage, please use the latter instead. ' - f'The old function will be removed soon.', + f'`traversal_paths` will be removed soon.', DeprecationWarning, ) parameters['access_paths'] = access_paths