Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick][lmi] validate inputs field is of type string for request (#2583) #2585

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,16 @@ def parse_lmi_default_request_rolling_batch(payload):
f"Invalid request payload. Request payload should be a json object specifying the 'inputs' field. Received payload {payload}"
)

if not isinstance(inputs, str):
raise ValueError(
f"Invalid request payload. The 'inputs' field must be a string. Received type {type(inputs)}"
)

if len(inputs) == 0:
raise ValueError(
f"Invalid request payload. The 'inputs' field does not contain any content. Received payload {payload}"
)

parameters = payload.get("parameters", {})
if not isinstance(parameters, dict):
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,11 @@ void sendOutput(Output output, ChannelHandlerContext ctx) {
// This allows inference update HTTP code.
// If this is the first and last chunk, we're in a non-streaming case and can
// use default response without chunked transfer encoding
// Note, we have to read the code and message from output AGAIN here.
// They get changed to different values from when we read them at the start of
// this method
if (first && !supplier.hasNext()) {
status = new HttpResponseStatus(output.getCode(), output.getMessage());
FullHttpResponse resp =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status);
for (Map.Entry<String, String> entry : output.getProperties().entrySet()) {
Expand All @@ -433,8 +437,7 @@ void sendOutput(Output output, ChannelHandlerContext ctx) {
return;
}
if (first) {
code = output.getCode();
status = new HttpResponseStatus(code, output.getMessage());
status = new HttpResponseStatus(output.getCode(), output.getMessage());
HttpResponse resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status);
for (Map.Entry<String, String> entry : output.getProperties().entrySet()) {
resp.headers().set(entry.getKey(), entry.getValue());
Expand Down
Loading