Skip to content

Commit

Permalink
added concurrency batch aggregator fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 7, 2021
1 parent 1e5b901 commit 7e751c5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ public void sendResponse(ModelWorkerResponse message) {
// this is from initial load.
return;
}

for (Predictions prediction : message.getPredictions()) {
String jobId = prediction.getRequestId();
Job job = jobs.remove(jobId);
Job job = jobs.get(jobId);

if (job == null) {
throw new IllegalStateException("Unexpected job: " + jobId);
throw new IllegalStateException(
"Unexpected job in sendResponse() with 200 status code: " + jobId);
}
job.response(
prediction.getResp(),
Expand All @@ -77,18 +78,19 @@ public void sendResponse(ModelWorkerResponse message) {
prediction.getReasonPhrase(),
prediction.getHeaders());
}

} else {
for (String reqId : jobs.keySet()) {
Job j = jobs.remove(reqId);
if (j == null) {
throw new IllegalStateException("Unexpected job: " + reqId);
for (Map.Entry<String, Job> j : jobs.entrySet()) {

if (j.getValue() == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with non 200 status code: "
+ j.getKey());
}
j.sendError(message.getCode(), message.getMessage());
}
if (!jobs.isEmpty()) {
throw new IllegalStateException("Not all jobs get response.");
j.getValue().sendError(message.getCode(), message.getMessage());
}
}
jobs.clear();
}

public void sendError(BaseModelRequest message, String error, int status) {
Expand All @@ -103,20 +105,20 @@ public void sendError(BaseModelRequest message, String error, int status) {
String requestId = req.getRequestId();
Job job = jobs.remove(requestId);
if (job == null) {
logger.error("Unexpected job: " + requestId);
logger.error("Unexpected job in sendError(): " + requestId);
} else {
job.sendError(status, error);
}
}
if (!jobs.isEmpty()) {
jobs.clear();
logger.error("Not all jobs get response.");
logger.error("Not all jobs got an error response.");
}
} else {
// Send the error message to all the jobs
for (Map.Entry<String, Job> j : jobs.entrySet()) {
String jobsId = j.getValue().getJobId();
Job job = jobs.remove(jobsId);
Job job = jobs.get(jobsId);

if (job.isControlCmd()) {
job.sendError(status, error);
Expand All @@ -127,5 +129,6 @@ public void sendError(BaseModelRequest message, String error, int status) {
}
}
}
jobs.clear();
}
}
}
24 changes: 10 additions & 14 deletions test/pytest/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import requests
import json
import test_utils
import asyncio
import multiprocessing
import numpy as np
import ast
REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
Expand Down Expand Up @@ -226,8 +228,8 @@ def test_kfserving_mnist_model_register_and_inference_on_valid_model_explain():
test_utils.unregister_model("mnist")

def test_huggingface_bert_batch_inference():
batch_size = 4
batch_delay = 20000 # 20 seconds
batch_size = 2
batch_delay = 10000 # 10 seconds
params = (
('model_name', 'BERTSeqClassification'),
('url', 'https://bert-mar-file.s3.us-west-2.amazonaws.com/BERTSeqClassification.mar'),
Expand All @@ -238,18 +240,12 @@ def test_huggingface_bert_batch_inference():
test_utils.start_torchserve(no_config_snapshots=True)
test_utils.register_model_with_params(params)
input_text = os.path.join(REPO_ROOT, 'examples/Huggingface_Transformers/Seq_classification_artifacts/sample_text.txt')
files = {
'data': (input_text,
open(input_text, 'rb')),
}

for _ in range(batch_size):
response = run_inference_using_url_with_data(TF_INFERENCE_API + '/v1/models/BERTSeqClassification:predict', pfiles=files)

response = response.content
# response = ast.literal_eval(response)
# custom handler returns number of responses not the actual responses
assert int(response) == batch_size

response = os.popen(f"curl http://127.0.0.1:8080/predictions/BERTSeqClassification -T {input_text} & curl http://127.0.0.1:8080/predictions/BERTSeqClassification -T {input_text}")

# handler responds with number of inferences
response = response.read()
assert response == batch_size
test_utils.unregister_model('BERTSeqClassification')

def test_MMF_activity_recognition_model_register_and_inference_on_valid_model():
Expand Down

0 comments on commit 7e751c5

Please sign in to comment.