Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEID-2609 Added NERTextResponse class and changes to support the NER … #41

Merged
merged 9 commits into from
Jul 22, 2024
5 changes: 5 additions & 0 deletions src/privateai_client/components/pai_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def process_text(self, request_object):
return self.make_request(
self.request_type, self.uris.process_text, request_object
)

def process_ner_text(self, request_object):
return self.make_request(
self.request_type, self.uris.process_ner_text, request_object
AmirPAI marked this conversation as resolved.
Show resolved Hide resolved
)

def process_files_uri(self, request_object):
return self.make_request(
Expand Down
21 changes: 21 additions & 0 deletions src/privateai_client/components/pai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,27 @@ def get_reidentify_request(self):
return ReidentifyTextRequest(self.processed_text, entities)


class NerTextResponse(BaseResponse):
def __init__(self, response_object: Response = None):
super(NerTextResponse, self).__init__(response_object, True)

@property
def entities(self):
return self.get_attribute_entries("entities")

@property
def entities_present(self):
return self.get_attribute_entries("entities_present")

@property
def characters_processed(self):
return self.get_attribute_entries("characters_processed")

@property
def languages_detected(self):
return self.get_attribute_entries("languages_detected")


class TextResponse(DemiTextResponse):
def __init__(self, response_object: Response = None):
super(TextResponse, self).__init__(response_object)
Expand Down
31 changes: 31 additions & 0 deletions src/privateai_client/components/request_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,37 @@ def fromdict(cls, values: dict):
)


class NerTextRequest(BaseRequestObject):
default_link_batch = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this being used. Do you want the initializer to default link_batch to this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @a-guiducci , Thanks for the approval.
The NerText has most of its properties similar to ProcessText, so I have set the initializer to the same default values.
@guyd can you verify this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what you're saying. Ok np. I think the default value for process text was removed at some point but never really cleaned up :/

If you don't mind, let's remove the default_link_batch from the NerTextRequest because we'll never use it 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok. Sure


def __init__(
self,
text: List[str],
link_batch: Optional[bool] = None,
entity_detection: Optional[EntityDetection] = None,
project_id: Optional[str] = None,
):
self.text = text
self.link_batch = link_batch
self.entity_detection = entity_detection
self.project_id = project_id

@classmethod
def fromdict(cls, values: dict):
try:
initializer_dict = {}
for key, value in values.items():
if key == "entity_detection":
initializer_dict[key] = EntityDetection.fromdict(value)
else:
initializer_dict[key] = value
return cls._fromdict(initializer_dict)
except TypeError:
raise TypeError(
"NerTextRequest can only accept the values 'text', 'link_batch' and 'entity_detection'"
)


class ProcessFileUriRequest(BaseRequestObject):
def __init__(
self,
Expand Down
112 changes: 112 additions & 0 deletions src/privateai_client/tests/test_request_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,118 @@ def test_process_text_request_to_dict():
assert process_text_request["processed_text"]["pattern"] == processed_text.pattern


# NER Text Request Tests
def test_ner_text_request_default_initializer():
ner_text_request = NerTextRequest(text=["hey"])
assert ner_text_request.text == ["hey"]
assert ner_text_request.link_batch is None
assert ner_text_request.entity_detection is None


def test_ner_text_request_initializer():
text = ["hey!"]
link_batch = True
entity_type = EntityTypeSelector(type="ENABLE", value=["NAME"])
filter = FilterSelector(type="ALLOW", pattern="hey")
entity_detection = EntityDetection(
accuracy="standard",
entity_types=[entity_type],
filter=[filter],
return_entity=False,
)

ner_text_request = NerTextRequest(
text=text,
link_batch=link_batch,
entity_detection=entity_detection,
)

assert ner_text_request.text == text
assert ner_text_request.link_batch == link_batch
assert ner_text_request.entity_detection.accuracy == entity_detection.accuracy
assert ner_text_request.entity_detection.entity_types[0].type == entity_type.type
assert ner_text_request.entity_detection.entity_types[0].value == entity_type.value
assert ner_text_request.entity_detection.filter[0].type == filter.type
assert ner_text_request.entity_detection.filter[0].pattern == filter.pattern


def test_ner_text_request_initialize_fromdict():
request_obj = {
"text": ["hey!"],
"link_batch": False,
"entity_detection": {
"accuracy": "standard",
"entity_types": [{"type": "DISABLE", "value": ["LOCATION"]}],
"filter": [{"type": "BLOCK", "pattern": "Roger", "entity_type": "TEST"}],
"return_entity": False,
},
}
ner_text_request = NerTextRequest.fromdict(request_obj)
assert ner_text_request.text == request_obj["text"]
assert ner_text_request.link_batch == request_obj["link_batch"]
assert ner_text_request.entity_detection.accuracy == request_obj["entity_detection"]["accuracy"]
assert (
ner_text_request.entity_detection.entity_types[0].type
== request_obj["entity_detection"]["entity_types"][0]["type"]
)
assert (
ner_text_request.entity_detection.entity_types[0].value
== request_obj["entity_detection"]["entity_types"][0]["value"]
)
assert ner_text_request.entity_detection.filter[0].type == request_obj["entity_detection"]["filter"][0]["type"]
assert (
ner_text_request.entity_detection.filter[0].pattern
== request_obj["entity_detection"]["filter"][0]["pattern"]
)


def test_ner_text_request_invalid_initialize_fromdict():
error_msg = (
"NerTextRequest can only accept the values 'text', 'link_batch' and 'entity_detection'"
)
request_obj = {
"text": ["hey!"],
"link_batch": False,
"entity_detection": {
"accuracy": "standard",
"entity_types": [{"type": "DISABLE", "value": ["LOCATION"]}],
"filter": [{"type": "BLOCK", "pattern": "Roger"}],
"return_entity": False,
},
"junk": "value",
}
with pytest.raises(TypeError) as excinfo:
NerTextRequest.fromdict(request_obj)
assert error_msg in str(excinfo.value)


def test_ner_text_request_to_dict():
text = ["hey!"]
link_batch = True
entity_type = EntityTypeSelector(type="ENABLE", value=["NAME"])
filter = FilterSelector(type="ALLOW", pattern="hey")
entity_detection = EntityDetection(
accuracy="standard",
entity_types=[entity_type],
filter=[filter],
return_entity=False,
)

ner_text_request = NerTextRequest(
text=text,
link_batch=link_batch,
entity_detection=entity_detection,
).to_dict()
print(ner_text_request)
assert ner_text_request["text"] == text
assert ner_text_request["link_batch"] == link_batch
assert ner_text_request["entity_detection"]["accuracy"] == entity_detection.accuracy
assert ner_text_request["entity_detection"]["entity_types"][0]["type"] == entity_type.type
assert ner_text_request["entity_detection"]["entity_types"][0]["value"] == entity_type.value
assert ner_text_request["entity_detection"]["filter"][0]["type"] == filter.type
assert ner_text_request["entity_detection"]["filter"][0]["pattern"] == filter.pattern


# Process File URI Request Tests
def test_process_file_uri_request_default_initializer():
process_file_uri_obj = ProcessFileUriRequest(uri="this/location/right/here.png")
Expand Down