diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md new file mode 100644 index 0000000000..10e346e53b --- /dev/null +++ b/docs/token_authorization_api.md @@ -0,0 +1,50 @@ +# TorchServe token authorization API + +## Configuration +1. Enable token authorization by adding the provided plugin at start using the `--plugins-path` command. +2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated. + 1. Example key file: + +```python + { + "management": { + "key": "B-E5KSRM", + "expiration time": "2024-02-16T21:12:24.801167Z" + }, + "inference": { + "key": "gNRuA7dS", + "expiration time": "2024-02-16T21:12:24.801148Z" + }, + "API": { + "key": "yv9uQajP" + } +} +``` + +3. There are 3 keys and each have a different use. + 1. Management key: Used for management APIs. Example: + `curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer I_J_ItMb"` + 2. Inference key: Used for inference APIs. Example: + `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"` + 3. API key: Used for the token authorization API. Check section 4 for API use. +4. The plugin also includes an API in order to generate a new key to replace either the management or inference key. + 1. Management Example: + `curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time. + 2. Inference example: + `curl localhost:8081/token?type=inference -H "Authorization: Bearer m4M-5IBY"` + + Users will have to use either one of the APIs above. + +5. When users shut down the server the key_file will be deleted. + + +## Customization +Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result. +1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30` +2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is: + 1. The urlPattern for the plugin must be 'token' and the class name must not change + 2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity. + +## Notes +1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file. +2. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models. diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java new file mode 100644 index 0000000000..1045264e3e --- /dev/null +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java @@ -0,0 +1,32 @@ +package org.pytorch.serve.archive.model; + +public class InvalidKeyException extends ModelException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public InvalidKeyException(String message) { + super(message); + } + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public InvalidKeyException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java index 38b81fdb4c..b362ceb958 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java @@ -10,6 +10,7 @@ import org.pytorch.serve.http.HttpRequestHandler; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.http.InvalidRequestHandler; +import org.pytorch.serve.http.TokenAuthorizationHandler; import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler; import org.pytorch.serve.http.api.rest.InferenceRequestHandler; import org.pytorch.serve.http.api.rest.ManagementRequestHandler; @@ -18,6 +19,7 @@ import org.pytorch.serve.servingsdk.impl.PluginsManager; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.ConnectorType; +import org.pytorch.serve.util.TokenType; import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler; import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler; import org.slf4j.Logger; @@ -63,6 +65,9 @@ public void initChannel(Channel ch) { HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler; if (ConnectorType.ALL.equals(connectorType) || ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) { + httpRequestHandlerChain = + httpRequestHandlerChain.setNextHandler( + new TokenAuthorizationHandler(TokenType.INFERENCE)); httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new InferenceRequestHandler( @@ -80,6 +85,9 @@ public void initChannel(Channel ch) { } if (ConnectorType.ALL.equals(connectorType) || ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) { + httpRequestHandlerChain = + httpRequestHandlerChain.setNextHandler( + new TokenAuthorizationHandler(TokenType.MANAGEMENT)); httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new ManagementRequestHandler( diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java new file mode 100644 index 0000000000..cab59f11b5 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -0,0 +1,103 @@ +package org.pytorch.serve.http; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; +import java.lang.reflect.*; +import org.pytorch.serve.archive.DownloadArchiveException; +import org.pytorch.serve.archive.model.InvalidKeyException; +import org.pytorch.serve.archive.model.ModelException; +import org.pytorch.serve.archive.workflow.WorkflowException; +import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.TokenType; +import org.pytorch.serve.wlm.WorkerInitializationException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class handling token check for all inbound HTTP requests + * + *

