Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert base batch test #1272

Merged
merged 21 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ public void sendResponse(ModelWorkerResponse message) {
// this is from initial load.
return;
}

for (Predictions prediction : message.getPredictions()) {
String jobId = prediction.getRequestId();
Job job = jobs.remove(jobId);
Job job = jobs.get(jobId);

if (job == null) {
throw new IllegalStateException("Unexpected job: " + jobId);
throw new IllegalStateException(
"Unexpected job in sendResponse() with 200 status code: " + jobId);
}
job.response(
prediction.getResp(),
Expand All @@ -77,18 +78,19 @@ public void sendResponse(ModelWorkerResponse message) {
prediction.getReasonPhrase(),
prediction.getHeaders());
}

} else {
for (String reqId : jobs.keySet()) {
Job j = jobs.remove(reqId);
if (j == null) {
throw new IllegalStateException("Unexpected job: " + reqId);
for (Map.Entry<String, Job> j : jobs.entrySet()) {

if (j.getValue() == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with non 200 status code: "
+ j.getKey());
}
j.sendError(message.getCode(), message.getMessage());
}
if (!jobs.isEmpty()) {
throw new IllegalStateException("Not all jobs get response.");
j.getValue().sendError(message.getCode(), message.getMessage());
}
}
jobs.clear();
}

public void sendError(BaseModelRequest message, String error, int status) {
Expand All @@ -103,20 +105,20 @@ public void sendError(BaseModelRequest message, String error, int status) {
String requestId = req.getRequestId();
Job job = jobs.remove(requestId);
if (job == null) {
logger.error("Unexpected job: " + requestId);
logger.error("Unexpected job in sendError(): " + requestId);
} else {
job.sendError(status, error);
}
}
if (!jobs.isEmpty()) {
jobs.clear();
logger.error("Not all jobs get response.");
logger.error("Not all jobs got an error response.");
}
} else {
// Send the error message to all the jobs
for (Map.Entry<String, Job> j : jobs.entrySet()) {
String jobsId = j.getValue().getJobId();
Job job = jobs.remove(jobsId);
Job job = jobs.get(jobsId);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim do we have any test to repro the batch aggregator issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HamidShojanazeri yeah it's in the above description at the end of the repro section

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @msaroufim, this test case only covers status code=200. Can you add test case for "else part" (ie. original jobs.remove() part)?

Copy link
Member Author

@msaroufim msaroufim Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, we already have tests to test batch failures https://github.com/pytorch/serve/blame/master/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java#L1071

We just didn't have one for batch successes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim ModelServerTest.java#L1071 is a frontend integration test case which is called during frontend build.

Test test_huggingface_bert_batch_inference is an end 2 end integration which is called by regression test.

So they are different level test. Usually, a java test case should be added into dir frontend/server/src/test if there are frontend source code changes. Currently frontend misses a batch positive test case.

It is fine if you prefer adding batch end2end test cases in regression test. Then, I suggest you add both positive and negative test cases at here. Otherwise, the problem at regression for batch use case is as same as current Java frontend test (ie. miss a test case).


if (job.isControlCmd()) {
job.sendError(status, error);
Expand All @@ -127,5 +129,6 @@ public void sendError(BaseModelRequest message, String error, int status) {
}
}
}
jobs.clear();
}
}
33 changes: 33 additions & 0 deletions frontend/server/src/test/java/org/pytorch/serve/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,39 @@ public static void registerModel(
}
}

public static void registerModel(
Channel channel,
String url,
String modelName,
boolean withInitialWorkers,
boolean syncChannel,
int batchSize,
int maxBatchDelay)
throws InterruptedException {
String requestURL =
"/models?url="
+ url
+ "&model_name="
+ modelName
+ "&runtime=python"
+ "&batch_size="
+ batchSize
+ "&max_batch_delay="
+ maxBatchDelay;
if (withInitialWorkers) {
requestURL += "&initial_workers=1&synchronous=true";
}

HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, requestURL);
if (syncChannel) {
channel.writeAndFlush(req).sync();
channel.closeFuture().sync();
} else {
channel.writeAndFlush(req);
}
}

public static void registerWorkflow(
Channel channel, String url, String workflowName, boolean syncChannel)
throws InterruptedException {
Expand Down
25 changes: 24 additions & 1 deletion test/pytest/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,30 @@ def test_kfserving_mnist_model_register_and_inference_on_valid_model_explain():

assert np.array(json.loads(response.content)['explanations']).shape == (1, 1, 28, 28)
test_utils.unregister_model("mnist")

def test_huggingface_bert_batch_inference():
batch_size = 2
batch_delay = 10000 # 10 seconds
params = (
('model_name', 'BERTSeqClassification'),
('url', 'https://torchserve.pytorch.org/mar_files/BERTSeqClassification.mar'),
('initial_workers', '1'),
('batch_size', str(batch_size)),
('max_batch_delay', str(batch_delay))
)
test_utils.start_torchserve(no_config_snapshots=True)
test_utils.register_model_with_params(params)
input_text = os.path.join(REPO_ROOT, 'examples/Huggingface_Transformers/Seq_classification_artifacts/sample_text.txt')

# Make 2 curl requests in parallel with &
# curl --header \"X-Forwarded-For: 1.2.3.4\" won't work since you can't access local host anymore
response = os.popen(f"curl http://127.0.0.1:8080/predictions/BERTSeqClassification -T {input_text} & curl http://127.0.0.1:8080/predictions/BERTSeqClassification -T {input_text}")
response = response.read()


## Assert that 2 responses are returned from the same batch
assert response == 'Not AcceptedNot Accepted'
test_utils.unregister_model('BERTSeqClassification')

def test_MMF_activity_recognition_model_register_and_inference_on_valid_model():

Expand All @@ -245,4 +268,4 @@ def test_MMF_activity_recognition_model_register_and_inference_on_valid_model():
response = ast.literal_eval(response)
response = [n.strip() for n in response]
assert response == ['Sitting at a table','Someone is sneezing','Watching a laptop or something on a laptop']
test_utils.unregister_model("MMF_activity_recognition_v2")
test_utils.unregister_model("MMF_activity_recognition_v2")