From de42bcf70c3be6249136afb6f442aa5ad0b39678 Mon Sep 17 00:00:00 2001 From: Harsh Bafna Date: Thu, 10 Dec 2020 22:59:22 +0530 Subject: [PATCH 1/3] gRPC support for TorchServe (#687) * refactored torchserve job * added grpc server side implementation * added protobuff files * added grpc server startup * fixed valid port test case * automated server stub generation through gradle * enhanced sanity script to validate grpc inference api * Added grpcio-tools package * fixed path issue in grpc client * fixed incorrect exit logic in client script * removed json parse in python gRPC client * removed unnecessary file checkin * added regression test cases for gRPC regression APIs * added tolerance check * added python client stub cleanup * enhanced error handling for inference APIs * removed unused utility file * added support for datafile driven management api test collection * added gRPC support for management APIs * added minor fixes found during testing * enhanced grpc pytest suite to use grpc client for registering and unregistering model * updated command to generate python client stubs * removed netty http staus dependency from wlm framework * refacroted common code to utility module * added gRPC management api test cases in regression suite and minor fixes * added ping api * removed grpc metric api * added ssl support for gRPC server * added documentation * fixed issue after conflict resolution * added reference to python gRPC client, used in regression suite, in grpc doc * added validation for register and unregister model in sanity script * updated docs * minor fixes in grpc doc * updated gRPC server await termination code * refactored gRPC server startup code * added null check before terminating gRPC servers * minor refactoring of method name * skipped grpc package from jacoco verification * Fixed typo in doc * added error logs in gRPC client * added gRPC server interceptor to log api access data * added checkstyle fixes * fixed grpc command in readme * refactored test cases to removed code duplication * Fixed typo in link. Co-authored-by: Amit Agarwal * fixed compilation issues after conflict resolution * fixed regression suite pytest issue * fixed pytest case * fixed import * fixed sanity suite * fixed path in grpc client stub generation * fixed path for grpc client * incorporated code review comments * fixed management api newman command * fixed import issues * fixed regression pytest issues Co-authored-by: Shivam Shriwas Co-authored-by: Amit Agarwal Co-authored-by: dhanainme <60679183+dhanainme@users.noreply.github.com> Co-authored-by: Aaqib Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com> --- README.md | 24 +- docs/README.md | 1 + docs/configuration.md | 24 + docs/grpc_api.md | 70 + docs/inference_api.md | 4 + docs/management_api.md | 12 + frontend/build.gradle | 14 +- frontend/gradle.properties | 2 + frontend/server/build.gradle | 1 + .../java/org/pytorch/serve/ModelServer.java | 52 +- .../serve/grpcimpl/GRPCInterceptor.java | 37 + .../serve/grpcimpl/GRPCServiceFactory.java | 24 + .../pytorch/serve/grpcimpl/InferenceImpl.java | 107 + .../serve/grpcimpl/ManagementImpl.java | 193 + .../serve/http/InferenceRequestHandler.java | 30 +- .../serve/http/ManagementRequestHandler.java | 330 +- .../java/org/pytorch/serve/http/Session.java | 9 + .../pytorch/serve/http/StatusResponse.java | 25 +- .../http/messages/RegisterModelRequest.java | 16 + .../java/org/pytorch/serve/job/GRPCJob.java | 75 + .../main/java/org/pytorch/serve/job/Job.java | 69 + .../serve/{wlm/Job.java => job/RestJob.java} | 72 +- .../java/org/pytorch/serve/util/ApiUtils.java | 359 ++ .../org/pytorch/serve/util/ConfigManager.java | 25 + .../org/pytorch/serve/util/GRPCUtils.java | 64 + .../org/pytorch/serve/util/JsonUtils.java | 8 + .../org/pytorch/serve/util/NettyUtils.java | 8 + .../pytorch/serve/wlm/BatchAggregator.java | 6 +- .../java/org/pytorch/serve/wlm/Model.java | 1 + .../org/pytorch/serve/wlm/ModelManager.java | 124 +- .../pytorch/serve/wlm/ModelVersionedRefs.java | 6 +- .../pytorch/serve/wlm/WorkLoadManager.java | 19 +- .../serve/wlm/WorkerStateListener.java | 7 +- .../org/pytorch/serve/wlm/WorkerThread.java | 36 +- .../src/main/resources/proto/inference.proto | 35 + .../src/main/resources/proto/management.proto | 114 + .../org/pytorch/serve/ModelServerTest.java | 2 +- .../java/org/pytorch/serve/SnapshotTest.java | 6 +- frontend/tools/conf/checkstyle.xml | 4 +- frontend/tools/gradle/check.gradle | 3 + frontend/tools/gradle/formatter.gradle | 2 +- frontend/tools/gradle/proto.gradle | 39 + requirements/developer.txt | 7 +- test/postman/inference_data.json | 4 +- .../management_api_test_collection.json | 3620 +---------------- test/postman/management_data.json | 548 +++ test/pytest/test_gRPC_inference_api.py | 85 + test/pytest/test_gRPC_management_apis.py | 107 + test/pytest/test_gRPC_utils.py | 22 + test/pytest/test_utils.py | 4 + torchserve_sanity.py | 27 +- ts_scripts/api_utils.py | 3 +- ts_scripts/regression_utils.py | 12 +- ts_scripts/sanity_utils.py | 35 + ts_scripts/torchserve_grpc_client.py | 71 + ts_scripts/tsutils.py | 11 + 56 files changed, 2520 insertions(+), 4095 deletions(-) create mode 100644 docs/grpc_api.md create mode 100644 frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCServiceFactory.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/job/Job.java rename frontend/server/src/main/java/org/pytorch/serve/{wlm/Job.java => job/RestJob.java} (70%) create mode 100644 frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/util/GRPCUtils.java create mode 100644 frontend/server/src/main/resources/proto/inference.proto create mode 100644 frontend/server/src/main/resources/proto/management.proto create mode 100644 frontend/tools/gradle/proto.gradle create mode 100644 test/postman/management_data.json create mode 100644 test/pytest/test_gRPC_inference_api.py create mode 100644 test/pytest/test_gRPC_management_apis.py create mode 100644 test/pytest/test_gRPC_utils.py create mode 100644 ts_scripts/torchserve_grpc_client.py diff --git a/README.md b/README.md index 1ac8895bbe..79f14cf430 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,29 @@ After you execute the `torchserve` command above, TorchServe runs on your host, ### Get predictions from a model -To test the model server, send a request to the server's `predictions` API. +To test the model server, send a request to the server's `predictions` API. TorchServe supports all [inference](docs/inference_api.md) and [management](docs/management_api.md) api's through both [gRPC](docs/grpc_api.md) and [HTTP/REST](docs/grpc_api.md). + +#### Using GRPC APIs through python client + + - Install grpc python dependencies : + +```bash +pip install -U grpcio protobuf grpcio-tools +``` + + - Generate inference client using proto files + +```bash +python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=scripts --grpc_python_out=scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto +``` + + - Run inference using a sample client [gRPC python client](scripts/torchserve_grpc_client.py) + +```bash +python scripts/torchserve_grpc_client.py infer densenet161 examples/image_classifier/kitten.jpg +``` + +#### Using REST APIs Complete the following steps: diff --git a/docs/README.md b/docs/README.md index 4364fa91ac..186f0429b3 100644 --- a/docs/README.md +++ b/docs/README.md @@ -7,6 +7,7 @@ * [Installation](../README.md##install-torchserve) - Installation procedures * [Serving Models](server.md) - Explains how to use `torchserve`. * [REST API](rest_api.md) - Specification on the API endpoint for TorchServe + * [gRPC API](grpc_api.md) - Specification on the gRPC API endpoint for TorchServe * [Packaging Model Archive](../model-archiver/README.md) - Explains how to package model archive file, use `model-archiver`. * [Logging](logging.md) - How to configure logging * [Metrics](metrics.md) - How to configure metrics diff --git a/docs/configuration.md b/docs/configuration.md index 90c151364a..3deff9b1c9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -97,6 +97,24 @@ inference_address=https://0.0.0.0:8443 inference_address=https://172.16.1.10:8080 ``` +### Configure TorchServe gRPC listening ports +The inference gRPC API is listening on port 9090 and the management gRPC API is listening on port 9091 by default. + +To configure different ports use following poroperties + +* `grpc_inference_port`: Inference gRPC API binding port. Default: 9090 +* `grpc_management_port`: management gRPC API binding port. Default: 9091 + +Here are a couple of examples: + +```properties +grpc_inference_port=8888 +``` + +```properties +grpc_management_port=9999 +``` + ### Enable SSL To enable HTTPs, you can change `inference_address`, `management_address` or `metrics_address` protocol from http to https. For example: `inference_address=https://127.0.0.1`. @@ -201,6 +219,12 @@ By default, TorchServe uses all available GPUs for inference. Use `number_of_gpu * `metrics_format` : Use this to specify metric report format . At present, the only supported and default value for this is `prometheus' This is used in conjunction with `enable_meterics_api` option above. +### Enable metrics api +* `enable_metrics_api` : Enable or disable metric apis i.e. it can be either `true` or `false`. Default: true (Enabled) +* `metrics_format` : Use this to specify metric report format . At present, the only supported and default value for this is `prometheus` + This is used in conjunction with `enable_meterics_api` option above. + + ### Other properties Most of the following properties are designed for performance tuning. Adjusting these numbers will impact scalability and throughput. diff --git a/docs/grpc_api.md b/docs/grpc_api.md new file mode 100644 index 0000000000..88952a7e3a --- /dev/null +++ b/docs/grpc_api.md @@ -0,0 +1,70 @@ +# TorchServe gRPC API + +TorchServe also supports [gRPC APIs](../frontend/server/src/main/resources/proto) for both inference and management calls. + +TorchServe provides following gRPCs apis + +* [Inference API](../frontend/server/src/main/resources/proto/inference.proto) + - **Ping** : Gets the health status of the running server + - **Predictions** : Gets predictions from the served model + +* [Management API](../frontend/server/src/main/resources/proto/management.proto) + - **RegisterModel** : Serve a model/model-version on TorchServe + - **UnregisterModel** : Free up system resources by unregistering specific version of a model from TorchServe + - **ScaleWorker** : Dynamically adjust the number of workers for any version of a model to better serve different inference request loads. + - **ListModels** : Query default versions of current registered models + - **DescribeModel** : Get detail runtime status of default version of a model + - **SetDefault** : Set any registered version of a model as default version + +By default, TorchServe listens on port 9090 for the gRPC Inference API and 9091 for the gRPC Management API. +To configure gRPC APIs on different ports refer [configuration documentation](configuration.md) + +## Python client example for gRPC APIs + +Run following commands to Register, run inference and unregister, densenet161 model from [TorchServe model zoo](model_zoo.md) using [gRPC python client](../scripts/torchserve_grpc_client.py). + + - [Install TorchServe](../README.md#install-torchserve) + + - Clone serve repo to run this example + +```bash +git clone +cd serve +``` + + - Install gRPC python dependencies + +```bash +pip install -U grpcio protobuf grpcio-tools +``` + + - Start torchServe + +```bash +mkdir model_store +torchserve --start +``` + + - Generate python gRPC client stub using the proto files + +```bash +python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=scripts --grpc_python_out=scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto +``` + + - Register densenet161 model + +```bash +python scripts/torchserve_grpc_client.py register densenet161 +``` + + - Run inference using + +```bash +python scripts/torchserve_grpc_client.py infer densenet161 examples/image_classifier/kitten.jpg +``` + + - Unregister densenet161 model + +```bash +python scripts/torchserve_grpc_client.py unregister densenet161 +``` diff --git a/docs/inference_api.md b/docs/inference_api.md index 9dbaf2ebc4..d4dee9ea37 100644 --- a/docs/inference_api.md +++ b/docs/inference_api.md @@ -22,6 +22,8 @@ The out is OpenAPI 3.0.1 json format. You can use it to generate client code, se ## Health check API +This API follows the [InferenceAPIsService.Ping](../frontend/server/src/main/resources/proto/inference.proto) gRPC API. It returns the status of a model in the ModelServer. + TorchServe supports a `ping` API that you can call to check the health status of a running TorchServe server: ```bash @@ -38,6 +40,8 @@ If the server is running, the response is: ## Predictions API +This API follows the [InferenceAPIsService.Predictions](../frontend/server/src/main/resources/proto/inference.proto) gRPC API. It returns the status of a model in the ModelServer. + To get predictions from the default version of each loaded model, make a REST call to `/predictions/{model_name}`: * POST /predictions/{model_name} diff --git a/docs/management_api.md b/docs/management_api.md index 168aed4651..9c89ea9923 100644 --- a/docs/management_api.md +++ b/docs/management_api.md @@ -15,6 +15,8 @@ Similar to the [Inference API](inference_api.md), the Management API provides a ## Register a model +This API follows the [ManagementAPIsService.RegisterModel](../frontend/server/src/main/resources/proto/management.proto) gRPC API. + `POST /models` * `url` - Model archive download url. Supports the following locations: @@ -74,6 +76,9 @@ curl -v -X POST "http://localhost:8081/models?initial_workers=1&synchronous=true ## Scale workers +This API follows the [ManagementAPIsService.ScaleWorker](../frontend/server/src/main/resources/proto/management.proto) gRPC API. It returns the status of a model in the ModelServer. + + `PUT /models/{model_name}` * `min_worker` - (optional) the minimum number of worker processes. TorchServe will try to maintain this minimum for specified model. The default value is `1`. @@ -139,6 +144,8 @@ curl -v -X PUT "http://localhost:8081/models/noop/2.0?min_worker=3&synchronous=t ## Describe model +This API follows the [ManagementAPIsService.DescribeModel](../frontend/server/src/main/resources/proto/management.proto) gRPC API. It returns the status of a model in the ModelServer. + `GET /models/{model_name}` Use the Describe Model API to get detail runtime status of default version of a model: @@ -251,6 +258,8 @@ curl http://localhost:8081/models/noop/all ## Unregister a model +This API follows the [ManagementAPIsService.UnregisterModel](../frontend/server/src/main/resources/proto/management.proto) gRPC API. It returns the status of a model in the ModelServer. + `DELETE /models/{model_name}/{version}` Use the Unregister Model API to free up system resources by unregistering specific version of a model from TorchServe: @@ -264,6 +273,7 @@ curl -X DELETE http://localhost:8081/models/noop/1.0 ``` ## List models +This API follows the [ManagementAPIsService.ListModels](../frontend/server/src/main/resources/proto/management.proto) gRPC API. It returns the status of a model in the ModelServer. `GET /models` @@ -320,6 +330,8 @@ Example outputs of the Inference and Management APIs: ## Set Default Version +This API follows the [ManagementAPIsService.SetDefault](../frontend/server/src/main/resources/proto/management.proto) gRPC API. It returns the status of a model in the ModelServer. + `PUT /models/{model_name}/{version}/set-default` To set any registered version of a model as default version use: diff --git a/frontend/build.gradle b/frontend/build.gradle index 8c9bc4d7ba..ecf62bcd0a 100644 --- a/frontend/build.gradle +++ b/frontend/build.gradle @@ -3,10 +3,15 @@ buildscript { spotbugsVersion = '4.0.2' toolVersion = '4.0.2' } + dependencies { + classpath 'com.google.protobuf:protobuf-gradle-plugin:0.8.13' + } } plugins { - id 'com.github.spotbugs' version '4.0.2' apply false + id 'com.google.protobuf' version '0.8.13' apply false + id 'idea' + id 'com.github.spotbugs' version '4.0.2' apply false } allprojects { @@ -25,6 +30,7 @@ allprojects { } } + def javaProjects() { return subprojects.findAll(); } @@ -63,6 +69,12 @@ configure(javaProjects()) { minimum = 0.70 } } + afterEvaluate { + classDirectories.setFrom(files(classDirectories.files.collect { + fileTree(dir: "${rootProject.projectDir}/server/src/main/java", + exclude: ['org/pytorch/serve/grpc**/**']) + })) + } } } } diff --git a/frontend/gradle.properties b/frontend/gradle.properties index d072ac8237..e7de745231 100644 --- a/frontend/gradle.properties +++ b/frontend/gradle.properties @@ -8,3 +8,5 @@ slf4j_api_version=1.7.25 slf4j_log4j12_version=1.7.25 testng_version=7.1.0 torchserve_sdk_version=0.0.3 +grpc_version=1.31.1 +protoc_version=3.13.0 diff --git a/frontend/server/build.gradle b/frontend/server/build.gradle index 23c018693a..1606fbabf3 100644 --- a/frontend/server/build.gradle +++ b/frontend/server/build.gradle @@ -8,6 +8,7 @@ dependencies { testImplementation "org.testng:testng:${testng_version}" } +apply from: file("${project.rootProject.projectDir}/tools/gradle/proto.gradle") apply from: file("${project.rootProject.projectDir}/tools/gradle/launcher.gradle") jar { diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index 06a74328ff..ad16790fdb 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -1,5 +1,8 @@ package org.pytorch.serve; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerInterceptors; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -31,6 +34,8 @@ import org.pytorch.serve.archive.ModelArchive; import org.pytorch.serve.archive.ModelException; import org.pytorch.serve.archive.ModelNotFoundException; +import org.pytorch.serve.grpcimpl.GRPCInterceptor; +import org.pytorch.serve.grpcimpl.GRPCServiceFactory; import org.pytorch.serve.metrics.MetricManager; import org.pytorch.serve.servingsdk.ModelServerEndpoint; import org.pytorch.serve.servingsdk.annotations.Endpoint; @@ -53,6 +58,8 @@ public class ModelServer { private Logger logger = LoggerFactory.getLogger(ModelServer.class); private ServerGroups serverGroups; + private Server inferencegRPCServer; + private Server managementgRPCServer; private List futures = new ArrayList<>(2); private AtomicBoolean stopped = new AtomicBoolean(false); private ConfigManager configManager; @@ -104,7 +111,10 @@ public void startAndWait() throws InterruptedException, IOException, GeneralSecurityException, InvalidSnapshotException { try { - List channelFutures = start(); + List channelFutures = startRESTserver(); + + startGRPCServers(); + // Create and schedule metrics manager MetricManager.scheduleMetrics(configManager); System.out.println("Model server started."); // NOPMD @@ -305,7 +315,7 @@ public ChannelFuture initializeServer( * @throws InterruptedException if interrupted * @throws InvalidSnapshotException */ - public List start() + public List startRESTserver() throws InterruptedException, IOException, GeneralSecurityException, InvalidSnapshotException { stopped.set(false); @@ -363,6 +373,30 @@ public List start() return futures; } + public void startGRPCServers() throws IOException { + inferencegRPCServer = startGRPCServer(ConnectorType.INFERENCE_CONNECTOR); + managementgRPCServer = startGRPCServer(ConnectorType.MANAGEMENT_CONNECTOR); + } + + private Server startGRPCServer(ConnectorType connectorType) throws IOException { + + ServerBuilder s = + ServerBuilder.forPort(configManager.getGRPCPort(connectorType)) + .addService( + ServerInterceptors.intercept( + GRPCServiceFactory.getgRPCService(connectorType), + new GRPCInterceptor())); + + if (configManager.isGRPCSSLEnabled()) { + s.useTransportSecurity( + new File(configManager.getCertificateFile()), + new File(configManager.getPrivateKeyFile())); + } + Server server = s.build(); + server.start(); + return server; + } + private boolean validEndpoint(Annotation a, EndpointTypes type) { return a instanceof Endpoint && !((Endpoint) a).urlPattern().isEmpty() @@ -388,6 +422,16 @@ public boolean isRunning() { return !stopped.get(); } + private void stopgRPCServer(Server server) { + if (server != null) { + try { + server.shutdown().awaitTermination(); + } catch (InterruptedException e) { + e.printStackTrace(); // NOPMD + } + } + } + private void exitModelStore() throws ModelNotFoundException { ModelManager modelMgr = ModelManager.getInstance(); Map defModels = modelMgr.getDefaultModels(); @@ -420,6 +464,10 @@ public void stop() { } stopped.set(true); + + stopgRPCServer(inferencegRPCServer); + stopgRPCServer(managementgRPCServer); + for (ChannelFuture future : futures) { try { future.channel().close().sync(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java new file mode 100644 index 0000000000..252427ca4c --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java @@ -0,0 +1,37 @@ +package org.pytorch.serve.grpcimpl; + +import io.grpc.ForwardingServerCall; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import org.pytorch.serve.http.Session; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class GRPCInterceptor implements ServerInterceptor { + + private static final Logger logger = LoggerFactory.getLogger("ACCESS_LOG"); + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + String inetSocketString = + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString(); + String serviceName = call.getMethodDescriptor().getFullMethodName(); + Session session = new Session(inetSocketString, serviceName); + + return next.startCall( + new ForwardingServerCall.SimpleForwardingServerCall(call) { + @Override + public void close(final Status status, final Metadata trailers) { + session.setCode(status.getCode().value()); + logger.info(session.toString()); + super.close(status, trailers); + } + }, + headers); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCServiceFactory.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCServiceFactory.java new file mode 100644 index 0000000000..75a3122dc2 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCServiceFactory.java @@ -0,0 +1,24 @@ +package org.pytorch.serve.grpcimpl; + +import io.grpc.BindableService; +import org.pytorch.serve.util.ConnectorType; + +public final class GRPCServiceFactory { + + private GRPCServiceFactory() {} + + public static BindableService getgRPCService(ConnectorType connectorType) { + BindableService torchServeService = null; + switch (connectorType) { + case MANAGEMENT_CONNECTOR: + torchServeService = new ManagementImpl(); + break; + case INFERENCE_CONNECTOR: + torchServeService = new InferenceImpl(); + break; + default: + break; + } + return torchServeService; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java new file mode 100644 index 0000000000..236209a214 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/InferenceImpl.java @@ -0,0 +1,107 @@ +package org.pytorch.serve.grpcimpl; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Empty; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import java.net.HttpURLConnection; +import java.util.Map; +import java.util.UUID; +import org.pytorch.serve.archive.ModelNotFoundException; +import org.pytorch.serve.archive.ModelVersionNotFoundException; +import org.pytorch.serve.grpc.inference.InferenceAPIsServiceGrpc.InferenceAPIsServiceImplBase; +import org.pytorch.serve.grpc.inference.PredictionResponse; +import org.pytorch.serve.grpc.inference.PredictionsRequest; +import org.pytorch.serve.grpc.inference.TorchServeHealthResponse; +import org.pytorch.serve.http.BadRequestException; +import org.pytorch.serve.http.InternalServerException; +import org.pytorch.serve.http.StatusResponse; +import org.pytorch.serve.job.GRPCJob; +import org.pytorch.serve.job.Job; +import org.pytorch.serve.metrics.api.MetricAggregator; +import org.pytorch.serve.util.ApiUtils; +import org.pytorch.serve.util.JsonUtils; +import org.pytorch.serve.util.messages.InputParameter; +import org.pytorch.serve.util.messages.RequestInput; +import org.pytorch.serve.util.messages.WorkerCommands; +import org.pytorch.serve.wlm.ModelManager; + +public class InferenceImpl extends InferenceAPIsServiceImplBase { + + @Override + public void ping(Empty request, StreamObserver responseObserver) { + Runnable r = + () -> { + String response = ApiUtils.getWorkerStatus(); + TorchServeHealthResponse reply = + TorchServeHealthResponse.newBuilder() + .setHealth( + JsonUtils.GSON_PRETTY_EXPOSED.toJson( + new StatusResponse( + response, HttpURLConnection.HTTP_OK))) + .build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + }; + ApiUtils.getTorchServeHealth(r); + } + + @Override + public void predictions( + PredictionsRequest request, StreamObserver responseObserver) { + String modelName = request.getModelName(); + String modelVersion = request.getModelVersion(); + + if (modelName == null || "".equals(modelName)) { + BadRequestException e = new BadRequestException("Parameter model_name is required."); + sendErrorResponse(responseObserver, Status.INTERNAL, e, "BadRequestException.()"); + return; + } + + if (modelVersion == null || "".equals(modelVersion)) { + modelVersion = null; + } + + String requestId = UUID.randomUUID().toString(); + RequestInput inputData = new RequestInput(requestId); + + for (Map.Entry entry : request.getInputMap().entrySet()) { + inputData.addParameter( + new InputParameter(entry.getKey(), entry.getValue().toByteArray())); + } + + MetricAggregator.handleInferenceMetric(modelName, modelVersion); + Job job = + new GRPCJob( + responseObserver, + modelName, + modelVersion, + WorkerCommands.PREDICT, + inputData); + + try { + if (!ModelManager.getInstance().addJob(job)) { + String responseMessage = + ApiUtils.getInferenceErrorResponseMessage(modelName, modelVersion); + InternalServerException e = new InternalServerException(responseMessage); + sendErrorResponse( + responseObserver, Status.INTERNAL, e, "InternalServerException.()"); + } + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + sendErrorResponse(responseObserver, Status.INTERNAL, e, null); + } + } + + private void sendErrorResponse( + StreamObserver responseObserver, + Status status, + Exception e, + String description) { + responseObserver.onError( + status.withDescription(e.getMessage()) + .augmentDescription( + description == null ? e.getClass().getCanonicalName() : description) + .withCause(e) + .asRuntimeException()); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java new file mode 100644 index 0000000000..59e9914833 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java @@ -0,0 +1,193 @@ +package org.pytorch.serve.grpcimpl; + +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import java.util.concurrent.ExecutionException; +import org.pytorch.serve.archive.ModelException; +import org.pytorch.serve.archive.ModelNotFoundException; +import org.pytorch.serve.archive.ModelVersionNotFoundException; +import org.pytorch.serve.grpc.management.DescribeModelRequest; +import org.pytorch.serve.grpc.management.ListModelsRequest; +import org.pytorch.serve.grpc.management.ManagementAPIsServiceGrpc.ManagementAPIsServiceImplBase; +import org.pytorch.serve.grpc.management.ManagementResponse; +import org.pytorch.serve.grpc.management.RegisterModelRequest; +import org.pytorch.serve.grpc.management.ScaleWorkerRequest; +import org.pytorch.serve.grpc.management.SetDefaultRequest; +import org.pytorch.serve.grpc.management.UnregisterModelRequest; +import org.pytorch.serve.http.BadRequestException; +import org.pytorch.serve.http.InternalServerException; +import org.pytorch.serve.http.StatusResponse; +import org.pytorch.serve.util.ApiUtils; +import org.pytorch.serve.util.GRPCUtils; +import org.pytorch.serve.util.JsonUtils; + +public class ManagementImpl extends ManagementAPIsServiceImplBase { + + @Override + public void describeModel( + DescribeModelRequest request, StreamObserver responseObserver) { + + String modelName = request.getModelName(); + String modelVersion = request.getModelVersion(); + + String resp; + try { + resp = + JsonUtils.GSON_PRETTY.toJson( + ApiUtils.getModelDescription(modelName, modelVersion)); + sendResponse(responseObserver, resp); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + sendErrorResponse(responseObserver, Status.NOT_FOUND, e); + } + } + + @Override + public void listModels( + ListModelsRequest request, StreamObserver responseObserver) { + int limit = request.getLimit(); + int pageToken = request.getNextPageToken(); + + String modelList = JsonUtils.GSON_PRETTY.toJson(ApiUtils.getModelList(limit, pageToken)); + sendResponse(responseObserver, modelList); + } + + @Override + public void registerModel( + RegisterModelRequest request, StreamObserver responseObserver) { + org.pytorch.serve.http.messages.RegisterModelRequest registerModelRequest = + new org.pytorch.serve.http.messages.RegisterModelRequest(request); + + StatusResponse statusResponse; + try { + statusResponse = ApiUtils.registerModel(registerModelRequest); + sendStatusResponse(responseObserver, statusResponse); + } catch (InternalServerException e) { + sendException(responseObserver, e, null); + } catch (ExecutionException | InterruptedException e) { + sendException(responseObserver, e, "Error while creating workers"); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + sendErrorResponse(responseObserver, Status.NOT_FOUND, e); + } catch (ModelException | BadRequestException e) { + sendErrorResponse(responseObserver, Status.INVALID_ARGUMENT, e); + } + } + + @Override + public void scaleWorker( + ScaleWorkerRequest request, StreamObserver responseObserver) { + int minWorkers = GRPCUtils.getRegisterParam(request.getMinWorker(), 1); + int maxWorkers = GRPCUtils.getRegisterParam(request.getMaxWorker(), minWorkers); + String modelName = GRPCUtils.getRegisterParam(request.getModelName(), null); + String modelVersion = GRPCUtils.getRegisterParam(request.getModelVersion(), null); + boolean synchronous = request.getSynchronous(); + + StatusResponse statusResponse; + try { + statusResponse = + ApiUtils.updateModelWorkers( + modelName, + modelVersion, + minWorkers, + maxWorkers, + synchronous, + false, + null); + sendStatusResponse(responseObserver, statusResponse); + } catch (ExecutionException | InterruptedException e) { + sendException(responseObserver, e, "Error while creating workers"); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + sendErrorResponse(responseObserver, Status.NOT_FOUND, e); + } catch (BadRequestException e) { + sendErrorResponse(responseObserver, Status.INVALID_ARGUMENT, e); + } + } + + @Override + public void setDefault( + SetDefaultRequest request, StreamObserver responseObserver) { + String modelName = request.getModelName(); + String newModelVersion = request.getModelVersion(); + + try { + String msg = ApiUtils.setDefault(modelName, newModelVersion); + sendResponse(responseObserver, msg); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + sendErrorResponse(responseObserver, Status.NOT_FOUND, e); + } + } + + @Override + public void unregisterModel( + UnregisterModelRequest request, StreamObserver responseObserver) { + try { + String modelName = request.getModelName(); + if (modelName == null || ("").equals(modelName)) { + sendErrorResponse( + responseObserver, + Status.INVALID_ARGUMENT, + new BadRequestException("Parameter url is required.")); + } + + String modelVersion = request.getModelVersion(); + + if (("").equals(modelVersion)) { + modelVersion = null; + } + ApiUtils.unregisterModel(modelName, modelVersion); + String msg = "Model \"" + modelName + "\" unregistered"; + sendResponse(responseObserver, msg); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + sendErrorResponse(responseObserver, Status.NOT_FOUND, e); + } catch (BadRequestException e) { + sendErrorResponse(responseObserver, Status.INVALID_ARGUMENT, e); + } + } + + private void sendResponse(StreamObserver responseObserver, String msg) { + ManagementResponse reply = ManagementResponse.newBuilder().setMsg(msg).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + + private void sendErrorResponse( + StreamObserver responseObserver, + Status status, + String description, + String errorClass) { + responseObserver.onError( + status.withDescription(description) + .augmentDescription(errorClass) + .asRuntimeException()); + } + + private void sendErrorResponse( + StreamObserver responseObserver, Status status, Exception e) { + responseObserver.onError( + status.withDescription(e.getMessage()) + .augmentDescription(e.getClass().getCanonicalName()) + .asRuntimeException()); + } + + private void sendStatusResponse( + StreamObserver responseObserver, StatusResponse statusResponse) { + int httpResponseStatusCode = statusResponse.getHttpResponseCode(); + if (httpResponseStatusCode >= 200 && httpResponseStatusCode < 300) { + sendResponse(responseObserver, statusResponse.getStatus()); + } else { + sendErrorResponse( + responseObserver, + GRPCUtils.getGRPCStatusCode(statusResponse.getHttpResponseCode()), + statusResponse.getE().getMessage(), + statusResponse.getE().getClass().getCanonicalName()); + } + } + + private void sendException( + StreamObserver responseObserver, Exception e, String description) { + sendErrorResponse( + responseObserver, + Status.INTERNAL, + description == null ? e.getMessage() : description, + e.getClass().getCanonicalName()); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/InferenceRequestHandler.java index da3993654a..72cf498a2e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/InferenceRequestHandler.java @@ -9,19 +9,22 @@ import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory; import io.netty.handler.codec.http.multipart.HttpDataFactory; import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder; +import java.net.HttpURLConnection; import java.util.List; import java.util.Map; import org.pytorch.serve.archive.ModelException; import org.pytorch.serve.archive.ModelNotFoundException; import org.pytorch.serve.archive.ModelVersionNotFoundException; +import org.pytorch.serve.job.Job; +import org.pytorch.serve.job.RestJob; import org.pytorch.serve.metrics.api.MetricAggregator; import org.pytorch.serve.openapi.OpenApiUtils; import org.pytorch.serve.servingsdk.ModelServerEndpoint; +import org.pytorch.serve.util.ApiUtils; import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.util.messages.InputParameter; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.util.messages.WorkerCommands; -import org.pytorch.serve.wlm.Job; import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; import org.slf4j.Logger; @@ -54,7 +57,15 @@ protected void handleRequest( } else { switch (segments[1]) { case "ping": - ModelManager.getInstance().workerStatus(ctx); + Runnable r = + () -> { + String response = ApiUtils.getWorkerStatus(); + NettyUtils.sendJsonResponse( + ctx, + new StatusResponse( + response, HttpURLConnection.HTTP_OK)); + }; + ApiUtils.getTorchServeHealth(r); break; case "models": case "invocations": @@ -214,21 +225,10 @@ private void predict( } MetricAggregator.handleInferenceMetric(modelName, modelVersion); - Job job = new Job(ctx, modelName, modelVersion, WorkerCommands.PREDICT, input); + Job job = new RestJob(ctx, modelName, modelVersion, WorkerCommands.PREDICT, input); if (!ModelManager.getInstance().addJob(job)) { String responseMessage = - "Model \"" - + modelName - + "\" Version " - + modelVersion - + " has no worker to serve inference request. Please use scale workers API to add workers."; - - if (modelVersion == null) { - responseMessage = - "Model \"" - + modelName - + "\" has no worker to serve inference request. Please use scale workers API to add workers."; - } + ApiUtils.getInferenceErrorResponseMessage(modelName, modelVersion); throw new ServiceUnavailableException(responseMessage); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java index 10f85cb81c..901f4a4623 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java @@ -8,24 +8,16 @@ import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.util.CharsetUtil; -import java.io.IOException; -import java.nio.file.FileAlreadyExistsException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Function; -import org.apache.commons.io.FilenameUtils; -import org.pytorch.serve.archive.Manifest; -import org.pytorch.serve.archive.ModelArchive; +import java.util.concurrent.ExecutionException; import org.pytorch.serve.archive.ModelException; import org.pytorch.serve.archive.ModelNotFoundException; import org.pytorch.serve.archive.ModelVersionNotFoundException; import org.pytorch.serve.http.messages.RegisterModelRequest; import org.pytorch.serve.servingsdk.ModelServerEndpoint; -import org.pytorch.serve.snapshot.SnapshotManager; -import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.ApiUtils; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.wlm.Model; @@ -117,32 +109,8 @@ private boolean isKFV1ManagementReq(String[] segments) { private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder decoder) { int limit = NettyUtils.getIntParameter(decoder, "limit", 100); int pageToken = NettyUtils.getIntParameter(decoder, "next_page_token", 0); - if (limit > 100 || limit < 0) { - limit = 100; - } - if (pageToken < 0) { - pageToken = 0; - } - - ModelManager modelManager = ModelManager.getInstance(); - Map models = modelManager.getDefaultModels(); - - List keys = new ArrayList<>(models.keySet()); - Collections.sort(keys); - ListModelsResponse list = new ListModelsResponse(); - int last = pageToken + limit; - if (last > keys.size()) { - last = keys.size(); - } else { - list.setNextPageToken(String.valueOf(last)); - } - - for (int i = pageToken; i < last; ++i) { - String modelName = keys.get(i); - Model model = models.get(modelName); - list.addModel(modelName, model.getModelUrl()); - } + ListModelsResponse list = ApiUtils.getModelList(limit, pageToken); NettyUtils.sendJsonResponse(ctx, list); } @@ -150,53 +118,12 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco private void handleDescribeModel( ChannelHandlerContext ctx, String modelName, String modelVersion) throws ModelNotFoundException, ModelVersionNotFoundException { - ModelManager modelManager = ModelManager.getInstance(); - ArrayList resp = new ArrayList(); - - if ("all".equals(modelVersion)) { - for (Map.Entry m : modelManager.getAllModelVersions(modelName)) { - resp.add(createModelResponse(modelManager, modelName, m.getValue())); - } - } else { - Model model = modelManager.getModel(modelName, modelVersion); - if (model == null) { - throw new ModelNotFoundException("Model not found: " + modelName); - } - resp.add(createModelResponse(modelManager, modelName, model)); - } + ArrayList resp = + ApiUtils.getModelDescription(modelName, modelVersion); NettyUtils.sendJsonResponse(ctx, resp); } - private DescribeModelResponse createModelResponse( - ModelManager modelManager, String modelName, Model model) { - DescribeModelResponse resp = new DescribeModelResponse(); - resp.setModelName(modelName); - resp.setModelUrl(model.getModelUrl()); - resp.setBatchSize(model.getBatchSize()); - resp.setMaxBatchDelay(model.getMaxBatchDelay()); - resp.setMaxWorkers(model.getMaxWorkers()); - resp.setMinWorkers(model.getMinWorkers()); - resp.setLoadedAtStartup(modelManager.getStartupModels().contains(modelName)); - Manifest manifest = model.getModelArchive().getManifest(); - resp.setModelVersion(manifest.getModel().getModelVersion()); - resp.setRuntime(manifest.getRuntime().getValue()); - - List workers = modelManager.getWorkers(model.getModelVersionName()); - for (WorkerThread worker : workers) { - String workerId = worker.getWorkerId(); - long startTime = worker.getStartTime(); - boolean isRunning = worker.isRunning(); - int gpuId = worker.getGpuId(); - long memory = worker.getMemory(); - int pid = worker.getPid(); - String gpuUsage = worker.getGpuUsage(); - resp.addWorker(workerId, startTime, isRunning, gpuId, memory, pid, gpuUsage); - } - - return resp; - } - private void handleKF1ModelReady( ChannelHandlerContext ctx, String modelName, String modelVersion) throws ModelNotFoundException, ModelVersionNotFoundException { @@ -222,104 +149,31 @@ private void handleRegisterModel( ChannelHandlerContext ctx, QueryStringDecoder decoder, FullHttpRequest req) throws ModelException { RegisterModelRequest registerModelRequest = parseRequest(req, decoder); - String modelUrl = registerModelRequest.getModelUrl(); - if (modelUrl == null) { - throw new BadRequestException("Parameter url is required."); - } - - String modelName = registerModelRequest.getModelName(); - String runtime = registerModelRequest.getRuntime(); - String handler = registerModelRequest.getHandler(); - int batchSize = registerModelRequest.getBatchSize(); - int maxBatchDelay = registerModelRequest.getMaxBatchDelay(); - int initialWorkers = registerModelRequest.getInitialWorkers(); - boolean synchronous = registerModelRequest.getSynchronous(); - int responseTimeout = registerModelRequest.getResponseTimeout(); - if (responseTimeout == -1) { - responseTimeout = ConfigManager.getInstance().getDefaultResponseTimeout(); - } - Manifest.RuntimeType runtimeType = null; - if (runtime != null) { - try { - runtimeType = Manifest.RuntimeType.fromValue(runtime); - } catch (IllegalArgumentException e) { - throw new BadRequestException(e); - } - } - - ModelManager modelManager = ModelManager.getInstance(); - final ModelArchive archive; + StatusResponse statusResponse; try { - - archive = - modelManager.registerModel( - modelUrl, - modelName, - runtimeType, - handler, - batchSize, - maxBatchDelay, - responseTimeout, - null); - } catch (FileAlreadyExistsException e) { - throw new InternalServerException( - "Model file already exists " + FilenameUtils.getName(modelUrl), e); - } catch (IOException | InterruptedException e) { - throw new InternalServerException("Failed to save model: " + modelUrl, e); - } - - modelName = archive.getModelName(); - - if (initialWorkers <= 0) { - final String msg = - "Model \"" - + modelName - + "\" Version: " - + archive.getModelVersion() - + " registered with 0 initial workers. Use scale workers API to add workers for the model."; - SnapshotManager.getInstance().saveSnapshot(); - NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)); - return; + statusResponse = ApiUtils.registerModel(registerModelRequest); + } catch (ExecutionException | InterruptedException | InternalServerException e) { + String message; + if (e instanceof InternalServerException) { + message = e.getMessage(); + } else { + message = "Error while creating workers"; + } + statusResponse = new StatusResponse(); + statusResponse.setE(e); + statusResponse.setStatus(message); + statusResponse.setHttpResponseCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()); } - - updateModelWorkers( - ctx, - modelName, - archive.getModelVersion(), - initialWorkers, - initialWorkers, - synchronous, - true, - f -> { - modelManager.unregisterModel(archive.getModelName(), archive.getModelVersion()); - return null; - }); + sendResponse(ctx, statusResponse); } private void handleUnregisterModel( ChannelHandlerContext ctx, String modelName, String modelVersion) throws ModelNotFoundException, InternalServerException, RequestTimeoutException, ModelVersionNotFoundException { - ModelManager modelManager = ModelManager.getInstance(); - HttpResponseStatus httpResponseStatus = - modelManager.unregisterModel(modelName, modelVersion); - if (httpResponseStatus == HttpResponseStatus.NOT_FOUND) { - throw new ModelNotFoundException("Model not found: " + modelName); - } else if (httpResponseStatus == HttpResponseStatus.BAD_REQUEST) { - throw new ModelVersionNotFoundException( - String.format( - "Model version: %s does not exist for model: %s", - modelVersion, modelName)); - } else if (httpResponseStatus == HttpResponseStatus.INTERNAL_SERVER_ERROR) { - throw new InternalServerException("Interrupted while cleaning resources: " + modelName); - } else if (httpResponseStatus == HttpResponseStatus.REQUEST_TIMEOUT) { - throw new RequestTimeoutException("Timed out while cleaning resources: " + modelName); - } else if (httpResponseStatus == HttpResponseStatus.FORBIDDEN) { - throw new InvalidModelVersionException( - "Cannot remove default version for model " + modelName); - } + ApiUtils.unregisterModel(modelName, modelVersion); String msg = "Model \"" + modelName + "\" unregistered"; - NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)); + NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg, HttpResponseStatus.OK.code())); } private void handleScaleModel( @@ -333,92 +187,28 @@ private void handleScaleModel( if (modelVersion == null) { modelVersion = NettyUtils.getParameter(decoder, "model_version", null); } - if (maxWorkers < minWorkers) { - throw new BadRequestException("max_worker cannot be less than min_worker."); - } + boolean synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", null)); - ModelManager modelManager = ModelManager.getInstance(); - if (!modelManager.getDefaultModels().containsKey(modelName)) { - throw new ModelNotFoundException("Model not found: " + modelName); - } - updateModelWorkers( - ctx, modelName, modelVersion, minWorkers, maxWorkers, synchronous, false, null); - } - - private void updateModelWorkers( - final ChannelHandlerContext ctx, - final String modelName, - final String modelVersion, - int minWorkers, - int maxWorkers, - boolean synchronous, - boolean isInit, - final Function onError) - throws ModelVersionNotFoundException { - ModelManager modelManager = ModelManager.getInstance(); - CompletableFuture future = - modelManager.updateModel(modelName, modelVersion, minWorkers, maxWorkers); - if (!synchronous) { - NettyUtils.sendJsonResponse( - ctx, - new StatusResponse("Processing worker updates..."), - HttpResponseStatus.ACCEPTED); - return; + StatusResponse statusResponse; + try { + statusResponse = + ApiUtils.updateModelWorkers( + modelName, + modelVersion, + minWorkers, + maxWorkers, + synchronous, + false, + null); + } catch (ExecutionException | InterruptedException e) { + statusResponse = new StatusResponse(); + statusResponse.setE(e); + statusResponse.setStatus("Error while creating workers"); + statusResponse.setHttpResponseCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()); } - future.thenApply( - v -> { - boolean status = - modelManager.scaleRequestStatus(modelName, modelVersion); - if (HttpResponseStatus.OK.equals(v)) { - if (status) { - String msg = - "Workers scaled to " - + minWorkers - + " for model: " - + modelName; - if (modelVersion != null) { - msg += ", version: " + modelVersion; // NOPMD - } - - if (isInit) { - msg = - "Model \"" - + modelName - + "\" Version: " - + modelVersion - + " registered with " - + minWorkers - + " initial workers"; - } - - NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg), v); - } else { - NettyUtils.sendJsonResponse( - ctx, - new StatusResponse("Workers scaling in progress..."), - new HttpResponseStatus(210, "Partial Success")); - } - } else { - NettyUtils.sendError( - ctx, - v, - new InternalServerException("Failed to start workers")); - if (onError != null) { - onError.apply(null); - } - } - return v; - }) - .exceptionally( - (e) -> { - if (onError != null) { - onError.apply(null); - } - NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); - return null; - }); + sendResponse(ctx, statusResponse); } private RegisterModelRequest parseRequest(FullHttpRequest req, QueryStringDecoder decoder) { @@ -435,25 +225,29 @@ private RegisterModelRequest parseRequest(FullHttpRequest req, QueryStringDecode } private void setDefaultModelVersion( - ChannelHandlerContext ctx, String modelName, String newModelVersion) - throws ModelNotFoundException, InternalServerException, RequestTimeoutException, - ModelVersionNotFoundException { - ModelManager modelManager = ModelManager.getInstance(); - HttpResponseStatus httpResponseStatus = - modelManager.setDefaultVersion(modelName, newModelVersion); - if (httpResponseStatus == HttpResponseStatus.NOT_FOUND) { - throw new ModelNotFoundException("Model not found: " + modelName); - } else if (httpResponseStatus == HttpResponseStatus.FORBIDDEN) { - throw new ModelVersionNotFoundException( - "Model version " + newModelVersion + " does not exist for model " + modelName); + ChannelHandlerContext ctx, String modelName, String newModelVersion) { + try { + String msg = ApiUtils.setDefault(modelName, newModelVersion); + NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg, HttpResponseStatus.OK.code())); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, e); + } + } + + private void sendResponse(ChannelHandlerContext ctx, StatusResponse statusResponse) { + if (statusResponse != null) { + if (statusResponse.getHttpResponseCode() >= 200 + && statusResponse.getHttpResponseCode() < 300) { + NettyUtils.sendJsonResponse(ctx, statusResponse); + } else { + // Re-map HTTPURLConnections HTTP_ENTITY_TOO_LARGE to Netty's INSUFFICIENT_STORAGE + int httpResponseStatus = statusResponse.getHttpResponseCode(); + NettyUtils.sendError( + ctx, + HttpResponseStatus.valueOf( + httpResponseStatus == 413 ? 507 : httpResponseStatus), + statusResponse.getE()); + } } - String msg = - "Default vesion succsesfully updated for model \"" - + modelName - + "\" to \"" - + newModelVersion - + "\""; - SnapshotManager.getInstance().saveSnapshot(); - NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)); } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/Session.java b/frontend/server/src/main/java/org/pytorch/serve/http/Session.java index 3947ac121e..11247331f9 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/Session.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/Session.java @@ -27,6 +27,15 @@ public Session(String remoteIp, HttpRequest request) { startTime = System.currentTimeMillis(); } + public Session(String remoteIp, String gRPCMethod) { + this.remoteIp = remoteIp; + method = "gRPC"; + protocol = "HTTP/2.0"; + this.uri = gRPCMethod; + requestId = UUID.randomUUID().toString(); + startTime = System.currentTimeMillis(); + } + public String getRequestId() { return requestId; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/StatusResponse.java b/frontend/server/src/main/java/org/pytorch/serve/http/StatusResponse.java index e8378a0ec7..36cb1e4de8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/StatusResponse.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/StatusResponse.java @@ -1,13 +1,26 @@ package org.pytorch.serve.http; +import com.google.gson.annotations.Expose; + public class StatusResponse { - private String status; + private int httpResponseCode; + @Expose private String status; + private Throwable e; public StatusResponse() {} - public StatusResponse(String status) { + public StatusResponse(String status, int httpResponseCode) { this.status = status; + this.httpResponseCode = httpResponseCode; + } + + public int getHttpResponseCode() { + return httpResponseCode; + } + + public void setHttpResponseCode(int httpResponseCode) { + this.httpResponseCode = httpResponseCode; } public String getStatus() { @@ -17,4 +30,12 @@ public String getStatus() { public void setStatus(String status) { this.status = status; } + + public Throwable getE() { + return e; + } + + public void setE(Throwable e) { + this.e = e; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/messages/RegisterModelRequest.java b/frontend/server/src/main/java/org/pytorch/serve/http/messages/RegisterModelRequest.java index edbcb7e4f4..9f5fe69ebc 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/messages/RegisterModelRequest.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/messages/RegisterModelRequest.java @@ -3,6 +3,7 @@ import com.google.gson.annotations.SerializedName; import io.netty.handler.codec.http.QueryStringDecoder; import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.GRPCUtils; import org.pytorch.serve.util.NettyUtils; /** Register Model Request for Model server */ @@ -50,6 +51,21 @@ public RegisterModelRequest(QueryStringDecoder decoder) { modelUrl = NettyUtils.getParameter(decoder, "url", null); } + public RegisterModelRequest(org.pytorch.serve.grpc.management.RegisterModelRequest request) { + modelName = GRPCUtils.getRegisterParam(request.getModelName(), null); + runtime = GRPCUtils.getRegisterParam(request.getRuntime(), null); + handler = GRPCUtils.getRegisterParam(request.getHandler(), null); + batchSize = GRPCUtils.getRegisterParam(request.getBatchSize(), 1); + maxBatchDelay = GRPCUtils.getRegisterParam(request.getMaxBatchDelay(), 100); + initialWorkers = + GRPCUtils.getRegisterParam( + request.getInitialWorkers(), + ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel()); + synchronous = request.getSynchronous(); + responseTimeout = GRPCUtils.getRegisterParam(request.getResponseTimeout(), -1); + modelUrl = GRPCUtils.getRegisterParam(request.getUrl(), null); + } + public RegisterModelRequest() { batchSize = 1; maxBatchDelay = 100; diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java new file mode 100644 index 0000000000..d7a90e01ac --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java @@ -0,0 +1,75 @@ +package org.pytorch.serve.job; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.pytorch.serve.grpc.inference.PredictionResponse; +import org.pytorch.serve.metrics.Dimension; +import org.pytorch.serve.metrics.Metric; +import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.GRPCUtils; +import org.pytorch.serve.util.messages.RequestInput; +import org.pytorch.serve.util.messages.WorkerCommands; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class GRPCJob extends Job { + private static final Logger logger = LoggerFactory.getLogger(Job.class); + private static final org.apache.log4j.Logger loggerTsMetrics = + org.apache.log4j.Logger.getLogger(ConfigManager.MODEL_SERVER_METRICS_LOGGER); + private static final Dimension DIMENSION = new Dimension("Level", "Host"); + + private StreamObserver predictionResponseObserver; + + public GRPCJob( + StreamObserver predictionResponseObserver, + String modelName, + String version, + WorkerCommands cmd, + RequestInput input) { + super(modelName, version, cmd, input); + this.predictionResponseObserver = predictionResponseObserver; + } + + @Override + public void response( + byte[] body, + CharSequence contentType, + int statusCode, + String statusPhrase, + Map responseHeaders) { + + ByteString output = ByteString.copyFrom(body); + PredictionResponse reply = PredictionResponse.newBuilder().setPrediction(output).build(); + predictionResponseObserver.onNext(reply); + predictionResponseObserver.onCompleted(); + + logger.debug( + "Waiting time ns: {}, Backend time ns: {}", + getScheduled() - getBegin(), + System.nanoTime() - getScheduled()); + String queueTime = + String.valueOf( + TimeUnit.MILLISECONDS.convert( + getScheduled() - getBegin(), TimeUnit.NANOSECONDS)); + loggerTsMetrics.info( + new Metric( + "QueueTime", + queueTime, + "ms", + ConfigManager.getInstance().getHostName(), + DIMENSION)); + } + + @Override + public void sendError(int status, String error) { + Status responseStatus = GRPCUtils.getGRPCStatusCode(status); + predictionResponseObserver.onError( + responseStatus + .withDescription(error) + .augmentDescription("org.pytorch.serve.http.InternalServerException") + .asRuntimeException()); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/Job.java b/frontend/server/src/main/java/org/pytorch/serve/job/Job.java new file mode 100644 index 0000000000..c4c4f7e3f7 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/job/Job.java @@ -0,0 +1,69 @@ +package org.pytorch.serve.job; + +import java.util.Map; +import org.pytorch.serve.util.messages.RequestInput; +import org.pytorch.serve.util.messages.WorkerCommands; + +public abstract class Job { + + private String modelName; + private String modelVersion; + private WorkerCommands cmd; // Else its data msg or inf requests + private RequestInput input; + private long begin; + private long scheduled; + + public Job(String modelName, String version, WorkerCommands cmd, RequestInput input) { + this.modelName = modelName; + this.cmd = cmd; + this.input = input; + this.modelVersion = version; + begin = System.nanoTime(); + scheduled = begin; + } + + public String getJobId() { + return input.getRequestId(); + } + + public String getModelName() { + return modelName; + } + + public String getModelVersion() { + return modelVersion; + } + + public WorkerCommands getCmd() { + return cmd; + } + + public boolean isControlCmd() { + return !WorkerCommands.PREDICT.equals(cmd); + } + + public RequestInput getPayload() { + return input; + } + + public void setScheduled() { + scheduled = System.nanoTime(); + } + + public long getBegin() { + return begin; + } + + public long getScheduled() { + return scheduled; + } + + public abstract void response( + byte[] body, + CharSequence contentType, + int statusCode, + String statusPhrase, + Map responseHeaders); + + public abstract void sendError(int status, String error); +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Job.java b/frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java similarity index 70% rename from frontend/server/src/main/java/org/pytorch/serve/wlm/Job.java rename to frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java index dac0342f01..9c51addf9a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Job.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/RestJob.java @@ -1,4 +1,4 @@ -package org.pytorch.serve.wlm; +package org.pytorch.serve.job; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; @@ -19,7 +19,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class Job { +public class RestJob extends Job { private static final Logger logger = LoggerFactory.getLogger(Job.class); private static final org.apache.log4j.Logger loggerTsMetrics = @@ -28,63 +28,24 @@ public class Job { private ChannelHandlerContext ctx; - private String modelName; - private String modelVersion; - private WorkerCommands cmd; // Else its data msg or inf requests - private RequestInput input; - private long begin; - private long scheduled; - - public Job( + public RestJob( ChannelHandlerContext ctx, String modelName, String version, WorkerCommands cmd, RequestInput input) { + super(modelName, version, cmd, input); this.ctx = ctx; - this.modelName = modelName; - this.cmd = cmd; - this.input = input; - this.modelVersion = version; - begin = System.nanoTime(); - scheduled = begin; - } - - public String getJobId() { - return input.getRequestId(); - } - - public String getModelName() { - return modelName; - } - - public String getModelVersion() { - return modelVersion; - } - - public WorkerCommands getCmd() { - return cmd; - } - - public boolean isControlCmd() { - return !WorkerCommands.PREDICT.equals(cmd); - } - - public RequestInput getPayload() { - return input; - } - - public void setScheduled() { - scheduled = System.nanoTime(); } + @Override public void response( byte[] body, CharSequence contentType, int statusCode, String statusPhrase, Map responseHeaders) { - long inferTime = System.nanoTime() - scheduled; + long inferTime = System.nanoTime() - getBegin(); HttpResponseStatus status = (statusPhrase == null) ? HttpResponseStatus.valueOf(statusCode) @@ -109,16 +70,17 @@ public void response( */ if (ctx != null) { MetricAggregator.handleInferenceMetric( - modelName, modelVersion, scheduled - begin, inferTime); + getModelName(), getModelVersion(), getScheduled() - getBegin(), inferTime); NettyUtils.sendHttpResponse(ctx, resp, true); } logger.debug( "Waiting time ns: {}, Backend time ns: {}", - scheduled - begin, - System.nanoTime() - scheduled); + getScheduled() - getBegin(), + System.nanoTime() - getScheduled()); String queueTime = String.valueOf( - TimeUnit.MILLISECONDS.convert(scheduled - begin, TimeUnit.NANOSECONDS)); + TimeUnit.MILLISECONDS.convert( + getScheduled() - getBegin(), TimeUnit.NANOSECONDS)); loggerTsMetrics.info( new Metric( "QueueTime", @@ -128,7 +90,8 @@ public void response( DIMENSION)); } - public void sendError(HttpResponseStatus status, String error) { + @Override + public void sendError(int status, String error) { /* * We can load the models based on the configuration file.Since this Job is * not driven by the external connections, we could have a empty context for @@ -136,12 +99,15 @@ public void sendError(HttpResponseStatus status, String error) { * by external clients. */ if (ctx != null) { - NettyUtils.sendError(ctx, status, new InternalServerException(error)); + // Mapping HTTPURLConnection's HTTP_ENTITY_TOO_LARGE to Netty's INSUFFICIENT_STORAGE + status = (status == 413) ? 507 : status; + NettyUtils.sendError( + ctx, HttpResponseStatus.valueOf(status), new InternalServerException(error)); } logger.debug( "Waiting time ns: {}, Inference time ns: {}", - scheduled - begin, - System.nanoTime() - begin); + getScheduled() - getBegin(), + System.nanoTime() - getBegin()); } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java new file mode 100644 index 0000000000..aa4c65875e --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java @@ -0,0 +1,359 @@ +package org.pytorch.serve.util; + +import io.netty.handler.codec.http.HttpResponseStatus; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.nio.file.FileAlreadyExistsException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import org.apache.commons.io.FilenameUtils; +import org.pytorch.serve.archive.Manifest; +import org.pytorch.serve.archive.ModelArchive; +import org.pytorch.serve.archive.ModelException; +import org.pytorch.serve.archive.ModelNotFoundException; +import org.pytorch.serve.archive.ModelVersionNotFoundException; +import org.pytorch.serve.http.BadRequestException; +import org.pytorch.serve.http.DescribeModelResponse; +import org.pytorch.serve.http.InternalServerException; +import org.pytorch.serve.http.InvalidModelVersionException; +import org.pytorch.serve.http.ListModelsResponse; +import org.pytorch.serve.http.RequestTimeoutException; +import org.pytorch.serve.http.StatusResponse; +import org.pytorch.serve.http.messages.RegisterModelRequest; +import org.pytorch.serve.snapshot.SnapshotManager; +import org.pytorch.serve.wlm.Model; +import org.pytorch.serve.wlm.ModelManager; +import org.pytorch.serve.wlm.ModelVersionedRefs; +import org.pytorch.serve.wlm.WorkerThread; + +public final class ApiUtils { + + private ApiUtils() {} + + public static ListModelsResponse getModelList(int limit, int pageToken) { + if (limit > 100 || limit < 0) { + limit = 100; + } + if (pageToken < 0) { + pageToken = 0; + } + + ModelManager modelManager = ModelManager.getInstance(); + Map models = modelManager.getDefaultModels(); + + List keys = new ArrayList<>(models.keySet()); + Collections.sort(keys); + ListModelsResponse list = new ListModelsResponse(); + + int last = pageToken + limit; + if (last > keys.size()) { + last = keys.size(); + } else { + list.setNextPageToken(String.valueOf(last)); + } + + for (int i = pageToken; i < last; ++i) { + String modelName = keys.get(i); + Model model = models.get(modelName); + list.addModel(modelName, model.getModelUrl()); + } + + return list; + } + + public static ArrayList getModelDescription( + String modelName, String modelVersion) + throws ModelNotFoundException, ModelVersionNotFoundException { + ModelManager modelManager = ModelManager.getInstance(); + ArrayList resp = new ArrayList(); + + if ("all".equals(modelVersion)) { + for (Map.Entry m : modelManager.getAllModelVersions(modelName)) { + resp.add(createModelResponse(modelManager, modelName, m.getValue())); + } + } else { + Model model = modelManager.getModel(modelName, modelVersion); + if (model == null) { + throw new ModelNotFoundException("Model not found: " + modelName); + } + resp.add(createModelResponse(modelManager, modelName, model)); + } + + return resp; + } + + public static String setDefault(String modelName, String newModelVersion) + throws ModelNotFoundException, ModelVersionNotFoundException { + ModelManager modelManager = ModelManager.getInstance(); + modelManager.setDefaultVersion(modelName, newModelVersion); + String msg = + "Default vesion succsesfully updated for model \"" + + modelName + + "\" to \"" + + newModelVersion + + "\""; + SnapshotManager.getInstance().saveSnapshot(); + return msg; + } + + public static StatusResponse registerModel(RegisterModelRequest registerModelRequest) + throws ModelException, InternalServerException, ExecutionException, + InterruptedException { + String modelUrl = registerModelRequest.getModelUrl(); + if (modelUrl == null) { + throw new BadRequestException("Parameter url is required."); + } + + String modelName = registerModelRequest.getModelName(); + String runtime = registerModelRequest.getRuntime(); + String handler = registerModelRequest.getHandler(); + int batchSize = registerModelRequest.getBatchSize(); + int maxBatchDelay = registerModelRequest.getMaxBatchDelay(); + int initialWorkers = registerModelRequest.getInitialWorkers(); + int responseTimeout = registerModelRequest.getResponseTimeout(); + if (responseTimeout == -1) { + responseTimeout = ConfigManager.getInstance().getDefaultResponseTimeout(); + } + + Manifest.RuntimeType runtimeType = null; + if (runtime != null) { + try { + runtimeType = Manifest.RuntimeType.fromValue(runtime); + } catch (IllegalArgumentException e) { + throw new BadRequestException(e); + } + } + + ModelManager modelManager = ModelManager.getInstance(); + final ModelArchive archive; + try { + archive = + modelManager.registerModel( + modelUrl, + modelName, + runtimeType, + handler, + batchSize, + maxBatchDelay, + responseTimeout, + null); + } catch (FileAlreadyExistsException e) { + throw new InternalServerException( + "Model file already exists " + FilenameUtils.getName(modelUrl), e); + } catch (IOException | InterruptedException e) { + throw new InternalServerException("Failed to save model: " + modelUrl, e); + } + + modelName = archive.getModelName(); + if (initialWorkers <= 0) { + final String msg = + "Model \"" + + modelName + + "\" Version: " + + archive.getModelVersion() + + " registered with 0 initial workers. Use scale workers API to add workers for the model."; + SnapshotManager.getInstance().saveSnapshot(); + return new StatusResponse(msg, HttpURLConnection.HTTP_OK); + } + + return ApiUtils.updateModelWorkers( + modelName, + archive.getModelVersion(), + initialWorkers, + initialWorkers, + registerModelRequest.getSynchronous(), + true, + f -> { + modelManager.unregisterModel(archive.getModelName(), archive.getModelVersion()); + return null; + }); + } + + public static StatusResponse updateModelWorkers( + String modelName, + String modelVersion, + int minWorkers, + int maxWorkers, + boolean synchronous, + boolean isInit, + final Function onError) + throws ModelVersionNotFoundException, ModelNotFoundException, ExecutionException, + InterruptedException { + + ModelManager modelManager = ModelManager.getInstance(); + if (maxWorkers < minWorkers) { + throw new BadRequestException("max_worker cannot be less than min_worker."); + } + if (!modelManager.getDefaultModels().containsKey(modelName)) { + throw new ModelNotFoundException("Model not found: " + modelName); + } + + CompletableFuture future = + modelManager.updateModel(modelName, modelVersion, minWorkers, maxWorkers); + + StatusResponse statusResponse = new StatusResponse(); + + if (!synchronous) { + return new StatusResponse( + "Processing worker updates...", HttpURLConnection.HTTP_ACCEPTED); + } + + CompletableFuture statusResponseCompletableFuture = + future.thenApply( + v -> { + boolean status = + modelManager.scaleRequestStatus( + modelName, modelVersion); + + if (HttpURLConnection.HTTP_OK == v) { + if (status) { + String msg = + "Workers scaled to " + + minWorkers + + " for model: " + + modelName; + if (modelVersion != null) { + msg += ", version: " + modelVersion; // NOPMD + } + + if (isInit) { + msg = + "Model \"" + + modelName + + "\" Version: " + + modelVersion + + " registered with " + + minWorkers + + " initial workers"; + } + + statusResponse.setStatus(msg); + statusResponse.setHttpResponseCode(v); + } else { + statusResponse.setStatus( + "Workers scaling in progress..."); + statusResponse.setHttpResponseCode( + HttpURLConnection.HTTP_PARTIAL); + } + } else { + statusResponse.setHttpResponseCode(v); + statusResponse.setE( + new InternalServerException( + "Failed to start workers")); + if (onError != null) { + onError.apply(null); + } + } + return statusResponse; + }) + .exceptionally( + (e) -> { + if (onError != null) { + onError.apply(null); + } + statusResponse.setStatus(e.getMessage()); + statusResponse.setHttpResponseCode( + HttpURLConnection.HTTP_INTERNAL_ERROR); + statusResponse.setE(e); + return statusResponse; + }); + + return statusResponseCompletableFuture.get(); + } + + public static void unregisterModel(String modelName, String modelVersion) + throws ModelNotFoundException, ModelVersionNotFoundException { + ModelManager modelManager = ModelManager.getInstance(); + int httpResponseStatus = modelManager.unregisterModel(modelName, modelVersion); + if (httpResponseStatus == HttpResponseStatus.NOT_FOUND.code()) { + throw new ModelNotFoundException("Model not found: " + modelName); + } else if (httpResponseStatus == HttpResponseStatus.BAD_REQUEST.code()) { + throw new ModelVersionNotFoundException( + String.format( + "Model version: %s does not exist for model: %s", + modelVersion, modelName)); + } else if (httpResponseStatus == HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) { + throw new InternalServerException("Interrupted while cleaning resources: " + modelName); + } else if (httpResponseStatus == HttpResponseStatus.REQUEST_TIMEOUT.code()) { + throw new RequestTimeoutException("Timed out while cleaning resources: " + modelName); + } else if (httpResponseStatus == HttpResponseStatus.FORBIDDEN.code()) { + throw new InvalidModelVersionException( + "Cannot remove default version for model " + modelName); + } + } + + public static void getTorchServeHealth(Runnable r) { + ModelManager modelManager = ModelManager.getInstance(); + modelManager.submitTask(r); + } + + public static String getWorkerStatus() { + ModelManager modelManager = ModelManager.getInstance(); + String response = "Healthy"; + int numWorking = 0; + int numScaled = 0; + + for (Map.Entry m : modelManager.getAllModels()) { + numScaled += m.getValue().getDefaultModel().getMinWorkers(); + numWorking += + modelManager.getNumRunningWorkers( + m.getValue().getDefaultModel().getModelVersionName()); + } + + if ((numWorking > 0) && (numWorking < numScaled)) { + response = "Partial Healthy"; + } else if ((numWorking == 0) && (numScaled > 0)) { + response = "Unhealthy"; + } + // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" and + // "Unhealthy" + return response; + } + + private static DescribeModelResponse createModelResponse( + ModelManager modelManager, String modelName, Model model) { + DescribeModelResponse resp = new DescribeModelResponse(); + resp.setModelName(modelName); + resp.setModelUrl(model.getModelUrl()); + resp.setBatchSize(model.getBatchSize()); + resp.setMaxBatchDelay(model.getMaxBatchDelay()); + resp.setMaxWorkers(model.getMaxWorkers()); + resp.setMinWorkers(model.getMinWorkers()); + resp.setLoadedAtStartup(modelManager.getStartupModels().contains(modelName)); + Manifest manifest = model.getModelArchive().getManifest(); + resp.setModelVersion(manifest.getModel().getModelVersion()); + resp.setRuntime(manifest.getRuntime().getValue()); + + List workers = modelManager.getWorkers(model.getModelVersionName()); + for (WorkerThread worker : workers) { + String workerId = worker.getWorkerId(); + long startTime = worker.getStartTime(); + boolean isRunning = worker.isRunning(); + int gpuId = worker.getGpuId(); + long memory = worker.getMemory(); + int pid = worker.getPid(); + String gpuUsage = worker.getGpuUsage(); + resp.addWorker(workerId, startTime, isRunning, gpuId, memory, pid, gpuUsage); + } + + return resp; + } + + @SuppressWarnings("PMD") + public static String getInferenceErrorResponseMessage(String modelName, String modelVersion) { + String responseMessage = "Model \"" + modelName; + + if (modelVersion == null) { + responseMessage += "\" Version " + modelVersion; + } + + responseMessage += + "\" has no worker to serve inference request. Please use scale workers API to add workers."; + return responseMessage; + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 839ba9609e..71686a3130 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -83,6 +83,9 @@ public final class ConfigManager { private static final String TS_INSTALL_PY_DEP_PER_MODEL = "install_py_dep_per_model"; private static final String TS_METRICS_FORMAT = "metrics_format"; private static final String TS_ENABLE_METRICS_API = "enable_metrics_api"; + private static final String TS_GRPC_INFERENCE_PORT = "grpc_inference_port"; + private static final String TS_GRPC_MANAGEMENT_PORT = "grpc_management_port"; + private static final String TS_ENABLE_GRPC_SSL = "enable_grpc_ssl"; private static final String TS_INITIAL_WORKER_PORT = "initial_worker_port"; // Configuration which are not documented or enabled through environment variables @@ -279,6 +282,20 @@ public Connector getListener(ConnectorType connectorType) { return Connector.parse(binding, connectorType); } + public int getGRPCPort(ConnectorType connectorType) { + String port; + if (connectorType == ConnectorType.MANAGEMENT_CONNECTOR) { + port = prop.getProperty(TS_GRPC_MANAGEMENT_PORT, "9091"); + } else { + port = prop.getProperty(TS_GRPC_INFERENCE_PORT, "9090"); + } + return Integer.parseInt(port); + } + + public boolean isGRPCSSLEnabled() { + return Boolean.parseBoolean(getProperty(TS_ENABLE_GRPC_SSL, "false")); + } + public boolean getPreferDirectBuffer() { return Boolean.parseBoolean(getProperty(TS_PREFER_DIRECT_BUFFER, "false")); } @@ -404,6 +421,14 @@ public String getCorsAllowedHeaders() { return prop.getProperty(TS_CORS_ALLOWED_HEADERS); } + public String getPrivateKeyFile() { + return prop.getProperty(TS_PRIVATE_KEY_FILE); + } + + public String getCertificateFile() { + return prop.getProperty(TS_CERTIFICATE_FILE); + } + public SslContext getSslContext() throws IOException, GeneralSecurityException { List supportedCiphers = Arrays.asList( diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/GRPCUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/GRPCUtils.java new file mode 100644 index 0000000000..4b5d0e0d13 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/util/GRPCUtils.java @@ -0,0 +1,64 @@ +package org.pytorch.serve.util; + +import io.grpc.Status; + +public final class GRPCUtils { + + private GRPCUtils() {} + + public static String getRegisterParam(String param, String def) { + if ("".equals(param)) { + return def; + } + return param; + } + + public static int getRegisterParam(int param, int def) { + if (param > 0) { + return param; + } + return def; + } + + public static Status getGRPCStatusCode(int httpStatusCode) { + switch (httpStatusCode) { + case 400: + return Status.INVALID_ARGUMENT; + case 401: + return Status.UNAUTHENTICATED; + case 403: + return Status.PERMISSION_DENIED; + case 404: + return Status.NOT_FOUND; + case 409: + return Status.ABORTED; + case 413: + case 429: + return Status.RESOURCE_EXHAUSTED; + case 416: + return Status.OUT_OF_RANGE; + case 499: + return Status.CANCELLED; + case 504: + return Status.DEADLINE_EXCEEDED; + case 501: + return Status.UNIMPLEMENTED; + case 503: + return Status.UNAVAILABLE; + + default: + { + if (httpStatusCode >= 200 && httpStatusCode < 300) { + return Status.OK; + } + if (httpStatusCode >= 400 && httpStatusCode < 500) { + return Status.FAILED_PRECONDITION; + } + if (httpStatusCode >= 500 && httpStatusCode < 600) { + return Status.INTERNAL; + } + return Status.UNKNOWN; + } + } + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/JsonUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/JsonUtils.java index a4b2c86217..1ecc880bde 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/JsonUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/JsonUtils.java @@ -10,6 +10,14 @@ public final class JsonUtils { .setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") .setPrettyPrinting() .create(); + + public static final Gson GSON_PRETTY_EXPOSED = + new GsonBuilder() + .excludeFieldsWithoutExposeAnnotation() + .setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") + .setPrettyPrinting() + .create(); + public static final Gson GSON = new GsonBuilder().create(); private JsonUtils() {} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/NettyUtils.java b/frontend/server/src/main/java/org/pytorch/serve/util/NettyUtils.java index 69805e759a..c2bb18af40 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/NettyUtils.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/NettyUtils.java @@ -26,6 +26,7 @@ import java.util.List; import org.pytorch.serve.http.ErrorResponse; import org.pytorch.serve.http.Session; +import org.pytorch.serve.http.StatusResponse; import org.pytorch.serve.metrics.Dimension; import org.pytorch.serve.metrics.Metric; import org.pytorch.serve.util.messages.InputParameter; @@ -99,6 +100,13 @@ public static void sendJsonResponse( sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), status); } + public static void sendJsonResponse(ChannelHandlerContext ctx, StatusResponse statusResponse) { + sendJsonResponse( + ctx, + JsonUtils.GSON_PRETTY_EXPOSED.toJson(statusResponse), + HttpResponseStatus.valueOf(statusResponse.getHttpResponseCode())); + } + public static void sendJsonResponse(ChannelHandlerContext ctx, String json) { sendJsonResponse(ctx, json, HttpResponseStatus.OK); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java index ceee681945..8c95383d96 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/BatchAggregator.java @@ -1,8 +1,8 @@ package org.pytorch.serve.wlm; -import io.netty.handler.codec.http.HttpResponseStatus; import java.util.LinkedHashMap; import java.util.Map; +import org.pytorch.serve.job.Job; import org.pytorch.serve.util.messages.BaseModelRequest; import org.pytorch.serve.util.messages.ModelInferenceRequest; import org.pytorch.serve.util.messages.ModelLoadModelRequest; @@ -83,7 +83,7 @@ public void sendResponse(ModelWorkerResponse message) { if (j == null) { throw new IllegalStateException("Unexpected job: " + reqId); } - j.sendError(HttpResponseStatus.valueOf(message.getCode()), message.getMessage()); + j.sendError(message.getCode(), message.getMessage()); } if (!jobs.isEmpty()) { throw new IllegalStateException("Not all jobs get response."); @@ -91,7 +91,7 @@ public void sendResponse(ModelWorkerResponse message) { } } - public void sendError(BaseModelRequest message, String error, HttpResponseStatus status) { + public void sendError(BaseModelRequest message, String error, int status) { if (message instanceof ModelLoadModelRequest) { logger.warn("Load model failed: {}, error: {}", message.getModelName(), error); return; diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index af3c283a2e..1e365e435b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -12,6 +12,7 @@ import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.io.FilenameUtils; import org.pytorch.serve.archive.ModelArchive; +import org.pytorch.serve.job.Job; import org.pytorch.serve.util.ConfigManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index f02ce02ecb..344b258292 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -1,10 +1,8 @@ package org.pytorch.serve.wlm; import com.google.gson.JsonObject; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.HttpResponseStatus; import java.io.IOException; -import java.nio.file.FileAlreadyExistsException; +import java.net.HttpURLConnection; import java.nio.file.Path; import java.nio.file.Paths; import java.util.HashSet; @@ -24,9 +22,8 @@ import org.pytorch.serve.archive.ModelVersionNotFoundException; import org.pytorch.serve.http.ConflictStatusException; import org.pytorch.serve.http.InvalidModelVersionException; -import org.pytorch.serve.http.StatusResponse; +import org.pytorch.serve.job.Job; import org.pytorch.serve.util.ConfigManager; -import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.util.messages.EnvironmentUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,11 +34,11 @@ public final class ModelManager { private static ModelManager modelManager; - private ConfigManager configManager; - private WorkLoadManager wlm; - private ConcurrentHashMap modelsNameMap; - private HashSet startupModels; - private ScheduledExecutorService scheduler; + private final ConfigManager configManager; + private final WorkLoadManager wlm; + private final ConcurrentHashMap modelsNameMap; + private final HashSet startupModels; + private final ScheduledExecutorService scheduler; private ModelManager(ConfigManager configManager, WorkLoadManager wlm) { this.configManager = configManager; @@ -133,7 +130,7 @@ private ModelArchive createModelArchive( String handler, Manifest.RuntimeType runtime, String defaultModelName) - throws FileAlreadyExistsException, ModelException, IOException { + throws ModelException, IOException { ModelArchive archive = ModelArchive.downloadModel( configManager.getAllowedUrls(), configManager.getModelStore(), url); @@ -223,36 +220,34 @@ private void createVersionedModel(Model model, String versionId) modelsNameMap.putIfAbsent(model.getModelName(), modelVersionRef); } - public HttpResponseStatus unregisterModel(String modelName, String versionId) { + public int unregisterModel(String modelName, String versionId) { return unregisterModel(modelName, versionId, false); } - public HttpResponseStatus unregisterModel( - String modelName, String versionId, boolean isCleanUp) { + public int unregisterModel(String modelName, String versionId, boolean isCleanUp) { ModelVersionedRefs vmodel = modelsNameMap.get(modelName); if (vmodel == null) { logger.warn("Model not found: " + modelName); - return HttpResponseStatus.NOT_FOUND; + return HttpURLConnection.HTTP_NOT_FOUND; } if (versionId == null) { versionId = vmodel.getDefaultVersion(); } - Model model = null; - HttpResponseStatus httpResponseStatus = HttpResponseStatus.OK; + Model model; + int httpResponseStatus; try { model = vmodel.removeVersionModel(versionId); model.setMinWorkers(0); model.setMaxWorkers(0); - CompletableFuture futureStatus = - wlm.modelChanged(model, false, isCleanUp); + CompletableFuture futureStatus = wlm.modelChanged(model, false, isCleanUp); httpResponseStatus = futureStatus.get(); // Only continue cleaning if resource cleaning succeeded - if (httpResponseStatus == HttpResponseStatus.OK) { + if (httpResponseStatus == HttpURLConnection.HTTP_OK) { model.getModelArchive().clean(); startupModels.remove(modelName); logger.info("Model {} unregistered.", modelName); @@ -272,37 +267,29 @@ public HttpResponseStatus unregisterModel( } } catch (ModelVersionNotFoundException e) { logger.warn("Model {} version {} not found.", modelName, versionId); - httpResponseStatus = HttpResponseStatus.BAD_REQUEST; + httpResponseStatus = HttpURLConnection.HTTP_BAD_REQUEST; } catch (InvalidModelVersionException e) { logger.warn("Cannot remove default version {} for model {}", versionId, modelName); - httpResponseStatus = HttpResponseStatus.FORBIDDEN; + httpResponseStatus = HttpURLConnection.HTTP_FORBIDDEN; } catch (ExecutionException | InterruptedException e1) { logger.warn("Process was interrupted while cleaning resources."); - httpResponseStatus = HttpResponseStatus.INTERNAL_SERVER_ERROR; + httpResponseStatus = HttpURLConnection.HTTP_INTERNAL_ERROR; } return httpResponseStatus; } - public HttpResponseStatus setDefaultVersion(String modelName, String newModelVersion) - throws ModelVersionNotFoundException { - HttpResponseStatus httpResponseStatus = HttpResponseStatus.OK; + public void setDefaultVersion(String modelName, String newModelVersion) + throws ModelNotFoundException, ModelVersionNotFoundException { ModelVersionedRefs vmodel = modelsNameMap.get(modelName); if (vmodel == null) { logger.warn("Model not found: " + modelName); - return HttpResponseStatus.NOT_FOUND; - } - try { - vmodel.setDefaultVersion(newModelVersion); - } catch (ModelVersionNotFoundException e) { - logger.warn("Model version {} does not exist for model {}", newModelVersion, modelName); - httpResponseStatus = HttpResponseStatus.FORBIDDEN; + throw new ModelNotFoundException("Model not found: " + modelName); } - - return httpResponseStatus; + vmodel.setDefaultVersion(newModelVersion); } - private CompletableFuture updateModel( + private CompletableFuture updateModel( String modelName, String versionId, boolean isStartup) throws ModelVersionNotFoundException { Model model = getVersionModel(modelName, versionId); @@ -315,7 +302,7 @@ private CompletableFuture updateModel( false); } - public CompletableFuture updateModel( + public CompletableFuture updateModel( String modelName, String versionId, int minWorkers, @@ -345,7 +332,7 @@ private Model getVersionModel(String modelName, String versionId) { return vmodel.getVersionModel(versionId); } - public CompletableFuture updateModel( + public CompletableFuture updateModel( String modelName, String versionId, int minWorkers, int maxWorkers) throws ModelVersionNotFoundException { return updateModel(modelName, versionId, minWorkers, maxWorkers, false, false); @@ -387,59 +374,6 @@ public boolean addJob(Job job) throws ModelNotFoundException, ModelVersionNotFou return model.addJob(job); } - public void workerStatus(final ChannelHandlerContext ctx) { - Runnable r = - () -> { - String response = "Healthy"; - int numWorking = 0; - int numScaled = 0; - for (Map.Entry m : modelsNameMap.entrySet()) { - numScaled += m.getValue().getDefaultModel().getMinWorkers(); - numWorking += - wlm.getNumRunningWorkers( - m.getValue().getDefaultModel().getModelVersionName()); - } - - if ((numWorking > 0) && (numWorking < numScaled)) { - response = "Partial Healthy"; - } else if ((numWorking == 0) && (numScaled > 0)) { - response = "Unhealthy"; - } - - // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" - // and "Unhealthy" - NettyUtils.sendJsonResponse( - ctx, new StatusResponse(response), HttpResponseStatus.OK); - }; - wlm.scheduleAsync(r); - } - - public void modelWorkerStatus(final String modelName, final ChannelHandlerContext ctx) { - Runnable r = - () -> { - String response = "Healthy"; - int numWorking = 0; - int numScaled = 0; - ModelVersionedRefs vmodel = modelsNameMap.get(modelName); - for (Map.Entry m : vmodel.getAllVersions()) { - numScaled += m.getValue().getMinWorkers(); - numWorking += wlm.getNumRunningWorkers(m.getValue().getModelVersionName()); - } - - if ((numWorking > 0) && (numWorking < numScaled)) { - response = "Partial Healthy"; - } else if ((numWorking == 0) && (numScaled > 0)) { - response = "Unhealthy"; - } - - // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" - // and "Unhealthy" - NettyUtils.sendJsonResponse( - ctx, new StatusResponse(response), HttpResponseStatus.OK); - }; - wlm.scheduleAsync(r); - } - public boolean scaleRequestStatus(String modelName, String versionId) { Model model = modelsNameMap.get(modelName).getVersionModel(versionId); int numWorkers = 0; @@ -481,4 +415,12 @@ public Set> getAllModelVersions(String modelName) } return vmodel.getAllVersions(); } + + public Set> getAllModels() { + return modelsNameMap.entrySet(); + } + + public int getNumRunningWorkers(ModelVersionName modelVersionName) { + return wlm.getNumRunningWorkers(modelVersionName); + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelVersionedRefs.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelVersionedRefs.java index b45704a69a..fd57c09bb4 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelVersionedRefs.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelVersionedRefs.java @@ -67,7 +67,11 @@ public String getDefaultVersion() { public void setDefaultVersion(String versionId) throws ModelVersionNotFoundException { Model model = this.modelsVersionMap.get(versionId); if (model == null) { - throw new ModelVersionNotFoundException("Can't set default to: " + versionId); + throw new ModelVersionNotFoundException( + "Model version " + + versionId + + " does not exist for model " + + this.getDefaultModel().getModelName()); } logger.debug("Setting default version to {} for model {}", versionId, model.getModelName()); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index 52bf5746c8..bfce51b447 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -1,8 +1,8 @@ package org.pytorch.serve.wlm; import io.netty.channel.EventLoopGroup; -import io.netty.handler.codec.http.HttpResponseStatus; import java.io.IOException; +import java.net.HttpURLConnection; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -85,18 +85,18 @@ public int getNumRunningWorkers(ModelVersionName modelVersionName) { return numWorking; } - public CompletableFuture modelChanged( + public CompletableFuture modelChanged( Model model, boolean isStartup, boolean isCleanUp) { synchronized (model.getModelVersionName()) { boolean isSnapshotSaved = false; - CompletableFuture future = new CompletableFuture<>(); + CompletableFuture future = new CompletableFuture<>(); int minWorker = model.getMinWorkers(); int maxWorker = model.getMaxWorkers(); List threads; if (minWorker == 0) { threads = workers.remove(model.getModelVersionName()); if (threads == null) { - future.complete(HttpResponseStatus.OK); + future.complete(HttpURLConnection.HTTP_OK); if (!isStartup && !isCleanUp) { SnapshotManager.getInstance().saveSnapshot(); } @@ -133,13 +133,13 @@ public CompletableFuture modelChanged( } catch (InterruptedException | IOException e) { logger.warn( "WorkerThread interrupted during waitFor, possible async resource cleanup."); - future.complete(HttpResponseStatus.INTERNAL_SERVER_ERROR); + future.complete(HttpURLConnection.HTTP_INTERNAL_ERROR); return future; } if (!workerDestroyed) { logger.warn( "WorkerThread timed out while cleaning, please resend request."); - future.complete(HttpResponseStatus.REQUEST_TIMEOUT); + future.complete(HttpURLConnection.HTTP_CLIENT_TIMEOUT); return future; } } @@ -148,7 +148,7 @@ public CompletableFuture modelChanged( SnapshotManager.getInstance().saveSnapshot(); isSnapshotSaved = true; } - future.complete(HttpResponseStatus.OK); + future.complete(HttpURLConnection.HTTP_OK); } if (!isStartup && !isSnapshotSaved && !isCleanUp) { SnapshotManager.getInstance().saveSnapshot(); @@ -158,10 +158,7 @@ public CompletableFuture modelChanged( } private void addThreads( - List threads, - Model model, - int count, - CompletableFuture future) { + List threads, Model model, int count, CompletableFuture future) { WorkerStateListener listener = new WorkerStateListener(future, count); int maxGpu = configManager.getNumberOfGpu(); for (int i = 0; i < count; ++i) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerStateListener.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerStateListener.java index 2ccc8c1bce..c1bd8b55ce 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerStateListener.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerStateListener.java @@ -1,20 +1,19 @@ package org.pytorch.serve.wlm; -import io.netty.handler.codec.http.HttpResponseStatus; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; public class WorkerStateListener { - private CompletableFuture future; + private CompletableFuture future; private AtomicInteger count; - public WorkerStateListener(CompletableFuture future, int count) { + public WorkerStateListener(CompletableFuture future, int count) { this.future = future; this.count = new AtomicInteger(count); } - public void notifyChangeState(String modelName, WorkerState state, HttpResponseStatus status) { + public void notifyChangeState(String modelName, WorkerState state, Integer status) { // Update success and fail counts if (state == WorkerState.WORKER_MODEL_LOADED) { if (count.decrementAndGet() == 0) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index cdb8773b70..7bdd29f2ea 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -9,11 +9,11 @@ import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.HttpResponseStatus; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.HttpURLConnection; import java.net.SocketAddress; import java.nio.charset.StandardCharsets; import java.util.UUID; @@ -22,11 +22,12 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import org.pytorch.serve.job.Job; +import org.pytorch.serve.job.RestJob; import org.pytorch.serve.metrics.Dimension; import org.pytorch.serve.metrics.Metric; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.Connector; -import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.util.codec.ModelRequestEncoder; import org.pytorch.serve.util.codec.ModelResponseDecoder; import org.pytorch.serve.util.messages.BaseModelRequest; @@ -172,7 +173,7 @@ public void run() { thread.setName(getWorkerName()); currentThread.set(thread); BaseModelRequest req = null; - HttpResponseStatus status = HttpResponseStatus.INTERNAL_SERVER_ERROR; + int status = HttpURLConnection.HTTP_INTERNAL_ERROR; try { connect(); @@ -203,13 +204,11 @@ public void run() { break; case LOAD: if (reply.getCode() == 200) { - setState(WorkerState.WORKER_MODEL_LOADED, HttpResponseStatus.OK); + setState(WorkerState.WORKER_MODEL_LOADED, HttpURLConnection.HTTP_OK); backoffIdx = 0; } else { - setState( - WorkerState.WORKER_ERROR, - HttpResponseStatus.valueOf(reply.getCode())); - status = HttpResponseStatus.valueOf(reply.getCode()); + setState(WorkerState.WORKER_ERROR, reply.getCode()); + status = reply.getCode(); } break; case UNLOAD: @@ -241,7 +240,7 @@ public void run() { logger.error("Backend worker error", e); } catch (OutOfMemoryError oom) { logger.error("Out of memory error when creating workers", oom); - status = HttpResponseStatus.INSUFFICIENT_STORAGE; + status = HttpURLConnection.HTTP_ENTITY_TOO_LARGE; } catch (Throwable t) { logger.warn("Backend worker thread exception.", t); } finally { @@ -253,7 +252,7 @@ public void run() { Integer exitValue = lifeCycle.getExitValue(); if (exitValue != null && exitValue == 137) { - status = HttpResponseStatus.INSUFFICIENT_STORAGE; + status = HttpURLConnection.HTTP_ENTITY_TOO_LARGE; } if (req != null) { @@ -284,7 +283,7 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio String modelName = model.getModelName(); String modelVersion = model.getVersion(); - setState(WorkerState.WORKER_STARTED, HttpResponseStatus.OK); + setState(WorkerState.WORKER_STARTED, HttpURLConnection.HTTP_OK); final CountDownLatch latch = new CountDownLatch(1); final int responseBufferSize = configManager.getMaxResponseSize(); @@ -337,7 +336,7 @@ public void initChannel(Channel ch) { } Job job = - new Job( + new RestJob( null, modelName, modelVersion, @@ -379,7 +378,7 @@ public int getPid() { public void shutdown() { running.set(false); - setState(WorkerState.WORKER_SCALED_DOWN, HttpResponseStatus.OK); + setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK); if (backendChannel != null) { backendChannel.close(); } @@ -388,7 +387,7 @@ public void shutdown() { if (thread != null) { thread.interrupt(); aggregator.sendError( - null, "Worker scaled down.", HttpResponseStatus.INTERNAL_SERVER_ERROR); + null, "Worker scaled down.", HttpURLConnection.HTTP_INTERNAL_ERROR); model.removeJobQueue(workerId); } @@ -399,7 +398,7 @@ private String getWorkerName() { return "W-" + port + '-' + modelName; } - public void setState(WorkerState newState, HttpResponseStatus status) { + public void setState(WorkerState newState, int status) { listener.notifyChangeState( model.getModelVersionName().getVersionedModelName(), newState, status); logger.debug("{} State change {} -> {}", getWorkerName(), state, newState); @@ -447,7 +446,12 @@ public void channelRead0(ChannelHandlerContext ctx, ModelWorkerResponse msg) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { logger.error("Unknown exception", cause); if (cause instanceof OutOfMemoryError) { - NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, cause); + ModelWorkerResponse msg = new ModelWorkerResponse(); + msg.setCode(HttpURLConnection.HTTP_ENTITY_TOO_LARGE); + msg.setMessage(cause.getMessage()); + if (!replies.offer(msg)) { + throw new IllegalStateException("Reply queue is full."); + } } ctx.close(); } diff --git a/frontend/server/src/main/resources/proto/inference.proto b/frontend/server/src/main/resources/proto/inference.proto new file mode 100644 index 0000000000..1d3dad5589 --- /dev/null +++ b/frontend/server/src/main/resources/proto/inference.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package org.pytorch.serve.grpc.inference; + +import "google/protobuf/empty.proto"; + +option java_multiple_files = true; + +message PredictionsRequest { + // Name of model. + string model_name = 1; //required + + // Version of model to run prediction on. + string model_version = 2; //optional + + // input data for model prediction + map input = 3; //required +} + +message PredictionResponse { + // TorchServe health + bytes prediction = 1; +} + +message TorchServeHealthResponse { + // TorchServe health + string health = 1; +} + +service InferenceAPIsService { + rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {} + + // Predictions entry point to get inference using default model version. + rpc Predictions(PredictionsRequest) returns (PredictionResponse) {} +} \ No newline at end of file diff --git a/frontend/server/src/main/resources/proto/management.proto b/frontend/server/src/main/resources/proto/management.proto new file mode 100644 index 0000000000..54d8aa645d --- /dev/null +++ b/frontend/server/src/main/resources/proto/management.proto @@ -0,0 +1,114 @@ +syntax = "proto3"; + +package org.pytorch.serve.grpc.management; + +option java_multiple_files = true; + +message ManagementResponse { + // Response string of different management API calls. + string msg = 1; +} + +message DescribeModelRequest { + // Name of model to describe. + string model_name = 1; //required + // Version of model to describe. + string model_version = 2; //optional +} + +message ListModelsRequest { + // Use this parameter to specify the maximum number of items to return. When this value is present, TorchServe does not return more than the specified number of items, but it might return fewer. This value is optional. If you include a value, it must be between 1 and 1000, inclusive. If you do not include a value, it defaults to 100. + int32 limit = 1; //optional + + // The token to retrieve the next set of results. TorchServe provides the token when the response from a previous call has more results than the maximum page size. + int32 next_page_token = 2; //optional +} + +message RegisterModelRequest { + // Inference batch size, default: 1. + int32 batch_size = 1; //optional + + // Inference handler entry-point. This value will override handler in MANIFEST.json if present. + string handler = 2; //optional + + // Number of initial workers, default: 0. + int32 initial_workers = 3; //optional + + // Maximum delay for batch aggregation, default: 100. + int32 max_batch_delay = 4; //optional + + // Name of model. This value will override modelName in MANIFEST.json if present. + string model_name = 5; //optional + + // Maximum time, in seconds, the TorchServe waits for a response from the model inference code, default: 120. + int32 response_timeout = 6; //optional + + // Runtime for the model custom service code. This value will override runtime in MANIFEST.json if present. + string runtime = 7; //optional + + // Decides whether creation of worker synchronous or not, default: false. + bool synchronous = 8; //optional + + // Model archive download url, support local file or HTTP(s) protocol. + string url = 9; //required +} + +message ScaleWorkerRequest { + + // Name of model to scale workers. + string model_name = 1; //required + + // Model version. + string model_version = 2; //optional + + // Maximum number of worker processes. + int32 max_worker = 3; //optional + + // Minimum number of worker processes. + int32 min_worker = 4; //optional + + // Number of GPU worker processes to create. + int32 number_gpu = 5; //optional + + // Decides whether the call is synchronous or not, default: false. + bool synchronous = 6; //optional + + // Waiting up to the specified wait time if necessary for a worker to complete all pending requests. Use 0 to terminate backend worker process immediately. Use -1 for wait infinitely. + int32 timeout = 7; //optional +} + +message SetDefaultRequest { + // Name of model whose default version needs to be updated. + string model_name = 1; //required + + // Version of model to be set as default version for the model + string model_version = 2; //required +} + +message UnregisterModelRequest { + // Name of model to unregister. + string model_name = 1; //required + + // Name of model to unregister. + string model_version = 2; //optional +} + +service ManagementAPIsService { + // Provides detailed information about the default version of a model. + rpc DescribeModel(DescribeModelRequest) returns (ManagementResponse) {} + + // List registered models in TorchServe. + rpc ListModels(ListModelsRequest) returns (ManagementResponse) {} + + // Register a new model in TorchServe. + rpc RegisterModel(RegisterModelRequest) returns (ManagementResponse) {} + + // Configure number of workers for a default version of a model.This is a asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed. + rpc ScaleWorker(ScaleWorkerRequest) returns (ManagementResponse) {} + + // Set default version of a model + rpc SetDefault(SetDefaultRequest) returns (ManagementResponse) {} + + // Unregister the default version of a model from TorchServe if it is the only version available.This is a asynchronous call by default. Caller can call listModels to confirm model is unregistered + rpc UnregisterModel(UnregisterModelRequest) returns (ManagementResponse) {} +} \ No newline at end of file diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index ce8c78a005..0e199a4c72 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -79,7 +79,7 @@ public void beforeSuite() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); server = new ModelServer(configManager); - server.start(); + server.startRESTserver(); String version = configManager.getProperty("version", null); try (InputStream is = new FileInputStream("src/test/resources/inference_open_api.json")) { listInferenceApisResult = diff --git a/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java b/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java index 3a9ad348da..ea28e00cd3 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java @@ -63,7 +63,7 @@ public void beforeSuite() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); configManager.setIniitialWorkerPort(9500); server = new ModelServer(configManager); - server.start(); + server.startRESTserver(); } @AfterClass @@ -264,7 +264,7 @@ public void testStartTorchServeWithLastSnapshot() ConfigManager.init(new ConfigManager.Arguments()); configManager = ConfigManager.getInstance(); server = new ModelServer(configManager); - server.start(); + server.startRESTserver(); Channel channel = null; for (int i = 0; i < 5; ++i) { channel = TestUtils.connect(ConnectorType.INFERENCE_CONNECTOR, configManager); @@ -289,7 +289,7 @@ public void testRestartTorchServeWithSnapshotAsConfig() ConfigManager.init(new ConfigManager.Arguments()); configManager = ConfigManager.getInstance(); server = new ModelServer(configManager); - server.start(); + server.startRESTserver(); Channel channel = null; for (int i = 0; i < 5; ++i) { channel = TestUtils.connect(ConnectorType.INFERENCE_CONNECTOR, configManager); diff --git a/frontend/tools/conf/checkstyle.xml b/frontend/tools/conf/checkstyle.xml index 6ca74b177f..5a6aae5a0b 100644 --- a/frontend/tools/conf/checkstyle.xml +++ b/frontend/tools/conf/checkstyle.xml @@ -416,7 +416,9 @@ - + + +