This class // + */ +public class TokenAuthorizationHandler extends HttpRequestHandlerChain { + + private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class); + private static TokenType tokenType; + private static Boolean tokenEnabled = false; + private static Class tokenClass; + private static Object tokenObject; + private static Double timeToExpirationMinutes = 60.0; + + /** Creates a new {@code InferenceRequestHandler} instance. */ + public TokenAuthorizationHandler(TokenType type) { + tokenType = type; + } + + @Override + public void handleRequest( + ChannelHandlerContext ctx, + FullHttpRequest req, + QueryStringDecoder decoder, + String[] segments) + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { + if (tokenEnabled) { + if (tokenType == TokenType.MANAGEMENT) { + if (req.toString().contains("/token")) { + checkTokenAuthorization(req, "token"); + } else { + checkTokenAuthorization(req, "management"); + } + } else if (tokenType == TokenType.INFERENCE) { + checkTokenAuthorization(req, "inference"); + } + } + chain.handleRequest(ctx, req, decoder, segments); + } + + public static void setupTokenClass() { + try { + tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); + tokenObject = tokenClass.getDeclaredConstructor().newInstance(); + Method method = tokenClass.getMethod("setTime", Double.class); + Double time = ConfigManager.getInstance().getTimeToExpiration(); + if (time != 0.0) { + timeToExpirationMinutes = time; + } + method.invoke(tokenObject, timeToExpirationMinutes); + method = tokenClass.getMethod("generateKeyFile", String.class); + if ((boolean) method.invoke(tokenObject, "token")) { + logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY"); + } + } catch (NoSuchMethodException + | IllegalAccessException + | InstantiationException + | InvocationTargetException + | ClassNotFoundException e) { + e.printStackTrace(); + logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); + throw new IllegalStateException("Unable to import token class", e); + } + tokenEnabled = true; + } + + private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException { + + try { + Method method = + tokenClass.getMethod( + "checkTokenAuthorization", + io.netty.handler.codec.http.FullHttpRequest.class, + String.class); + boolean result = (boolean) (method.invoke(tokenObject, req, type)); + if (!result) { + throw new InvalidKeyException( + "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); + } + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); + throw new InvalidKeyException( + "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); + } + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java index 05422dafa8..def4be404c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java @@ -30,7 +30,6 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - if (isApiDescription(segments)) { String path = decoder.path(); if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method())) diff --git a/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java b/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java index bac20bf32d..fdf6b01dd0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java @@ -5,6 +5,7 @@ import java.util.Map; import java.util.ServiceLoader; import org.pytorch.serve.http.InvalidPluginException; +import org.pytorch.serve.http.TokenAuthorizationHandler; import org.pytorch.serve.servingsdk.ModelServerEndpoint; import org.pytorch.serve.servingsdk.annotations.Endpoint; import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes; @@ -30,6 +31,9 @@ public void initialize() { logger.info("Initializing plugins manager..."); inferenceEndpoints = initInferenceEndpoints(); managementEndpoints = initManagementEndpoints(); + if (managementEndpoints.containsKey("token")) { + TokenAuthorizationHandler.setupTokenClass(); + } } private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index b106c03b3e..849b895ddd 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -107,6 +107,7 @@ public final class ConfigManager { private static final String TS_WORKFLOW_STORE = "workflow_store"; private static final String TS_CPP_LOG_CONFIG = "cpp_log_config"; private static final String TS_OPEN_INFERENCE_PROTOCOL = "ts_open_inference_protocol"; + private static final String TS_TOKEN_EXPIRATION_TIME_MIN = "token_expiration_min"; // Configuration which are not documented or enabled through environment variables private static final String USE_NATIVE_IO = "use_native_io"; @@ -859,6 +860,17 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } + public Double getTimeToExpiration() { + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) { + try { + return Double.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN)); + } catch (NumberFormatException e) { + logger.error("Token expiration not a valid integer"); + } + } + return 0.0; + } + public boolean isSSLEnabled(ConnectorType connectorType) { String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080"); switch (connectorType) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java new file mode 100644 index 0000000000..9bf6318d2e --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java @@ -0,0 +1,7 @@ +package org.pytorch.serve.util; + +public enum TokenType { + INFERENCE, + MANAGEMENT, + TOKEN_API +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index 042bef68e1..916e723c81 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -80,6 +80,7 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + if ("wfpredict".equalsIgnoreCase(segments[1])) { if (segments.length < 3) { throw new ResourceNotFoundException(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java index b50f5891b7..6125de61e7 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java @@ -63,6 +63,7 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + if (isManagementReq(segments)) { if (!"workflows".equals(segments[1])) { throw new ResourceNotFoundException(); diff --git a/plugins/endpoints/build.gradle b/plugins/endpoints/build.gradle index 36c5191a6d..a162cc2ae8 100644 --- a/plugins/endpoints/build.gradle +++ b/plugins/endpoints/build.gradle @@ -1,6 +1,7 @@ dependencies { implementation "com.google.code.gson:gson:${gson_version}" implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}" + implementation "io.netty:netty-all:4.1.53.Final" } project.ext{ @@ -16,4 +17,3 @@ jar { exclude "META-INF//LICENSE*" exclude "META-INF//NOTICE*" } - diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java new file mode 100644 index 0000000000..805b6169ff --- /dev/null +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -0,0 +1,223 @@ +package org.pytorch.serve.plugins.endpoint; + +// import java.util.Properties; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonObject; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.attribute.PosixFilePermission; +import java.nio.file.attribute.PosixFilePermissions; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.pytorch.serve.servingsdk.Context; +import org.pytorch.serve.servingsdk.ModelServerEndpoint; +import org.pytorch.serve.servingsdk.annotations.Endpoint; +import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes; +import org.pytorch.serve.servingsdk.http.Request; +import org.pytorch.serve.servingsdk.http.Response; + +// import org.pytorch.serve.util.TokenType; + +@Endpoint( + urlPattern = "token", + endpointType = EndpointTypes.MANAGEMENT, + description = "Token authentication endpoint") +public class Token extends ModelServerEndpoint { + private static String apiKey; + private static String managementKey; + private static String inferenceKey; + private static Instant managementExpirationTimeMinutes; + private static Instant inferenceExpirationTimeMinutes; + private static Double timeToExpirationMinutes; + private SecureRandom secureRandom = new SecureRandom(); + private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); + private String fileName = "key_file.json"; + + @Override + public void doGet(Request req, Response rsp, Context ctx) throws IOException { + String queryResponse = parseQuery(req); + String test = ""; + if ("management".equals(queryResponse)) { + generateKeyFile("management"); + } else if ("inference".equals(queryResponse)) { + generateKeyFile("inference"); + } else { + test = "{\n\t\"Error\": " + queryResponse + "\n}\n"; + } + rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8)); + } + + // parses query and either returns management/inference or a wrong type error + public String parseQuery(Request req) { + QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI()); + Map> parameters = decoder.parameters(); + List values = parameters.get("type"); + if (values != null && !values.isEmpty()) { + if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) { + return values.get(0); + } else { + return "WRONG TYPE"; + } + } + return "NO TYPE PROVIDED"; + } + + public String generateKey() { + byte[] randomBytes = new byte[6]; + secureRandom.nextBytes(randomBytes); + return baseEncoder.encodeToString(randomBytes); + } + + public Instant generateTokenExpiration() { + long secondsToAdd = (long) (timeToExpirationMinutes * 60); + return Instant.now().plusSeconds(secondsToAdd); + } + + // generates a key file with new keys depending on the parameter provided + public boolean generateKeyFile(String type) throws IOException { + String userDirectory = System.getProperty("user.dir") + "/" + fileName; + File file = new File(userDirectory); + if (!file.createNewFile() && !file.exists()) { + return false; + } + if (apiKey == null) { + apiKey = generateKey(); + } + switch (type) { + case "management": + managementKey = generateKey(); + managementExpirationTimeMinutes = generateTokenExpiration(); + break; + case "inference": + inferenceKey = generateKey(); + inferenceExpirationTimeMinutes = generateTokenExpiration(); + break; + default: + managementKey = generateKey(); + inferenceKey = generateKey(); + inferenceExpirationTimeMinutes = generateTokenExpiration(); + managementExpirationTimeMinutes = generateTokenExpiration(); + } + + JsonObject parentObject = new JsonObject(); + + JsonObject managementObject = new JsonObject(); + managementObject.addProperty("key", managementKey); + managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString()); + parentObject.add("management", managementObject); + + JsonObject inferenceObject = new JsonObject(); + inferenceObject.addProperty("key", inferenceKey); + inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString()); + parentObject.add("inference", inferenceObject); + + JsonObject apiObject = new JsonObject(); + apiObject.addProperty("key", apiKey); + parentObject.add("API", apiObject); + + Files.write( + Paths.get(fileName), + new GsonBuilder() + .setPrettyPrinting() + .create() + .toJson(parentObject) + .getBytes(StandardCharsets.UTF_8)); + + if (!setFilePermissions()) { + try { + Files.delete(Paths.get(fileName)); + } catch (IOException e) { + return false; + } + return false; + } + return true; + } + + public boolean setFilePermissions() { + Path path = Paths.get(fileName); + try { + Set permissions = PosixFilePermissions.fromString("rw-------"); + Files.setPosixFilePermissions(path, permissions); + } catch (Exception e) { + return false; + } + return true; + } + + // checks the token provided in the http with the saved keys depening on parameters + public boolean checkTokenAuthorization(FullHttpRequest req, String type) { + String key; + Instant expiration; + switch (type) { + case "token": + key = apiKey; + expiration = null; + break; + case "management": + key = managementKey; + expiration = managementExpirationTimeMinutes; + break; + default: + key = inferenceKey; + expiration = inferenceExpirationTimeMinutes; + } + + String tokenBearer = req.headers().get("Authorization"); + if (tokenBearer == null) { + return false; + } + String[] arrOfStr = tokenBearer.split(" ", 2); + if (arrOfStr.length == 1) { + return false; + } + String token = arrOfStr[1]; + + if (token.equals(key)) { + if (expiration != null && isTokenExpired(expiration)) { + return false; + } + } else { + return false; + } + return true; + } + + public boolean isTokenExpired(Instant expirationTime) { + return !(Instant.now().isBefore(expirationTime)); + } + + public String getManagementKey() { + return managementKey; + } + + public String getInferenceKey() { + return inferenceKey; + } + + public String getKey() { + return apiKey; + } + + public Instant getInferenceExpirationTime() { + return inferenceExpirationTimeMinutes; + } + + public Instant getManagementExpirationTime() { + return managementExpirationTimeMinutes; + } + + public void setTime(Double time) { + timeToExpirationMinutes = time; + } +} diff --git a/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint b/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint index 88099c457c..887b1a0d3d 100644 --- a/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint +++ b/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint @@ -1 +1,2 @@ org.pytorch.serve.plugins.endpoint.ExecutionParameters +org.pytorch.serve.plugins.endpoint.Token diff --git a/plugins/gradle/wrapper/gradle-wrapper.properties b/plugins/gradle/wrapper/gradle-wrapper.properties index 7191a69876..62036368e7 100644 --- a/plugins/gradle/wrapper/gradle-wrapper.properties +++ b/plugins/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.4-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.3-all.zip diff --git a/plugins/settings.gradle b/plugins/settings.gradle index fd3dc1373c..1e949f6cf7 100644 --- a/plugins/settings.gradle +++ b/plugins/settings.gradle @@ -9,5 +9,4 @@ rootProject.name = 'plugins' include 'endpoints' -include 'DDBEndPoint' - +// include 'DDBEndPoint' diff --git a/test/pytest/test_data/0.png b/test/pytest/test_data/0.png new file mode 100644 index 0000000000..a193c47ba4 Binary files /dev/null and b/test/pytest/test_data/0.png differ diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py new file mode 100644 index 0000000000..4fa6a4799e --- /dev/null +++ b/test/pytest/test_token_authorization.py @@ -0,0 +1,218 @@ +import json +import os +import shutil +import subprocess +import tempfile +import time +from pathlib import Path + +import pytest +import requests +import test_utils + +ROOT_DIR = os.path.join(tempfile.gettempdir(), "workspace") +REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") +data_file_zero = os.path.join(REPO_ROOT, "test/pytest/test_data/0.png") +config_file = os.path.join(REPO_ROOT, "test/resources/config_token.properties") + + +# Set up token plugin +def get_plugin_jar(): + new_folder_path = os.path.join(ROOT_DIR, "plugins-path") + plugin_folder = os.path.join(REPO_ROOT, "plugins") + os.makedirs(new_folder_path, exist_ok=True) + os.chdir(plugin_folder) + subprocess.run(["./gradlew", "formatJava"]) + result = subprocess.run(["./gradlew", "build"]) + jar_path = os.path.join(plugin_folder, "endpoints/build/libs") + jar_file = [file for file in os.listdir(jar_path) if file.endswith(".jar")] + if jar_file: + shutil.move( + os.path.join(jar_path, jar_file[0]), + os.path.join(new_folder_path, jar_file[0]), + ) + os.chdir(REPO_ROOT) + + +# Parse json file and return key +def read_key_file(type): + json_file_path = os.path.join(REPO_ROOT, "key_file.json") + with open(json_file_path) as json_file: + json_data = json.load(json_file) + + options = { + "management": json_data.get("management", {}).get("key", "NOT_PRESENT"), + "inference": json_data.get("inference", {}).get("key", "NOT_PRESENT"), + "token": json_data.get("API", {}).get("key", "NOT_PRESENT"), + } + key = options.get(type, "Invalid data type") + return key + + +@pytest.fixture(scope="module") +def setup_torchserve(): + get_plugin_jar() + MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") + PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") + + Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True) + + test_utils.start_torchserve(no_config_snapshots=True, plugin_folder=PLUGIN_STORE) + + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + + params = ( + ("model_name", "mnist"), + ("url", "mnist.mar"), + ("initial_workers", "1"), + ("synchronous", "true"), + ) + response = requests.post( + "http://localhost:8081/models", params=params, headers=header + ) + file_content = Path(f"{REPO_ROOT}/key_file.json").read_text() + print(file_content) + + yield "test" + + test_utils.stop_torchserve() + + +@pytest.fixture(scope="module") +def setup_torchserve_expiration(): + get_plugin_jar() + MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") + PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") + + Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True) + + test_utils.start_torchserve( + snapshot_file=config_file, no_config_snapshots=True, plugin_folder=PLUGIN_STORE + ) + + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + + params = ( + ("model_name", "mnist"), + ("url", "mnist.mar"), + ("initial_workers", "1"), + ("synchronous", "true"), + ) + response = requests.post( + "http://localhost:8081/models", params=params, headers=header + ) + file_content = Path(f"{REPO_ROOT}/key_file.json").read_text() + print(file_content) + + yield "test" + + test_utils.stop_torchserve() + + +# Test describe model API with token enabled +def test_managament_api_with_token(setup_torchserve): + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + response = requests.get("http://localhost:8081/models/mnist", headers=header) + + assert response.status_code == 200, "Token check failed" + + +# Test describe model API with incorrect token and no token +def test_managament_api_with_incorrect_token(setup_torchserve): + # Using random key + header = {"Authorization": "Bearer abcd1234"} + response = requests.get(f"http://localhost:8081/models/mnist", headers=header) + + assert response.status_code == 400, "Token check failed" + + +# Test inference API with token enabled +def test_inference_api_with_token(setup_torchserve): + key = read_key_file("inference") + header = {"Authorization": f"Bearer {key}"} + + response = requests.post( + url="http://localhost:8080/predictions/mnist", + files={"data": open(data_file_zero, "rb")}, + headers=header, + ) + + assert response.status_code == 200, "Token check failed" + + +# Test inference API with incorrect token +def test_inference_api_with_incorrect_token(setup_torchserve): + # Using random key + header = {"Authorization": "Bearer abcd1234"} + + response = requests.post( + url="http://localhost:8080/predictions/mnist", + files={"data": open(data_file_zero, "rb")}, + headers=header, + ) + + assert response.status_code == 400, "Token check failed" + + +# Test Token API for regenerating new inference key +def test_token_inference_api(setup_torchserve): + token_key = read_key_file("token") + inference_key = read_key_file("inference") + header_inference = {"Authorization": f"Bearer {inference_key}"} + header_token = {"Authorization": f"Bearer {token_key}"} + params = {"type": "inference"} + + # check inference works with current token + response = requests.post( + url="http://localhost:8080/predictions/mnist", + files={"data": open(data_file_zero, "rb")}, + headers=header_inference, + ) + assert response.status_code == 200, "Token check failed" + + # generate new inference token and check it is different + response = requests.get( + url="http://localhost:8081/token", params=params, headers=header_token + ) + assert response.status_code == 200, "Token check failed" + assert inference_key != read_key_file("inference"), "Key file not updated" + + # check inference does not works with original token + response = requests.post( + url="http://localhost:8080/predictions/mnist", + files={"data": open(data_file_zero, "rb")}, + headers=header_inference, + ) + assert response.status_code == 400, "Token check failed" + + +# Test Token API for regenerating new management key +def test_token_management_api(setup_torchserve): + token_key = read_key_file("token") + management_key = read_key_file("management") + header = {"Authorization": f"Bearer {token_key}"} + params = {"type": "management"} + + response = requests.get( + url="http://localhost:8081/token", params=params, headers=header + ) + + assert management_key != read_key_file("management"), "Key file not updated" + assert response.status_code == 200, "Token check failed" + + +# Test expiration time +@pytest.mark.module2 +def test_token_expiration_time(setup_torchserve_expiration): + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + response = requests.get("http://localhost:8081/models/mnist", headers=header) + assert response.status_code == 200, "Token check failed" + + time.sleep(15) + + response = requests.get("http://localhost:8081/models/mnist", headers=header) + assert response.status_code == 400, "Token check failed" diff --git a/test/pytest/test_utils.py b/test/pytest/test_utils.py index 417bba460c..70317b27df 100644 --- a/test/pytest/test_utils.py +++ b/test/pytest/test_utils.py @@ -54,7 +54,11 @@ def run(self): def start_torchserve( - model_store=None, snapshot_file=None, no_config_snapshots=False, gen_mar=True + model_store=None, + snapshot_file=None, + no_config_snapshots=False, + gen_mar=True, + plugin_folder=None, ): stop_torchserve() crate_mar_file_table() @@ -63,6 +67,8 @@ def start_torchserve( if gen_mar: mg.gen_mar(model_store) cmd.extend(["--model-store", model_store]) + if plugin_folder: + cmd.extend(["--plugins-path", plugin_folder]) if snapshot_file: cmd.extend(["--ts-config", snapshot_file]) if no_config_snapshots: diff --git a/test/resources/config_token.properties b/test/resources/config_token.properties new file mode 100644 index 0000000000..b62cd26870 --- /dev/null +++ b/test/resources/config_token.properties @@ -0,0 +1 @@ +token_expiration_min=0.25 diff --git a/ts/model_server.py b/ts/model_server.py index a5ab224f7b..7291c250c4 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -3,6 +3,7 @@ """ import os +import pathlib import platform import re import subprocess @@ -48,6 +49,7 @@ def start() -> None: try: parent = psutil.Process(pid) parent.terminate() + pathlib.Path("key_file.json").unlink(missing_ok=True) if args.foreground: try: parent.wait(timeout=60) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 3cfa9e6840..aa050a38fb 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1187,3 +1187,11 @@ FxGraphCache TorchInductor fx locustapache +FINhR +IBY +ItMb +checkTokenAuthorization +fj +generateKeyFile +setTime +urlPattern