diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md new file mode 100644 index 0000000000..10e346e53b --- /dev/null +++ b/docs/token_authorization_api.md @@ -0,0 +1,50 @@ +# TorchServe token authorization API + +## Configuration +1. Enable token authorization by adding the provided plugin at start using the `--plugins-path` command. +2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated. + 1. Example key file: + +```python + { + "management": { + "key": "B-E5KSRM", + "expiration time": "2024-02-16T21:12:24.801167Z" + }, + "inference": { + "key": "gNRuA7dS", + "expiration time": "2024-02-16T21:12:24.801148Z" + }, + "API": { + "key": "yv9uQajP" + } +} +``` + +3. There are 3 keys and each have a different use. + 1. Management key: Used for management APIs. Example: + `curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer I_J_ItMb"` + 2. Inference key: Used for inference APIs. Example: + `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"` + 3. API key: Used for the token authorization API. Check section 4 for API use. +4. The plugin also includes an API in order to generate a new key to replace either the management or inference key. + 1. Management Example: + `curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time. + 2. Inference example: + `curl localhost:8081/token?type=inference -H "Authorization: Bearer m4M-5IBY"` + + Users will have to use either one of the APIs above. + +5. When users shut down the server the key_file will be deleted. + + +## Customization +Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result. +1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30` +2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is: + 1. The urlPattern for the plugin must be 'token' and the class name must not change + 2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity. + +## Notes +1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file. +2. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models. diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java new file mode 100644 index 0000000000..1045264e3e --- /dev/null +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java @@ -0,0 +1,32 @@ +package org.pytorch.serve.archive.model; + +public class InvalidKeyException extends ModelException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public InvalidKeyException(String message) { + super(message); + } + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message and cause. + * + *
Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public InvalidKeyException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java index 38b81fdb4c..b362ceb958 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java @@ -10,6 +10,7 @@ import org.pytorch.serve.http.HttpRequestHandler; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.http.InvalidRequestHandler; +import org.pytorch.serve.http.TokenAuthorizationHandler; import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler; import org.pytorch.serve.http.api.rest.InferenceRequestHandler; import org.pytorch.serve.http.api.rest.ManagementRequestHandler; @@ -18,6 +19,7 @@ import org.pytorch.serve.servingsdk.impl.PluginsManager; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.ConnectorType; +import org.pytorch.serve.util.TokenType; import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler; import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler; import org.slf4j.Logger; @@ -63,6 +65,9 @@ public void initChannel(Channel ch) { HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler; if (ConnectorType.ALL.equals(connectorType) || ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) { + httpRequestHandlerChain = + httpRequestHandlerChain.setNextHandler( + new TokenAuthorizationHandler(TokenType.INFERENCE)); httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new InferenceRequestHandler( @@ -80,6 +85,9 @@ public void initChannel(Channel ch) { } if (ConnectorType.ALL.equals(connectorType) || ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) { + httpRequestHandlerChain = + httpRequestHandlerChain.setNextHandler( + new TokenAuthorizationHandler(TokenType.MANAGEMENT)); httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new ManagementRequestHandler( diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java new file mode 100644 index 0000000000..cab59f11b5 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -0,0 +1,103 @@ +package org.pytorch.serve.http; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; +import java.lang.reflect.*; +import org.pytorch.serve.archive.DownloadArchiveException; +import org.pytorch.serve.archive.model.InvalidKeyException; +import org.pytorch.serve.archive.model.ModelException; +import org.pytorch.serve.archive.workflow.WorkflowException; +import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.TokenType; +import org.pytorch.serve.wlm.WorkerInitializationException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class handling token check for all inbound HTTP requests + * + *
This class //
+ */
+public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
+
+ private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
+ private static TokenType tokenType;
+ private static Boolean tokenEnabled = false;
+ private static Class> tokenClass;
+ private static Object tokenObject;
+ private static Double timeToExpirationMinutes = 60.0;
+
+ /** Creates a new {@code InferenceRequestHandler} instance. */
+ public TokenAuthorizationHandler(TokenType type) {
+ tokenType = type;
+ }
+
+ @Override
+ public void handleRequest(
+ ChannelHandlerContext ctx,
+ FullHttpRequest req,
+ QueryStringDecoder decoder,
+ String[] segments)
+ throws ModelException, DownloadArchiveException, WorkflowException,
+ WorkerInitializationException {
+ if (tokenEnabled) {
+ if (tokenType == TokenType.MANAGEMENT) {
+ if (req.toString().contains("/token")) {
+ checkTokenAuthorization(req, "token");
+ } else {
+ checkTokenAuthorization(req, "management");
+ }
+ } else if (tokenType == TokenType.INFERENCE) {
+ checkTokenAuthorization(req, "inference");
+ }
+ }
+ chain.handleRequest(ctx, req, decoder, segments);
+ }
+
+ public static void setupTokenClass() {
+ try {
+ tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
+ tokenObject = tokenClass.getDeclaredConstructor().newInstance();
+ Method method = tokenClass.getMethod("setTime", Double.class);
+ Double time = ConfigManager.getInstance().getTimeToExpiration();
+ if (time != 0.0) {
+ timeToExpirationMinutes = time;
+ }
+ method.invoke(tokenObject, timeToExpirationMinutes);
+ method = tokenClass.getMethod("generateKeyFile", String.class);
+ if ((boolean) method.invoke(tokenObject, "token")) {
+ logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY");
+ }
+ } catch (NoSuchMethodException
+ | IllegalAccessException
+ | InstantiationException
+ | InvocationTargetException
+ | ClassNotFoundException e) {
+ e.printStackTrace();
+ logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
+ throw new IllegalStateException("Unable to import token class", e);
+ }
+ tokenEnabled = true;
+ }
+
+ private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {
+
+ try {
+ Method method =
+ tokenClass.getMethod(
+ "checkTokenAuthorization",
+ io.netty.handler.codec.http.FullHttpRequest.class,
+ String.class);
+ boolean result = (boolean) (method.invoke(tokenObject, req, type));
+ if (!result) {
+ throw new InvalidKeyException(
+ "Token Authentication failed. Token either incorrect, expired, or not provided correctly");
+ }
+ } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+ e.printStackTrace();
+ throw new InvalidKeyException(
+ "Token Authentication failed. Token either incorrect, expired, or not provided correctly");
+ }
+ }
+}
diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java
index 05422dafa8..def4be404c 100644
--- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java
+++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java
@@ -30,7 +30,6 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {
-
if (isApiDescription(segments)) {
String path = decoder.path();
if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method()))
diff --git a/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java b/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java
index bac20bf32d..fdf6b01dd0 100644
--- a/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java
+++ b/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java
@@ -5,6 +5,7 @@
import java.util.Map;
import java.util.ServiceLoader;
import org.pytorch.serve.http.InvalidPluginException;
+import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.servingsdk.annotations.Endpoint;
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
@@ -30,6 +31,9 @@ public void initialize() {
logger.info("Initializing plugins manager...");
inferenceEndpoints = initInferenceEndpoints();
managementEndpoints = initManagementEndpoints();
+ if (managementEndpoints.containsKey("token")) {
+ TokenAuthorizationHandler.setupTokenClass();
+ }
}
private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) {
diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java
index b106c03b3e..849b895ddd 100644
--- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java
+++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java
@@ -107,6 +107,7 @@ public final class ConfigManager {
private static final String TS_WORKFLOW_STORE = "workflow_store";
private static final String TS_CPP_LOG_CONFIG = "cpp_log_config";
private static final String TS_OPEN_INFERENCE_PROTOCOL = "ts_open_inference_protocol";
+ private static final String TS_TOKEN_EXPIRATION_TIME_MIN = "token_expiration_min";
// Configuration which are not documented or enabled through environment variables
private static final String USE_NATIVE_IO = "use_native_io";
@@ -859,6 +860,17 @@ public boolean isSnapshotDisabled() {
return snapshotDisabled;
}
+ public Double getTimeToExpiration() {
+ if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) {
+ try {
+ return Double.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN));
+ } catch (NumberFormatException e) {
+ logger.error("Token expiration not a valid integer");
+ }
+ }
+ return 0.0;
+ }
+
public boolean isSSLEnabled(ConnectorType connectorType) {
String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080");
switch (connectorType) {
diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java
new file mode 100644
index 0000000000..9bf6318d2e
--- /dev/null
+++ b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java
@@ -0,0 +1,7 @@
+package org.pytorch.serve.util;
+
+public enum TokenType {
+ INFERENCE,
+ MANAGEMENT,
+ TOKEN_API
+}
diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java
index 042bef68e1..916e723c81 100644
--- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java
+++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java
@@ -80,6 +80,7 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {
+
if ("wfpredict".equalsIgnoreCase(segments[1])) {
if (segments.length < 3) {
throw new ResourceNotFoundException();
diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java
index b50f5891b7..6125de61e7 100644
--- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java
+++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java
@@ -63,6 +63,7 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {
+
if (isManagementReq(segments)) {
if (!"workflows".equals(segments[1])) {
throw new ResourceNotFoundException();
diff --git a/plugins/endpoints/build.gradle b/plugins/endpoints/build.gradle
index 36c5191a6d..a162cc2ae8 100644
--- a/plugins/endpoints/build.gradle
+++ b/plugins/endpoints/build.gradle
@@ -1,6 +1,7 @@
dependencies {
implementation "com.google.code.gson:gson:${gson_version}"
implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
+ implementation "io.netty:netty-all:4.1.53.Final"
}
project.ext{
@@ -16,4 +17,3 @@ jar {
exclude "META-INF//LICENSE*"
exclude "META-INF//NOTICE*"
}
-
diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java
new file mode 100644
index 0000000000..805b6169ff
--- /dev/null
+++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java
@@ -0,0 +1,223 @@
+package org.pytorch.serve.plugins.endpoint;
+
+// import java.util.Properties;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonObject;
+import io.netty.handler.codec.http.FullHttpRequest;
+import io.netty.handler.codec.http.QueryStringDecoder;
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.nio.file.attribute.PosixFilePermission;
+import java.nio.file.attribute.PosixFilePermissions;
+import java.security.SecureRandom;
+import java.time.Instant;
+import java.util.Base64;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.pytorch.serve.servingsdk.Context;
+import org.pytorch.serve.servingsdk.ModelServerEndpoint;
+import org.pytorch.serve.servingsdk.annotations.Endpoint;
+import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
+import org.pytorch.serve.servingsdk.http.Request;
+import org.pytorch.serve.servingsdk.http.Response;
+
+// import org.pytorch.serve.util.TokenType;
+
+@Endpoint(
+ urlPattern = "token",
+ endpointType = EndpointTypes.MANAGEMENT,
+ description = "Token authentication endpoint")
+public class Token extends ModelServerEndpoint {
+ private static String apiKey;
+ private static String managementKey;
+ private static String inferenceKey;
+ private static Instant managementExpirationTimeMinutes;
+ private static Instant inferenceExpirationTimeMinutes;
+ private static Double timeToExpirationMinutes;
+ private SecureRandom secureRandom = new SecureRandom();
+ private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
+ private String fileName = "key_file.json";
+
+ @Override
+ public void doGet(Request req, Response rsp, Context ctx) throws IOException {
+ String queryResponse = parseQuery(req);
+ String test = "";
+ if ("management".equals(queryResponse)) {
+ generateKeyFile("management");
+ } else if ("inference".equals(queryResponse)) {
+ generateKeyFile("inference");
+ } else {
+ test = "{\n\t\"Error\": " + queryResponse + "\n}\n";
+ }
+ rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8));
+ }
+
+ // parses query and either returns management/inference or a wrong type error
+ public String parseQuery(Request req) {
+ QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI());
+ Map