Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable batch processing in scriptable tokenizer example #2130

Merged
merged 2 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/text_classification_with_scriptable_tokenizer/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
Module for text classification with scriptable tokenizer
DOES NOT SUPPORT BATCH!
"""
import logging
from abc import ABC
Expand Down Expand Up @@ -51,18 +50,19 @@ def preprocess(self, data):

# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
# Processing only the first input, not handling batch inference

line = data[0]
text = line.get("data") or line.get("body")
# Decode text if not a str but bytes or bytearray
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")
text_batch = []
for line in data:
text = line.get("data") or line.get("body")
# Decode text if not a str but bytes or bytearray
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")

text = remove_html_tags(text)
text = text.lower()
text = remove_html_tags(text)
text = text.lower()
text_batch.append(text)

return text
return text_batch

def inference(self, data, *args, **kwargs):
"""The Inference Request is made through this function and the user
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main(args):
model = XLMR_BASE_ENCODER.get_model(head=classifier_head)

# Load trained parameters and load them into the model
model.load_state_dict(torch.load(args.input_file))
model.load_state_dict(torch.load(args.input_file, map_location=torch.device("cpu")))

# Chain the tokenizer, the adapter and the model
combi_model = T.Sequential(
Expand All @@ -88,7 +88,7 @@ def main(args):
combi_model.eval()

# Make sure to move the model to CPU to avoid placement error during loading
combi_model.to("cpu")
combi_model.to(torch.device("cpu"))

combi_model_jit = torch.jit.script(combi_model)

Expand Down