From 4ddb74e59ff3039ed4c140b5ea1f7908dfd6eddd Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 5 Dec 2023 18:10:56 +0000 Subject: [PATCH 01/28] token authorization update --- .../archive/model/KeyTimeOutException.java | 32 +++++ .../archive/model/s3/InvalidKeyException.java | 32 +++++ .../rest/ApiDescriptionRequestHandler.java | 4 +- .../api/rest/InferenceRequestHandler.java | 2 + .../api/rest/ManagementRequestHandler.java | 5 + .../rest/PrometheusMetricsRequestHandler.java | 3 + .../org/pytorch/serve/util/ConfigManager.java | 117 ++++++++++++++++++ .../http/WorkflowInferenceRequestHandler.java | 4 + .../api/http/WorkflowMgmtRequestHandler.java | 4 + ts/arg_parser.py | 6 + ts/model_server.py | 3 + 11 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java create mode 100644 frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java new file mode 100644 index 0000000000..ee9da9c3bb --- /dev/null +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java @@ -0,0 +1,32 @@ +package org.pytorch.serve.archive.model; + +public class KeyTimeOutException extends ModelException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs an {@code KeyTimeOutException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public KeyTimeOutException(String message) { + super(message); + } + + /** + * Constructs an {@code KeyTimeOutException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public KeyTimeOutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java new file mode 100644 index 0000000000..1045264e3e --- /dev/null +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java @@ -0,0 +1,32 @@ +package org.pytorch.serve.archive.model; + +public class InvalidKeyException extends ModelException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public InvalidKeyException(String message) { + super(message); + } + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public InvalidKeyException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java index 05422dafa8..013a53622f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java @@ -12,6 +12,7 @@ import org.pytorch.serve.openapi.OpenApiUtils; import org.pytorch.serve.util.ConnectorType; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.wlm.WorkerInitializationException; public class ApiDescriptionRequestHandler extends HttpRequestHandlerChain { @@ -30,7 +31,8 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - + ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req); if (isApiDescription(segments)) { String path = decoder.path(); if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method())) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 363717cb7f..413c7816a7 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -59,6 +59,8 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req); if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index 29a6f156cf..577e1dc269 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -15,6 +15,8 @@ import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; +import org.pytorch.serve.archive.model.KeyTimeOutException; +import org.pytorch.serve.archive.model.InvalidKeyException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; import org.pytorch.serve.archive.workflow.WorkflowException; import org.pytorch.serve.http.HttpRequestHandlerChain; @@ -34,6 +36,7 @@ import org.pytorch.serve.util.ApiUtils; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.util.messages.WorkerCommands; import org.pytorch.serve.wlm.Model; @@ -61,6 +64,8 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req); if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java index 9760babd46..8cb9622618 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java @@ -25,6 +25,7 @@ import org.pytorch.serve.archive.workflow.WorkflowException; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,6 +48,8 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req); if (segments.length >= 2 && "metrics".equals(segments[1])) { ByteBuf resBuf = Unpooled.directBuffer(); List params = diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 6449f0d662..27739adca1 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -5,6 +5,17 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; +import java.security.Key; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.concurrent.TimeUnit; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import org.pytorch.serve.archive.model.KeyTimeOutException; +import org.pytorch.serve.archive.model.InvalidKeyException; +import org.pytorch.serve.archive.model.ModelException; + import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -104,6 +115,7 @@ public final class ConfigManager { private static final String TS_INITIAL_WORKER_PORT = "initial_worker_port"; private static final String TS_INITIAL_DISTRIBUTION_PORT = "initial_distribution_port"; private static final String TS_WORKFLOW_STORE = "workflow_store"; + private static final String TS_TOKEN_EXPIRATION_TIME = "token_expiration"; // minutes // Configuration which are not documented or enabled through environment variables private static final String USE_NATIVE_IO = "use_native_io"; @@ -138,6 +150,11 @@ public final class ConfigManager { private Properties prop; private boolean snapshotDisabled; + private static boolean tokenAuthEnabled; + private static String currSecretKey; + private static Instant tokenExpires; + private static final SecureRandom secureRandom = new SecureRandom(); //threadsafe + private static final Base64.Encoder base64Encoder = Base64.getUrlEncoder(); //threadsafe private static ConfigManager instance; private String hostName; @@ -150,6 +167,15 @@ private ConfigManager(Arguments args) throws IOException { prop = new Properties(); this.snapshotDisabled = args.isSnapshotDisabled(); + this.tokenAuthEnabled = args.isTokenEnabled(); + // if (this.tokenAuthEnabled){ + // String key = generateToken(); + // System.out.println("-----TEST-----"); + // System.out.println("-----KEY------: " + key); + // // this.tokenExpires = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(5)); + // // this.currSecretKey = key; + // System.out.println(Instant.now() + " TIME: " + this.tokenExpires); + // } String version = readFile(getModelServerHome() + "/ts/version.txt"); if (version != null) { version = version.replaceAll("[\\n\\t ]", ""); @@ -251,6 +277,17 @@ private ConfigManager(Arguments args) throws IOException { setModelConfig(); + // Check for token authorization and setup if needed. + if (this.tokenAuthEnabled){ + String key = generateToken(); + // System.out.println("-----AUTHORIZATION-----"); + // System.out.println("-----TOKEN------: " + key); + // System.out.println("-----AUTHORIZATION-----"); + // this.tokenExpires = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(5)); + // this.currSecretKey = key; + System.out.println(Instant.now() + " TIME: " + this.tokenExpires); + } + // Issue warnining about URLs that can be accessed when loading models if (prop.getProperty(TS_ALLOWED_URLS, DEFAULT_TS_ALLOWED_URLS) == DEFAULT_TS_ALLOWED_URLS) { logger.warn( @@ -710,6 +747,8 @@ public String dumpConfigurations() { + isSystemMetricsDisabled() + "\nWorkflow Store: " + (getWorkflowStore() == null ? "N/A" : getWorkflowStore()) + + "\nToken Authorization: " + + (isTokenEnabled() == false ? "N/A" : getKey()) + "\nModel config: " + prop.getProperty(MODEL_CONFIG, "N/A"); } @@ -837,6 +876,72 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } + public boolean isTokenEnabled() { + return tokenAuthEnabled; + } + + public String getKey(){ + return currSecretKey; + } + + public boolean isTokenExpired(){ + return !(Instant.now().isBefore(tokenExpires)); + } + + public String generateToken() { + byte[] randomBytes = new byte[6]; + secureRandom.nextBytes(randomBytes); + setTokenExpiration(); + String key = base64Encoder.encodeToString(randomBytes); + this.currSecretKey = key; + return key; + } + + public void setTokenExpiration(){ + Integer time = 5; + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null){ + time = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); + } + this.tokenExpires = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); + } + + public void checkTokenAuthorization(FullHttpRequest req) throws ModelException { + HttpMethod method = req.method(); + if (isTokenEnabled()){ + System.out.println("TOKEN AUTHORIZATION IS ENABLED." + ); + // VERIFY IF THE TOKEN IS VALID + String tokenBearer = req.headers().get("Authorization"); + if (tokenBearer == null){ + throw new InvalidKeyException("NO TOKEN PROVIDED"); + } + String[] arrOfStr = tokenBearer.split(" ", 2); + String token = arrOfStr[1]; + if (token.equals(getKey())){ + if (isTokenExpired()){ + // System.out.println("TOKEN IS EXPIRED"); + String newToken = generateToken(); + if (!isTokenExpired()){ + System.out.println("This worked"); + }else { + System.out.println("This did not worked"); + } + throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED, NEW TOKEN : " + newToken); + } + else { + System.out.println("TOKEN AUTHORIZATION WORKED"); + } + } else { + throw new InvalidKeyException("TOKEN IS INCORRECT"); + // System.out.println("INCORRECT TOKEN"); + } + }else { + System.out.println("TOKEN AUTHORIZATION IS NOT ENABLED"); + } + } + + // FUNCTION THAT RECEIVES A TOKEN AND CHECKS TO SEE IF IT MATCHES SAVED TOKEN + public boolean isSSLEnabled(ConnectorType connectorType) { String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080"); switch (connectorType) { @@ -928,6 +1033,7 @@ public static final class Arguments { private String[] models; private boolean snapshotDisabled; private String workflowStore; + private boolean tokenAuthEnabled; public Arguments() {} @@ -938,6 +1044,7 @@ public Arguments(CommandLine cmd) { models = cmd.getOptionValues("models"); snapshotDisabled = cmd.hasOption("no-config-snapshot"); workflowStore = cmd.getOptionValue("workflow-store"); + tokenAuthEnabled = cmd.hasOption("token"); } public static Options getOptions() { @@ -983,6 +1090,12 @@ public static Options getOptions() { .argName("WORKFLOW-STORE") .desc("Workflow store location where workflow can be loaded.") .build()); + options.addOption( + Option.builder("token") + .longOpt("token") + .argName("TOKEN") + .desc("enables token authorization") + .build()); return options; } @@ -1006,6 +1119,10 @@ public String getWorkflowStore() { return workflowStore; } + public boolean isTokenEnabled() { + return tokenAuthEnabled; + } + public void setModelStore(String modelStore) { this.modelStore = modelStore; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index 042bef68e1..6266d01b31 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -20,6 +20,7 @@ import org.pytorch.serve.http.StatusResponse; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.InputParameter; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.wlm.WorkerInitializationException; @@ -80,6 +81,9 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + + ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req); if ("wfpredict".equalsIgnoreCase(segments[1])) { if (segments.length < 3) { throw new ResourceNotFoundException(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java index b50f5891b7..a9adec4e6d 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java @@ -26,6 +26,7 @@ import org.pytorch.serve.http.StatusResponse; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.workflow.WorkflowManager; import org.pytorch.serve.workflow.messages.DescribeWorkflowResponse; @@ -63,6 +64,9 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { + + ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req); if (isManagementReq(segments)) { if (!"workflows".equals(segments[1])) { throw new ResourceNotFoundException(); diff --git a/ts/arg_parser.py b/ts/arg_parser.py index 49aea5fcf1..2844ded554 100644 --- a/ts/arg_parser.py +++ b/ts/arg_parser.py @@ -77,6 +77,12 @@ def ts_parser(): dest="plugins_path", help="plugin jars to be included in torchserve class path", ) + parser.add_argument( + "--token", + dest="token_auth", + help="token authorization", + action="store_true", + ) return parser diff --git a/ts/model_server.py b/ts/model_server.py index 5e3334dbab..f582486438 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -181,6 +181,9 @@ def start() -> None: if args.no_config_snapshots: cmd.append("-ncs") + if args.token_auth: + cmd.append("-token") + if args.models: cmd.append("-m") cmd.extend(args.models) From 1304ca3d16fa7279513848f541581191ebfcd556 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Dec 2023 22:44:21 +0000 Subject: [PATCH 02/28] update token --- .../model/{s3 => }/InvalidKeyException.java | 0 .../rest/ApiDescriptionRequestHandler.java | 2 - .../api/rest/InferenceRequestHandler.java | 4 +- .../api/rest/ManagementRequestHandler.java | 2 +- .../rest/PrometheusMetricsRequestHandler.java | 2 +- .../org/pytorch/serve/util/ConfigManager.java | 150 +++++++----------- .../http/WorkflowInferenceRequestHandler.java | 2 +- .../api/http/WorkflowMgmtRequestHandler.java | 2 +- .../pytorch/serve/plugins/endpoint/Token.java | 118 ++++++++++++++ ...torch.serve.servingsdk.ModelServerEndpoint | 1 + ts/arg_parser.py | 6 - ts/model_server.py | 5 +- 12 files changed, 189 insertions(+), 105 deletions(-) rename frontend/archive/src/main/java/org/pytorch/serve/archive/model/{s3 => }/InvalidKeyException.java (100%) create mode 100644 plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java similarity index 100% rename from frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java rename to frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java index 013a53622f..e4fc701d1b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java @@ -31,8 +31,6 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req); if (isApiDescription(segments)) { String path = decoder.path(); if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method())) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 413c7816a7..25aeffd10f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -60,11 +60,13 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req); if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { + configManager.checkTokenAuthorization(req, false); + System.out.println("THIS IS A TEST NUMBER 3.1"); handleCustomEndpoint(ctx, req, segments, decoder); } else { + configManager.checkTokenAuthorization(req, true); switch (segments[1]) { case "ping": Runnable r = diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index 577e1dc269..12e797017e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -65,7 +65,7 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req); + configManager.checkTokenAuthorization(req, false); if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java index 8cb9622618..76168f5eac 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java @@ -49,7 +49,7 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req); + configManager.checkTokenAuthorization(req, true); if (segments.length >= 2 && "metrics".equals(segments[1])) { ByteBuf resBuf = Unpooled.directBuffer(); List params = diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 27739adca1..b79d7a65a6 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -50,6 +50,8 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; +import java.time.DateTimeException; +import java.lang.ArrayIndexOutOfBoundsException; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; @@ -150,11 +152,6 @@ public final class ConfigManager { private Properties prop; private boolean snapshotDisabled; - private static boolean tokenAuthEnabled; - private static String currSecretKey; - private static Instant tokenExpires; - private static final SecureRandom secureRandom = new SecureRandom(); //threadsafe - private static final Base64.Encoder base64Encoder = Base64.getUrlEncoder(); //threadsafe private static ConfigManager instance; private String hostName; @@ -167,15 +164,6 @@ private ConfigManager(Arguments args) throws IOException { prop = new Properties(); this.snapshotDisabled = args.isSnapshotDisabled(); - this.tokenAuthEnabled = args.isTokenEnabled(); - // if (this.tokenAuthEnabled){ - // String key = generateToken(); - // System.out.println("-----TEST-----"); - // System.out.println("-----KEY------: " + key); - // // this.tokenExpires = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(5)); - // // this.currSecretKey = key; - // System.out.println(Instant.now() + " TIME: " + this.tokenExpires); - // } String version = readFile(getModelServerHome() + "/ts/version.txt"); if (version != null) { version = version.replaceAll("[\\n\\t ]", ""); @@ -277,16 +265,6 @@ private ConfigManager(Arguments args) throws IOException { setModelConfig(); - // Check for token authorization and setup if needed. - if (this.tokenAuthEnabled){ - String key = generateToken(); - // System.out.println("-----AUTHORIZATION-----"); - // System.out.println("-----TOKEN------: " + key); - // System.out.println("-----AUTHORIZATION-----"); - // this.tokenExpires = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(5)); - // this.currSecretKey = key; - System.out.println(Instant.now() + " TIME: " + this.tokenExpires); - } // Issue warnining about URLs that can be accessed when loading models if (prop.getProperty(TS_ALLOWED_URLS, DEFAULT_TS_ALLOWED_URLS) == DEFAULT_TS_ALLOWED_URLS) { @@ -747,8 +725,6 @@ public String dumpConfigurations() { + isSystemMetricsDisabled() + "\nWorkflow Store: " + (getWorkflowStore() == null ? "N/A" : getWorkflowStore()) - + "\nToken Authorization: " - + (isTokenEnabled() == false ? "N/A" : getKey()) + "\nModel config: " + prop.getProperty(MODEL_CONFIG, "N/A"); } @@ -876,72 +852,80 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - public boolean isTokenEnabled() { - return tokenAuthEnabled; - } - - public String getKey(){ - return currSecretKey; - } - - public boolean isTokenExpired(){ - return !(Instant.now().isBefore(tokenExpires)); - } - public String generateToken() { - byte[] randomBytes = new byte[6]; - secureRandom.nextBytes(randomBytes); - setTokenExpiration(); - String key = base64Encoder.encodeToString(randomBytes); - this.currSecretKey = key; - return key; + public boolean isTokenExpired(Instant expirationTime){ + return !(Instant.now().isBefore(expirationTime)); } - public void setTokenExpiration(){ - Integer time = 5; - if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null){ - time = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); + public List parseFile(File tsTokenFile) { + List parsedTokens = new ArrayList<>(); + try { + InputStream stream = Files.newInputStream(tsTokenFile.toPath()); + byte[] array = new byte[100]; + stream.read(array); + // Convert byte array into string + String data = new String(array); + String[] arrOfData = data.split("\n", 2); + String[] managementArr = arrOfData[0].split(" ", 3); + String[] inferenceArr = arrOfData[1].split(" ", 7); + parsedTokens.add(managementArr[2]); + parsedTokens.add(inferenceArr[2]); + String[] expirationArr = inferenceArr[6].split("\n", 2); + parsedTokens.add(expirationArr[0]); + } + catch (IOException | ArrayIndexOutOfBoundsException e) { + System.out.println("Unable to read key file or key file has been modified"); + return null; } - this.tokenExpires = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); + return parsedTokens; } - public void checkTokenAuthorization(FullHttpRequest req) throws ModelException { + public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest) throws ModelException { HttpMethod method = req.method(); - if (isTokenEnabled()){ - System.out.println("TOKEN AUTHORIZATION IS ENABLED." - ); - // VERIFY IF THE TOKEN IS VALID - String tokenBearer = req.headers().get("Authorization"); - if (tokenBearer == null){ - throw new InvalidKeyException("NO TOKEN PROVIDED"); - } - String[] arrOfStr = tokenBearer.split(" ", 2); - String token = arrOfStr[1]; - if (token.equals(getKey())){ - if (isTokenExpired()){ - // System.out.println("TOKEN IS EXPIRED"); - String newToken = generateToken(); - if (!isTokenExpired()){ - System.out.println("This worked"); - }else { - System.out.println("This did not worked"); - } - throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED, NEW TOKEN : " + newToken); + // Will change to get file path rather then being set defaulty + String filePath = "/home/ubuntu/serve/key_file.txt"; + if (filePath != null) { + File tsTokenFile = new File(filePath); + if (tsTokenFile.exists()) { + // try (InputStream stream = Files.newInputStream(tsTokenFile.toPath())) { + List parsedTokens = parseFile(tsTokenFile); + String managementToken = parsedTokens.get(0); + String inferenceToken = parsedTokens.get(1); + Instant expirationTime = Instant.now(); + try { + expirationTime = Instant.parse(parsedTokens.get(2)); + }catch(DateTimeException e){ + e.printStackTrace(); + System.out.println("{\n\t\"Error\": Key File has been modified \n}\n"); + } + String tokenBearer = req.headers().get("Authorization"); + if (tokenBearer == null){ + throw new InvalidKeyException("NO TOKEN PROVIDED"); + } + String[] arrOfStr = tokenBearer.split(" ", 2); + if (arrOfStr.length == 1){ + throw new InvalidKeyException("NO TOKEN PROVIDED"); } - else { + String token = arrOfStr[1]; + String key = managementToken; + if (inferenceRequest){ + key = inferenceToken; + } + + if (token.equals(key)){ + if (isTokenExpired(expirationTime) && inferenceRequest){ + throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED"); + } System.out.println("TOKEN AUTHORIZATION WORKED"); + } else { + throw new InvalidKeyException("TOKEN IS INCORRECT "); } } else { - throw new InvalidKeyException("TOKEN IS INCORRECT"); - // System.out.println("INCORRECT TOKEN"); + System.out.println("TOKEN AUTHORIZATION IS NOT ENABLED"); } - }else { - System.out.println("TOKEN AUTHORIZATION IS NOT ENABLED"); } } - // FUNCTION THAT RECEIVES A TOKEN AND CHECKS TO SEE IF IT MATCHES SAVED TOKEN - public boolean isSSLEnabled(ConnectorType connectorType) { String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080"); switch (connectorType) { @@ -1033,7 +1017,6 @@ public static final class Arguments { private String[] models; private boolean snapshotDisabled; private String workflowStore; - private boolean tokenAuthEnabled; public Arguments() {} @@ -1044,7 +1027,6 @@ public Arguments(CommandLine cmd) { models = cmd.getOptionValues("models"); snapshotDisabled = cmd.hasOption("no-config-snapshot"); workflowStore = cmd.getOptionValue("workflow-store"); - tokenAuthEnabled = cmd.hasOption("token"); } public static Options getOptions() { @@ -1090,12 +1072,6 @@ public static Options getOptions() { .argName("WORKFLOW-STORE") .desc("Workflow store location where workflow can be loaded.") .build()); - options.addOption( - Option.builder("token") - .longOpt("token") - .argName("TOKEN") - .desc("enables token authorization") - .build()); return options; } @@ -1119,10 +1095,6 @@ public String getWorkflowStore() { return workflowStore; } - public boolean isTokenEnabled() { - return tokenAuthEnabled; - } - public void setModelStore(String modelStore) { this.modelStore = modelStore; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index 6266d01b31..c8339219bd 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -83,7 +83,7 @@ public void handleRequest( WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req); + configManager.checkTokenAuthorization(req, true); if ("wfpredict".equalsIgnoreCase(segments[1])) { if (segments.length < 3) { throw new ResourceNotFoundException(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java index a9adec4e6d..98e4326e14 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java @@ -66,7 +66,7 @@ public void handleRequest( WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req); + configManager.checkTokenAuthorization(req, false); if (isManagementReq(segments)) { if (!"workflows".equals(segments[1])) { throw new ResourceNotFoundException(); diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java new file mode 100644 index 0000000000..298fca38fd --- /dev/null +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -0,0 +1,118 @@ +package org.pytorch.serve.plugins.endpoint; + +// import java.util.Properties; +import com.google.gson.annotations.SerializedName; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.Base64; +import java.util.concurrent.TimeUnit; +import org.pytorch.serve.servingsdk.Context; +import org.pytorch.serve.servingsdk.ModelServerEndpoint; +import org.pytorch.serve.servingsdk.annotations.Endpoint; +import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes; +import org.pytorch.serve.servingsdk.http.Request; +import org.pytorch.serve.servingsdk.http.Response; + +@Endpoint( + urlPattern = "token", + endpointType = EndpointTypes.INFERENCE, + description = "Execution parameters endpoint") +public class Token extends ModelServerEndpoint { + private boolean managementFlag; + private String managementToken; + + @Override + public void doGet(Request req, Response rsp, Context ctx) throws IOException { + // Properties prop = ctx.getConfig(); + TokenResponse r = new TokenResponse(); + if (!managementFlag) { + managementFlag = true; + r.setKey(); + String output = "{\n\t\"Manager Key\": " + r.getKey() + "\n}\n"; + rsp.getOutputStream().write(output.getBytes(StandardCharsets.UTF_8)); + managementToken = r.getKey(); + } + String test = ""; + if (r.keyFile(managementToken)) { + test = "{\n\t\"File Updated\": successfully \n}\n"; + } else { + test = "{\n\t\"File\": failed \n}\n"; + } + rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8)); + } + + /** Response for Model server endpoint */ + public static class TokenResponse { + private SecureRandom secureRandom = new SecureRandom(); + private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); + + @SerializedName("Key") + private String key; + + @SerializedName("TokenExpiration") + private Instant tokenExpiration; + + public TokenResponse() { + key = "test 2"; + tokenExpiration = Instant.now(); + } + + public String getKey() { + return key; + } + + public Instant getTokenExpiration() { + return tokenExpiration; + } + + public void setKey() { + key = generateKey(); + } + + public boolean keyFile(String managementToken) throws IOException { + String fileSeparator = System.getProperty("file.separator"); + String fileData = " "; + // Will change to get file path rather then being set defaulty + String absoluteFilePath = + fileSeparator + + "home" + + fileSeparator + + "ubuntu" + + fileSeparator + + "serve/key_file.txt"; + File file = new File(absoluteFilePath); + + if (!file.createNewFile() && !file.exists()) { + return false; + } + fileData = + "Management Key: " + + managementToken + + "\n" + + "Inference Key: " + + generateKey() + + " --- Expiration time: " + + tokenExpiration.toString() + + "\n"; + Files.write(Paths.get("key_file.txt"), fileData.getBytes()); + return true; + } + + public String generateKey() { + byte[] randomBytes = new byte[6]; + secureRandom.nextBytes(randomBytes); + setTokenExpiration(); + return baseEncoder.encodeToString(randomBytes); + } + + public void setTokenExpiration() { + Integer time = 3; + tokenExpiration = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); + } + } +} diff --git a/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint b/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint index 88099c457c..887b1a0d3d 100644 --- a/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint +++ b/plugins/endpoints/src/main/resources/META-INF/services/org.pytorch.serve.servingsdk.ModelServerEndpoint @@ -1 +1,2 @@ org.pytorch.serve.plugins.endpoint.ExecutionParameters +org.pytorch.serve.plugins.endpoint.Token diff --git a/ts/arg_parser.py b/ts/arg_parser.py index 2844ded554..49aea5fcf1 100644 --- a/ts/arg_parser.py +++ b/ts/arg_parser.py @@ -77,12 +77,6 @@ def ts_parser(): dest="plugins_path", help="plugin jars to be included in torchserve class path", ) - parser.add_argument( - "--token", - dest="token_auth", - help="token authorization", - action="store_true", - ) return parser diff --git a/ts/model_server.py b/ts/model_server.py index f582486438..7311e0489c 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -48,6 +48,8 @@ def start() -> None: try: parent = psutil.Process(pid) parent.terminate() + # Will change to get file path rather then being set defaulty + os.remove("/home/ubuntu/serve/key_file.txt") if args.foreground: try: parent.wait(timeout=60) @@ -181,9 +183,6 @@ def start() -> None: if args.no_config_snapshots: cmd.append("-ncs") - if args.token_auth: - cmd.append("-token") - if args.models: cmd.append("-m") cmd.extend(args.models) From fde5f4d8ce80df7bedcd460ec67af4dd5c9b033c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Dec 2023 22:51:13 +0000 Subject: [PATCH 03/28] token authorization plugin test --- .../serve/http/api/rest/ApiDescriptionRequestHandler.java | 1 - .../org/pytorch/serve/http/api/rest/InferenceRequestHandler.java | 1 - .../src/main/java/org/pytorch/serve/util/ConfigManager.java | 1 - 3 files changed, 3 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java index e4fc701d1b..def4be404c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java @@ -12,7 +12,6 @@ import org.pytorch.serve.openapi.OpenApiUtils; import org.pytorch.serve.util.ConnectorType; import org.pytorch.serve.util.NettyUtils; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.wlm.WorkerInitializationException; public class ApiDescriptionRequestHandler extends HttpRequestHandlerChain { diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 25aeffd10f..51db8fc8d4 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -63,7 +63,6 @@ public void handleRequest( if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { configManager.checkTokenAuthorization(req, false); - System.out.println("THIS IS A TEST NUMBER 3.1"); handleCustomEndpoint(ctx, req, segments, decoder); } else { configManager.checkTokenAuthorization(req, true); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index b79d7a65a6..bd1d55e272 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -117,7 +117,6 @@ public final class ConfigManager { private static final String TS_INITIAL_WORKER_PORT = "initial_worker_port"; private static final String TS_INITIAL_DISTRIBUTION_PORT = "initial_distribution_port"; private static final String TS_WORKFLOW_STORE = "workflow_store"; - private static final String TS_TOKEN_EXPIRATION_TIME = "token_expiration"; // minutes // Configuration which are not documented or enabled through environment variables private static final String USE_NATIVE_IO = "use_native_io"; From 8ec3c4045cf7210af65819a0572244874ec703c2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Dec 2023 23:38:19 +0000 Subject: [PATCH 04/28] fix format --- .../api/rest/ManagementRequestHandler.java | 4 +- .../rest/PrometheusMetricsRequestHandler.java | 2 +- .../org/pytorch/serve/util/ConfigManager.java | 44 ++++++++----------- .../http/WorkflowInferenceRequestHandler.java | 1 - .../api/http/WorkflowMgmtRequestHandler.java | 2 +- 5 files changed, 21 insertions(+), 32 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index 12e797017e..663249c1ea 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -15,8 +15,6 @@ import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; -import org.pytorch.serve.archive.model.KeyTimeOutException; -import org.pytorch.serve.archive.model.InvalidKeyException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; import org.pytorch.serve.archive.workflow.WorkflowException; import org.pytorch.serve.http.HttpRequestHandlerChain; @@ -34,9 +32,9 @@ import org.pytorch.serve.openapi.OpenApiUtils; import org.pytorch.serve.servingsdk.ModelServerEndpoint; import org.pytorch.serve.util.ApiUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.util.messages.WorkerCommands; import org.pytorch.serve.wlm.Model; diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java index 76168f5eac..024990c7c4 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java @@ -24,8 +24,8 @@ import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.workflow.WorkflowException; import org.pytorch.serve.http.HttpRequestHandlerChain; -import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index bd1d55e272..e6b90fae0c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -2,20 +2,11 @@ import com.google.gson.JsonObject; import com.google.gson.reflect.TypeToken; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; -import java.security.Key; -import java.security.SecureRandom; -import java.time.Instant; -import java.util.concurrent.TimeUnit; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import org.pytorch.serve.archive.model.KeyTimeOutException; -import org.pytorch.serve.archive.model.InvalidKeyException; -import org.pytorch.serve.archive.model.ModelException; - import java.io.File; import java.io.IOException; import java.io.InputStream; @@ -36,6 +27,8 @@ import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; +import java.time.DateTimeException; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -50,12 +43,13 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; -import java.time.DateTimeException; -import java.lang.ArrayIndexOutOfBoundsException; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; +import org.pytorch.serve.archive.model.InvalidKeyException; +import org.pytorch.serve.archive.model.KeyTimeOutException; +import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.metrics.MetricBuilder; import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; import org.pytorch.serve.snapshot.SnapshotSerializerFactory; @@ -264,7 +258,6 @@ private ConfigManager(Arguments args) throws IOException { setModelConfig(); - // Issue warnining about URLs that can be accessed when loading models if (prop.getProperty(TS_ALLOWED_URLS, DEFAULT_TS_ALLOWED_URLS) == DEFAULT_TS_ALLOWED_URLS) { logger.warn( @@ -851,8 +844,7 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - - public boolean isTokenExpired(Instant expirationTime){ + public boolean isTokenExpired(Instant expirationTime) { return !(Instant.now().isBefore(expirationTime)); } @@ -866,20 +858,20 @@ public List parseFile(File tsTokenFile) { String data = new String(array); String[] arrOfData = data.split("\n", 2); String[] managementArr = arrOfData[0].split(" ", 3); - String[] inferenceArr = arrOfData[1].split(" ", 7); + String[] inferenceArr = arrOfData[1].split(" ", 7); parsedTokens.add(managementArr[2]); parsedTokens.add(inferenceArr[2]); String[] expirationArr = inferenceArr[6].split("\n", 2); parsedTokens.add(expirationArr[0]); - } - catch (IOException | ArrayIndexOutOfBoundsException e) { + } catch (IOException | ArrayIndexOutOfBoundsException e) { System.out.println("Unable to read key file or key file has been modified"); return null; } return parsedTokens; } - public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest) throws ModelException { + public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest) + throws ModelException { HttpMethod method = req.method(); // Will change to get file path rather then being set defaulty String filePath = "/home/ubuntu/serve/key_file.txt"; @@ -893,26 +885,26 @@ public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceReques Instant expirationTime = Instant.now(); try { expirationTime = Instant.parse(parsedTokens.get(2)); - }catch(DateTimeException e){ + } catch (DateTimeException e) { e.printStackTrace(); System.out.println("{\n\t\"Error\": Key File has been modified \n}\n"); } String tokenBearer = req.headers().get("Authorization"); - if (tokenBearer == null){ + if (tokenBearer == null) { throw new InvalidKeyException("NO TOKEN PROVIDED"); } String[] arrOfStr = tokenBearer.split(" ", 2); - if (arrOfStr.length == 1){ + if (arrOfStr.length == 1) { throw new InvalidKeyException("NO TOKEN PROVIDED"); } String token = arrOfStr[1]; String key = managementToken; - if (inferenceRequest){ + if (inferenceRequest) { key = inferenceToken; } - if (token.equals(key)){ - if (isTokenExpired(expirationTime) && inferenceRequest){ + if (token.equals(key)) { + if (isTokenExpired(expirationTime) && inferenceRequest) { throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED"); } System.out.println("TOKEN AUTHORIZATION WORKED"); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index c8339219bd..fd0f9a843b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -20,7 +20,6 @@ import org.pytorch.serve.http.StatusResponse; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.messages.InputParameter; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.wlm.WorkerInitializationException; diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java index 98e4326e14..6b55c5c96f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java @@ -24,9 +24,9 @@ import org.pytorch.serve.http.MethodNotAllowedException; import org.pytorch.serve.http.ResourceNotFoundException; import org.pytorch.serve.http.StatusResponse; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.wlm.WorkerInitializationException; import org.pytorch.serve.workflow.WorkflowManager; import org.pytorch.serve.workflow.messages.DescribeWorkflowResponse; From d50beaeb3de24b73e80308b403d2d26dedc6354b Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 3 Jan 2024 11:00:45 -0800 Subject: [PATCH 05/28] add key file generation at default --- .../java/org/pytorch/serve/ModelServer.java | 4 ++ .../org/pytorch/serve/util/ConfigManager.java | 39 +++++++++++++++++-- .../pytorch/serve/plugins/endpoint/Token.java | 36 +++++++++-------- ts/model_server.py | 3 +- 4 files changed, 60 insertions(+), 22 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index b7c9a2823d..ec5c06b8be 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -83,6 +83,10 @@ public static void main(String[] args) { ConfigManager.init(arguments); ConfigManager configManager = ConfigManager.getInstance(); PluginsManager.getInstance().initialize(); + Map plugins = PluginsManager.getInstance().getInferenceEndpoints(); + if (plugins.containsKey("token")){ + configManager.generateKeyFile(); + } MetricCache.init(); InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); ModelServer modelServer = new ModelServer(configManager); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index e6b90fae0c..4e24df725e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -22,6 +22,7 @@ import java.security.KeyFactory; import java.security.KeyStore; import java.security.PrivateKey; +import java.security.SecureRandom; import java.security.cert.Certificate; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; @@ -40,6 +41,7 @@ import java.util.Map; import java.util.Properties; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; @@ -153,6 +155,10 @@ public final class ConfigManager { private boolean telemetryEnabled; private Logger logger = LoggerFactory.getLogger(ConfigManager.class); + private SecureRandom secureRandom = new SecureRandom(); + private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); + public String keyFileLocation; + private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -848,13 +854,40 @@ public boolean isTokenExpired(Instant expirationTime) { return !(Instant.now().isBefore(expirationTime)); } + public String generateToken(){ + byte[] randomBytes = new byte[6]; + secureRandom.nextBytes(randomBytes); + return baseEncoder.encodeToString(randomBytes); + } + + public boolean generateKeyFile() throws IOException { + String fileData = " "; + String absoluteFilePath = getCanonicalPath(".") + "/key_file.txt"; + keyFileLocation = absoluteFilePath; + File file = new File(absoluteFilePath); + if (!file.createNewFile() && !file.exists()) { + return false; + } + Integer timeToExpiration = 30; // in minutes + fileData = + "Management Key: " + + generateToken() + + "\n" + + "Inference Key: " + + generateToken() + + " --- Expiration time: " + + Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpiration)) + + "\n"; + Files.write(Paths.get("key_file.txt"), fileData.getBytes()); + return true; + } + public List parseFile(File tsTokenFile) { List parsedTokens = new ArrayList<>(); try { InputStream stream = Files.newInputStream(tsTokenFile.toPath()); byte[] array = new byte[100]; stream.read(array); - // Convert byte array into string String data = new String(array); String[] arrOfData = data.split("\n", 2); String[] managementArr = arrOfData[0].split(" ", 3); @@ -873,12 +906,10 @@ public List parseFile(File tsTokenFile) { public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest) throws ModelException { HttpMethod method = req.method(); - // Will change to get file path rather then being set defaulty - String filePath = "/home/ubuntu/serve/key_file.txt"; + String filePath = keyFileLocation; if (filePath != null) { File tsTokenFile = new File(filePath); if (tsTokenFile.exists()) { - // try (InputStream stream = Files.newInputStream(tsTokenFile.toPath())) { List parsedTokens = parseFile(tsTokenFile); String managementToken = parsedTokens.get(0); String inferenceToken = parsedTokens.get(1); diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index 298fca38fd..75781e802e 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -4,6 +4,7 @@ import com.google.gson.annotations.SerializedName; import java.io.File; import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; @@ -28,14 +29,10 @@ public class Token extends ModelServerEndpoint { @Override public void doGet(Request req, Response rsp, Context ctx) throws IOException { - // Properties prop = ctx.getConfig(); TokenResponse r = new TokenResponse(); if (!managementFlag) { managementFlag = true; - r.setKey(); - String output = "{\n\t\"Manager Key\": " + r.getKey() + "\n}\n"; - rsp.getOutputStream().write(output.getBytes(StandardCharsets.UTF_8)); - managementToken = r.getKey(); + managementToken = r.findManagementKey(); } String test = ""; if (r.keyFile(managementToken)) { @@ -74,19 +71,26 @@ public void setKey() { key = generateKey(); } + public String findManagementKey() { + String userDirectory = System.getProperty("user.dir"); + File file = new File(userDirectory + "/key_file.txt"); + try { + InputStream stream = Files.newInputStream(file.toPath()); + byte[] array = new byte[100]; + stream.read(array); + String data = new String(array); + String[] arrOfData = data.split("\n", 2); + String[] managementArr = arrOfData[0].split(" ", 3); + return managementArr[2]; + } catch (IOException | ArrayIndexOutOfBoundsException e) { + return null; + } + } + public boolean keyFile(String managementToken) throws IOException { - String fileSeparator = System.getProperty("file.separator"); String fileData = " "; - // Will change to get file path rather then being set defaulty - String absoluteFilePath = - fileSeparator - + "home" - + fileSeparator - + "ubuntu" - + fileSeparator - + "serve/key_file.txt"; - File file = new File(absoluteFilePath); - + String userDirectory = System.getProperty("user.dir"); + File file = new File(userDirectory + "/key_file.txt"); if (!file.createNewFile() && !file.exists()) { return false; } diff --git a/ts/model_server.py b/ts/model_server.py index 7311e0489c..51f13cad93 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -48,8 +48,7 @@ def start() -> None: try: parent = psutil.Process(pid) parent.terminate() - # Will change to get file path rather then being set defaulty - os.remove("/home/ubuntu/serve/key_file.txt") + os.remove(os.getcwd() + "/key_file.txt") if args.foreground: try: parent.wait(timeout=60) From d5c25e67f9463a9d062c40afd0197f6e75998438 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 3 Jan 2024 13:18:51 -0800 Subject: [PATCH 06/28] fix format --- .../server/src/main/java/org/pytorch/serve/ModelServer.java | 5 +++-- .../src/main/java/org/pytorch/serve/util/ConfigManager.java | 2 +- .../main/java/org/pytorch/serve/plugins/endpoint/Token.java | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index ec5c06b8be..d5ce183c0e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -83,8 +83,9 @@ public static void main(String[] args) { ConfigManager.init(arguments); ConfigManager configManager = ConfigManager.getInstance(); PluginsManager.getInstance().initialize(); - Map plugins = PluginsManager.getInstance().getInferenceEndpoints(); - if (plugins.containsKey("token")){ + Map plugins = + PluginsManager.getInstance().getInferenceEndpoints(); + if (plugins.containsKey("token")) { configManager.generateKeyFile(); } MetricCache.init(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 4e24df725e..344f000514 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -854,7 +854,7 @@ public boolean isTokenExpired(Instant expirationTime) { return !(Instant.now().isBefore(expirationTime)); } - public String generateToken(){ + public String generateToken() { byte[] randomBytes = new byte[6]; secureRandom.nextBytes(randomBytes); return baseEncoder.encodeToString(randomBytes); diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index 75781e802e..e8da27c92a 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -115,7 +115,7 @@ public String generateKey() { } public void setTokenExpiration() { - Integer time = 3; + Integer time = 30; tokenExpiration = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); } } From 62cadc81f2b3a88ac41138ddd023b80792ff5895 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 17 Jan 2024 14:09:14 -0800 Subject: [PATCH 07/28] updated token plugin --- docs/token_authorization_api.md | 39 ++++ .../java/org/pytorch/serve/ModelServer.java | 4 +- .../api/rest/InferenceRequestHandler.java | 3 +- .../api/rest/ManagementRequestHandler.java | 5 +- .../rest/PrometheusMetricsRequestHandler.java | 3 - .../org/pytorch/serve/util/ConfigManager.java | 177 +++++++-------- .../http/WorkflowInferenceRequestHandler.java | 2 - .../api/http/WorkflowMgmtRequestHandler.java | 3 - .../pytorch/serve/plugins/endpoint/Token.java | 212 ++++++++++++------ 9 files changed, 270 insertions(+), 178 deletions(-) create mode 100644 docs/token_authorization_api.md diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md new file mode 100644 index 0000000000..a26e1153d1 --- /dev/null +++ b/docs/token_authorization_api.md @@ -0,0 +1,39 @@ +# TorchServe token authorization API + +## Customer Use +1. Enable token authorization by adding the provided plugin at start using the `--plugin-path` command. +2. Torchserve will enable token authorization if the plugin is provided. In the model server home folder a file `key_file.txt` will be generated. + 1. Example key file: + + `Management Key: aadJv_R6 --- Expiration time: 2024-01-16T22:23:32.952499Z` + + `Inference Key: poZXAlqe --- Expiration time: 2024-01-16T22:23:50.621298Z` + + `API Key: xryL_Vzs` +3. There are 3 keys and each have a different use. + 1. Management key: Used for management apis. Example: + `curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer aadJv_R6"` + 2. Inference key: Used for inference apis. Example: + `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer poZXAlqe"` + 3. API key: Used for the token authorization api. Check section 4 for api use. + 4. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models. +4. The plugin also includes an api in order to generate a new key to replace either the management or inference key. + 1. Management Example: + `curl localhost:8081/token?type=management -H "Authorization: Bearer xryL_Vzs"` will replace the current management key in the key_file with a new one and will update the expiration time. + 2. Inference example: + `curl localhost:8081/token?type=inference -H "Authorization: Bearer xryL_Vzs"` + + Users will have to use either one of the apis above. + +5. When users shut down the server the key_file will be deleted. + + +## Customization +Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result. +1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration`. Ex:`token_expiration=30` +2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is: + 1. The urlPattern for the plugin must be 'token' and the class name must not change + 2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signuture must not change. However, the code in the functions can be modified depending on user necessity. + +## Notes +1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file. diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index d5ce183c0e..084c57c593 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -84,9 +84,9 @@ public static void main(String[] args) { ConfigManager configManager = ConfigManager.getInstance(); PluginsManager.getInstance().initialize(); Map plugins = - PluginsManager.getInstance().getInferenceEndpoints(); + PluginsManager.getInstance().getManagementEndpoints(); if (plugins.containsKey("token")) { - configManager.generateKeyFile(); + configManager.setupTokenClass(); } MetricCache.init(); InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 51db8fc8d4..34c0e49261 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -60,12 +60,11 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); + configManager.checkTokenAuthorization(req, 2); if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { - configManager.checkTokenAuthorization(req, false); handleCustomEndpoint(ctx, req, segments, decoder); } else { - configManager.checkTokenAuthorization(req, true); switch (segments[1]) { case "ping": Runnable r = diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index 663249c1ea..55dc8efd89 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -63,11 +63,14 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req, false); if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { + if (req.toString().contains("/token")) { + configManager.checkTokenAuthorization(req, 0); + } handleCustomEndpoint(ctx, req, segments, decoder); } else { + configManager.checkTokenAuthorization(req, 1); if (!"models".equals(segments[1])) { throw new ResourceNotFoundException(); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java index 024990c7c4..9760babd46 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/PrometheusMetricsRequestHandler.java @@ -24,7 +24,6 @@ import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.workflow.WorkflowException; import org.pytorch.serve.http.HttpRequestHandlerChain; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; @@ -48,8 +47,6 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req, true); if (segments.length >= 2 && "metrics".equals(segments[1])) { ByteBuf resBuf = Unpooled.directBuffer(); List params = diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 344f000514..cc273dc4ac 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -3,13 +3,13 @@ import com.google.gson.JsonObject; import com.google.gson.reflect.TypeToken; import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.*; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -28,8 +28,6 @@ import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; -import java.time.DateTimeException; -import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -41,7 +39,6 @@ import java.util.Map; import java.util.Properties; import java.util.Set; -import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; @@ -50,7 +47,6 @@ import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; import org.pytorch.serve.archive.model.InvalidKeyException; -import org.pytorch.serve.archive.model.KeyTimeOutException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.metrics.MetricBuilder; import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; @@ -113,6 +109,7 @@ public final class ConfigManager { private static final String TS_INITIAL_WORKER_PORT = "initial_worker_port"; private static final String TS_INITIAL_DISTRIBUTION_PORT = "initial_distribution_port"; private static final String TS_WORKFLOW_STORE = "workflow_store"; + private static final String TS_TOKEN_EXPIRATION_TIME = "token_expiration"; // minutes // Configuration which are not documented or enabled through environment variables private static final String USE_NATIVE_IO = "use_native_io"; @@ -157,7 +154,10 @@ public final class ConfigManager { private SecureRandom secureRandom = new SecureRandom(); private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); - public String keyFileLocation; + private boolean tokenAuthorizationEnabled; + private Class tokenClass; + private Object tokenObject; + private Integer timeToExpiration = 60; private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -850,100 +850,93 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - public boolean isTokenExpired(Instant expirationTime) { - return !(Instant.now().isBefore(expirationTime)); - } - - public String generateToken() { - byte[] randomBytes = new byte[6]; - secureRandom.nextBytes(randomBytes); - return baseEncoder.encodeToString(randomBytes); + // Imports the token class and sets the expiration time either default or custom + // calls generate key file in token api to create 3 keys and logs the result + public void setupTokenClass() { + try { + tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); + tokenObject = tokenClass.getDeclaredConstructor().newInstance(); + Method method = tokenClass.getMethod("setTime", Integer.class); + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { + timeToExpiration = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); + } + method.invoke(tokenObject, timeToExpiration); + method = tokenClass.getMethod("generateKeyFile", Integer.class); + if ((boolean) method.invoke(tokenObject, 0)) { + System.out.println("TOKEN CLASS IMPORTED SUCCESSFULLY"); + dumpKeyLogs(); + } + } catch (ClassNotFoundException e) { + e.printStackTrace(); + } catch (NoSuchMethodException + | IllegalAccessException + | InstantiationException + | InvocationTargetException e) { + e.printStackTrace(); + } + tokenAuthorizationEnabled = true; } - public boolean generateKeyFile() throws IOException { - String fileData = " "; - String absoluteFilePath = getCanonicalPath(".") + "/key_file.txt"; - keyFileLocation = absoluteFilePath; - File file = new File(absoluteFilePath); - if (!file.createNewFile() && !file.exists()) { - return false; - } - Integer timeToExpiration = 30; // in minutes - fileData = - "Management Key: " - + generateToken() - + "\n" - + "Inference Key: " - + generateToken() - + " --- Expiration time: " - + Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpiration)) - + "\n"; - Files.write(Paths.get("key_file.txt"), fileData.getBytes()); - return true; - } - - public List parseFile(File tsTokenFile) { - List parsedTokens = new ArrayList<>(); + public void dumpKeyLogs() { + String managementKey = ""; + String inferenceKey = ""; + String apiKey = ""; try { - InputStream stream = Files.newInputStream(tsTokenFile.toPath()); - byte[] array = new byte[100]; - stream.read(array); - String data = new String(array); - String[] arrOfData = data.split("\n", 2); - String[] managementArr = arrOfData[0].split(" ", 3); - String[] inferenceArr = arrOfData[1].split(" ", 7); - parsedTokens.add(managementArr[2]); - parsedTokens.add(inferenceArr[2]); - String[] expirationArr = inferenceArr[6].split("\n", 2); - parsedTokens.add(expirationArr[0]); - } catch (IOException | ArrayIndexOutOfBoundsException e) { - System.out.println("Unable to read key file or key file has been modified"); - return null; + Method method = tokenClass.getMethod("getManagementKey"); + managementKey = (String) method.invoke(tokenObject); + method = tokenClass.getMethod("getInferenceKey"); + inferenceKey = (String) method.invoke(tokenObject); + method = tokenClass.getMethod("getKey"); + apiKey = (String) method.invoke(tokenObject); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); } - return parsedTokens; - } - public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest) + logger.info("KEY FILE PATH: " + System.getProperty("user.dir") + "/key_file.txt"); + logger.info("MANAGEMENT KEY: " + managementKey); + logger.info("INFERNCE KEY: " + inferenceKey); + logger.info("API KEY: " + apiKey); + logger.info( + "MANAGEMENT API Example: curl http://localhost:8081/models/ -H \"Authorization: Bearer " + + managementKey + + "\""); + logger.info( + "INFERNCE API Example: curl http://127.0.0.1:8080/predictions/ -T -H \"Authorization: Bearer " + + inferenceKey + + "\""); + logger.info( + "API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer " + + apiKey + + "\""); + } + + public boolean isTokenEnabled() { + return tokenAuthorizationEnabled; + } + + // Calls the checkTokenAuthorization function in the token plugin + // expects two inputs: the fullhttpRequest and an integer which is associated with the type + // 0: token api + // 1: management api + // 2: inference api + public void checkTokenAuthorization(FullHttpRequest req, Integer requestType) throws ModelException { - HttpMethod method = req.method(); - String filePath = keyFileLocation; - if (filePath != null) { - File tsTokenFile = new File(filePath); - if (tsTokenFile.exists()) { - List parsedTokens = parseFile(tsTokenFile); - String managementToken = parsedTokens.get(0); - String inferenceToken = parsedTokens.get(1); - Instant expirationTime = Instant.now(); - try { - expirationTime = Instant.parse(parsedTokens.get(2)); - } catch (DateTimeException e) { - e.printStackTrace(); - System.out.println("{\n\t\"Error\": Key File has been modified \n}\n"); - } - String tokenBearer = req.headers().get("Authorization"); - if (tokenBearer == null) { - throw new InvalidKeyException("NO TOKEN PROVIDED"); - } - String[] arrOfStr = tokenBearer.split(" ", 2); - if (arrOfStr.length == 1) { - throw new InvalidKeyException("NO TOKEN PROVIDED"); - } - String token = arrOfStr[1]; - String key = managementToken; - if (inferenceRequest) { - key = inferenceToken; - } - if (token.equals(key)) { - if (isTokenExpired(expirationTime) && inferenceRequest) { - throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED"); - } - System.out.println("TOKEN AUTHORIZATION WORKED"); - } else { - throw new InvalidKeyException("TOKEN IS INCORRECT "); + if (tokenAuthorizationEnabled) { + try { + Method method = + tokenClass.getMethod( + "checkTokenAuthorization", + io.netty.handler.codec.http.FullHttpRequest.class, + Integer.class); + boolean result = (boolean) (method.invoke(tokenObject, req, requestType)); + if (!result) { + throw new InvalidKeyException( + "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); } - } else { - System.out.println("TOKEN AUTHORIZATION IS NOT ENABLED"); + System.out.println("TOKEN AUTHORIZATION WORKED"); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); } } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java index fd0f9a843b..916e723c81 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java @@ -81,8 +81,6 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req, true); if ("wfpredict".equalsIgnoreCase(segments[1])) { if (segments.length < 3) { throw new ResourceNotFoundException(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java index 6b55c5c96f..6125de61e7 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java @@ -24,7 +24,6 @@ import org.pytorch.serve.http.MethodNotAllowedException; import org.pytorch.serve.http.ResourceNotFoundException; import org.pytorch.serve.http.StatusResponse; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.wlm.WorkerInitializationException; @@ -65,8 +64,6 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req, false); if (isManagementReq(segments)) { if (!"workflows".equals(segments[1])) { throw new ResourceNotFoundException(); diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index e8da27c92a..d2aa37b138 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -1,16 +1,18 @@ package org.pytorch.serve.plugins.endpoint; // import java.util.Properties; -import com.google.gson.annotations.SerializedName; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; import java.io.File; import java.io.IOException; -import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; import java.security.SecureRandom; import java.time.Instant; import java.util.Base64; +import java.util.List; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.pytorch.serve.servingsdk.Context; import org.pytorch.serve.servingsdk.ModelServerEndpoint; @@ -21,102 +23,166 @@ @Endpoint( urlPattern = "token", - endpointType = EndpointTypes.INFERENCE, - description = "Execution parameters endpoint") + endpointType = EndpointTypes.MANAGEMENT, + description = "Token authentication endpoint") public class Token extends ModelServerEndpoint { - private boolean managementFlag; - private String managementToken; + private static String apiKey; + private static String managementKey; + private static String inferenceKey; + private static Instant managementExpirationTime; + private static Instant inferenceExpirationTime; + private static Integer timeToExpiration = 30; + private SecureRandom secureRandom = new SecureRandom(); + private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); @Override public void doGet(Request req, Response rsp, Context ctx) throws IOException { - TokenResponse r = new TokenResponse(); - if (!managementFlag) { - managementFlag = true; - managementToken = r.findManagementKey(); - } + String queryResponse = parseQuery(req); String test = ""; - if (r.keyFile(managementToken)) { - test = "{\n\t\"File Updated\": successfully \n}\n"; + if ("management".equals(queryResponse)) { + generateKeyFile(1); + } else if ("inference".equals(queryResponse)) { + generateKeyFile(2); } else { - test = "{\n\t\"File\": failed \n}\n"; + test = "{\n\t\"Error\": " + queryResponse + "\n}\n"; } rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8)); } - /** Response for Model server endpoint */ - public static class TokenResponse { - private SecureRandom secureRandom = new SecureRandom(); - private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); + // parses query and either returns "management"/"inference" or a wrong type error + public String parseQuery(Request req) { + QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI()); + Map> parameters = decoder.parameters(); + List values = parameters.get("type"); + if (values != null && !values.isEmpty()) { + if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) { + return values.get(0); + } else { + return "WRONG TYPE"; + } + } + return "NO TYPE PROVIDED"; + } - @SerializedName("Key") - private String key; + public String generateKey() { + byte[] randomBytes = new byte[6]; + secureRandom.nextBytes(randomBytes); + return baseEncoder.encodeToString(randomBytes); + } - @SerializedName("TokenExpiration") - private Instant tokenExpiration; + public Instant generateTokenExpiration(Integer time) { + return Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); + } - public TokenResponse() { - key = "test 2"; - tokenExpiration = Instant.now(); + // generates a key file with new keys depending on the parameter provided + // 0: generates all 3 keys + // 1: generates management key and keeps other 2 the same + // 2: generates inference key and keeps other 2 the same + public boolean generateKeyFile(Integer keyCase) throws IOException { + String fileData = " "; + String userDirectory = System.getProperty("user.dir") + "/key_file.txt"; + File file = new File(userDirectory); + if (!file.createNewFile() && !file.exists()) { + return false; } - - public String getKey() { - return key; + if (apiKey == null) { + apiKey = generateKey(); } - - public Instant getTokenExpiration() { - return tokenExpiration; + switch (keyCase) { + case 1: + managementKey = generateKey(); + managementExpirationTime = generateTokenExpiration(timeToExpiration); + break; + case 2: + inferenceKey = generateKey(); + inferenceExpirationTime = generateTokenExpiration(timeToExpiration); + break; + default: + managementKey = generateKey(); + inferenceKey = generateKey(); + inferenceExpirationTime = generateTokenExpiration(timeToExpiration); + managementExpirationTime = generateTokenExpiration(timeToExpiration); } - public void setKey() { - key = generateKey(); + fileData = + "Management Key: " + + managementKey + + " --- Expiration time: " + + managementExpirationTime + + "\nInference Key: " + + inferenceKey + + " --- Expiration time: " + + inferenceExpirationTime + + "\nAPI Key: " + + apiKey + + "\n"; + Files.write(Paths.get("key_file.txt"), fileData.getBytes()); + return true; + } + + // checks the token provided in the http with the saved keys depening on parameters + public boolean checkTokenAuthorization(FullHttpRequest req, Integer keyCase) { + String key; + Instant expiration; + switch (keyCase) { + case 0: + key = apiKey; + expiration = null; + break; + case 1: + key = managementKey; + expiration = managementExpirationTime; + break; + default: + key = inferenceKey; + expiration = inferenceExpirationTime; } - public String findManagementKey() { - String userDirectory = System.getProperty("user.dir"); - File file = new File(userDirectory + "/key_file.txt"); - try { - InputStream stream = Files.newInputStream(file.toPath()); - byte[] array = new byte[100]; - stream.read(array); - String data = new String(array); - String[] arrOfData = data.split("\n", 2); - String[] managementArr = arrOfData[0].split(" ", 3); - return managementArr[2]; - } catch (IOException | ArrayIndexOutOfBoundsException e) { - return null; - } + String tokenBearer = req.headers().get("Authorization"); + if (tokenBearer == null) { + return false; } + String[] arrOfStr = tokenBearer.split(" ", 2); + if (arrOfStr.length == 1) { + return false; + } + String token = arrOfStr[1]; - public boolean keyFile(String managementToken) throws IOException { - String fileData = " "; - String userDirectory = System.getProperty("user.dir"); - File file = new File(userDirectory + "/key_file.txt"); - if (!file.createNewFile() && !file.exists()) { + if (token.equals(key)) { + if (expiration != null && isTokenExpired(expiration)) { return false; } - fileData = - "Management Key: " - + managementToken - + "\n" - + "Inference Key: " - + generateKey() - + " --- Expiration time: " - + tokenExpiration.toString() - + "\n"; - Files.write(Paths.get("key_file.txt"), fileData.getBytes()); - return true; + } else { + return false; } + return true; + } - public String generateKey() { - byte[] randomBytes = new byte[6]; - secureRandom.nextBytes(randomBytes); - setTokenExpiration(); - return baseEncoder.encodeToString(randomBytes); - } + public boolean isTokenExpired(Instant expirationTime) { + return !(Instant.now().isBefore(expirationTime)); + } - public void setTokenExpiration() { - Integer time = 30; - tokenExpiration = Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); - } + public String getManagementKey() { + return managementKey; + } + + public String getInferenceKey() { + return inferenceKey; + } + + public String getKey() { + return apiKey; + } + + public Instant getInferenceExpirationTime() { + return inferenceExpirationTime; + } + + public Instant getManagementExpirationTime() { + return managementExpirationTime; + } + + public void setTime(Integer time) { + timeToExpiration = time; } } From 129b65271296bee80ec580dc6412fd82b76904ec Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 17 Jan 2024 15:06:49 -0800 Subject: [PATCH 08/28] fixed file delete --- docs/token_authorization_api.md | 2 +- .../archive/model/InvalidKeyException.java | 32 ------------------- .../archive/model/KeyTimeOutException.java | 32 ------------------- ts/model_server.py | 5 ++- 4 files changed, 5 insertions(+), 66 deletions(-) delete mode 100644 frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java delete mode 100644 frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index a26e1153d1..5ac1c0472b 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -33,7 +33,7 @@ Torchserve offers various ways to customize the token authorization to allow own 1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration`. Ex:`token_expiration=30` 2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is: 1. The urlPattern for the plugin must be 'token' and the class name must not change - 2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signuture must not change. However, the code in the functions can be modified depending on user necessity. + 2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity. ## Notes 1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file. diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java deleted file mode 100644 index 1045264e3e..0000000000 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.pytorch.serve.archive.model; - -public class InvalidKeyException extends ModelException { - - private static final long serialVersionUID = 1L; - - /** - * Constructs an {@code InvalidKeyException} with the specified detail message. - * - * @param message The detail message (which is saved for later retrieval by the {@link - * #getMessage()} method) - */ - public InvalidKeyException(String message) { - super(message); - } - - /** - * Constructs an {@code InvalidKeyException} with the specified detail message and cause. - * - *

Note that the detail message associated with {@code cause} is not automatically - * incorporated into this exception's detail message. - * - * @param message The detail message (which is saved for later retrieval by the {@link - * #getMessage()} method) - * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} - * method). (A null value is permitted, and indicates that the cause is nonexistent or - * unknown.) - */ - public InvalidKeyException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java deleted file mode 100644 index ee9da9c3bb..0000000000 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/KeyTimeOutException.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.pytorch.serve.archive.model; - -public class KeyTimeOutException extends ModelException { - - private static final long serialVersionUID = 1L; - - /** - * Constructs an {@code KeyTimeOutException} with the specified detail message. - * - * @param message The detail message (which is saved for later retrieval by the {@link - * #getMessage()} method) - */ - public KeyTimeOutException(String message) { - super(message); - } - - /** - * Constructs an {@code KeyTimeOutException} with the specified detail message and cause. - * - *

Note that the detail message associated with {@code cause} is not automatically - * incorporated into this exception's detail message. - * - * @param message The detail message (which is saved for later retrieval by the {@link - * #getMessage()} method) - * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} - * method). (A null value is permitted, and indicates that the cause is nonexistent or - * unknown.) - */ - public KeyTimeOutException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/ts/model_server.py b/ts/model_server.py index 51f13cad93..b2c5fa3948 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -48,7 +48,10 @@ def start() -> None: try: parent = psutil.Process(pid) parent.terminate() - os.remove(os.getcwd() + "/key_file.txt") + try: + os.remove(os.getcwd() + "/key_file.txt") + except FileNotFoundError: + print("Token authorization not enabled") if args.foreground: try: parent.wait(timeout=60) From cecc0865654d33cc18f774eab69c042de78ad68a Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 17 Jan 2024 15:23:25 -0800 Subject: [PATCH 09/28] fixed imports --- .../org/pytorch/serve/util/ConfigManager.java | 62 ++++++++----------- 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index cc273dc4ac..ec3d83a70f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -3,13 +3,13 @@ import com.google.gson.JsonObject; import com.google.gson.reflect.TypeToken; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.*; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -28,6 +28,8 @@ import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; +import java.time.DateTimeException; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -39,6 +41,7 @@ import java.util.Map; import java.util.Properties; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; @@ -46,13 +49,13 @@ import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; -import org.pytorch.serve.archive.model.InvalidKeyException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.metrics.MetricBuilder; import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; import org.pytorch.serve.snapshot.SnapshotSerializerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.*; public final class ConfigManager { // Variables that can be configured through config.properties and Environment Variables @@ -159,6 +162,7 @@ public final class ConfigManager { private Object tokenObject; private Integer timeToExpiration = 60; + private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -850,14 +854,15 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } + // Imports the token class and sets the expiration time either default or custom // calls generate key file in token api to create 3 keys and logs the result - public void setupTokenClass() { + public void setupTokenClass(){ try { tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); tokenObject = tokenClass.getDeclaredConstructor().newInstance(); Method method = tokenClass.getMethod("setTime", Integer.class); - if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null){ timeToExpiration = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); } method.invoke(tokenObject, timeToExpiration); @@ -866,18 +871,16 @@ public void setupTokenClass() { System.out.println("TOKEN CLASS IMPORTED SUCCESSFULLY"); dumpKeyLogs(); } - } catch (ClassNotFoundException e) { + } + catch (ClassNotFoundException e) { e.printStackTrace(); - } catch (NoSuchMethodException - | IllegalAccessException - | InstantiationException - | InvocationTargetException e) { + } catch (NoSuchMethodException | IllegalAccessException | InstantiationException | InvocationTargetException e) { e.printStackTrace(); } tokenAuthorizationEnabled = true; } - public void dumpKeyLogs() { + public void dumpKeyLogs(){ String managementKey = ""; String inferenceKey = ""; String apiKey = ""; @@ -896,21 +899,13 @@ public void dumpKeyLogs() { logger.info("MANAGEMENT KEY: " + managementKey); logger.info("INFERNCE KEY: " + inferenceKey); logger.info("API KEY: " + apiKey); - logger.info( - "MANAGEMENT API Example: curl http://localhost:8081/models/ -H \"Authorization: Bearer " - + managementKey - + "\""); - logger.info( - "INFERNCE API Example: curl http://127.0.0.1:8080/predictions/ -T -H \"Authorization: Bearer " - + inferenceKey - + "\""); - logger.info( - "API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer " - + apiKey - + "\""); - } - - public boolean isTokenEnabled() { + logger.info("MANAGEMENT API Example: curl http://localhost:8081/models/ -H \"Authorization: Bearer " + managementKey+"\""); + logger.info("INFERNCE API Example: curl http://127.0.0.1:8080/predictions/ -T -H \"Authorization: Bearer " + inferenceKey + "\""); + logger.info("API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer " + apiKey + "\""); + + } + + public boolean isTokenEnabled(){ return tokenAuthorizationEnabled; } @@ -922,20 +917,15 @@ public boolean isTokenEnabled() { public void checkTokenAuthorization(FullHttpRequest req, Integer requestType) throws ModelException { - if (tokenAuthorizationEnabled) { + if (tokenAuthorizationEnabled){ try { - Method method = - tokenClass.getMethod( - "checkTokenAuthorization", - io.netty.handler.codec.http.FullHttpRequest.class, - Integer.class); - boolean result = (boolean) (method.invoke(tokenObject, req, requestType)); - if (!result) { - throw new InvalidKeyException( - "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); + Method method = tokenClass.getMethod("checkTokenAuthorization", io.netty.handler.codec.http.FullHttpRequest.class, Integer.class); + boolean result = (boolean)(method.invoke(tokenObject, req, requestType)); + if (!result){ + throw new InvalidKeyException("Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); } System.out.println("TOKEN AUTHORIZATION WORKED"); - } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + }catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { e.printStackTrace(); } } From 68767818973ccd7f4e586b7fe920ed159566adf5 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 17 Jan 2024 15:30:33 -0800 Subject: [PATCH 10/28] added custom expection --- .../archive/model/s3/InvalidKeyException.java | 32 +++++++++++++++++++ .../org/pytorch/serve/util/ConfigManager.java | 1 + 2 files changed, 33 insertions(+) create mode 100644 frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java new file mode 100644 index 0000000000..1045264e3e --- /dev/null +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java @@ -0,0 +1,32 @@ +package org.pytorch.serve.archive.model; + +public class InvalidKeyException extends ModelException { + + private static final long serialVersionUID = 1L; + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public InvalidKeyException(String message) { + super(message); + } + + /** + * Constructs an {@code InvalidKeyException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public InvalidKeyException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index ec3d83a70f..9a351bba7a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -49,6 +49,7 @@ import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; +import org.pytorch.serve.archive.model.InvalidKeyException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.metrics.MetricBuilder; import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; From 4848114adfbae27e71c6c1b218731f580f386ccd Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 17 Jan 2024 16:12:47 -0800 Subject: [PATCH 11/28] fix format --- .../org/pytorch/serve/util/ConfigManager.java | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 9a351bba7a..cc273dc4ac 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -3,13 +3,13 @@ import com.google.gson.JsonObject; import com.google.gson.reflect.TypeToken; import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.*; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -28,8 +28,6 @@ import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; -import java.time.DateTimeException; -import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -41,7 +39,6 @@ import java.util.Map; import java.util.Properties; import java.util.Set; -import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; @@ -56,7 +53,6 @@ import org.pytorch.serve.snapshot.SnapshotSerializerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.reflect.*; public final class ConfigManager { // Variables that can be configured through config.properties and Environment Variables @@ -163,7 +159,6 @@ public final class ConfigManager { private Object tokenObject; private Integer timeToExpiration = 60; - private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -855,15 +850,14 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - // Imports the token class and sets the expiration time either default or custom // calls generate key file in token api to create 3 keys and logs the result - public void setupTokenClass(){ + public void setupTokenClass() { try { tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); tokenObject = tokenClass.getDeclaredConstructor().newInstance(); Method method = tokenClass.getMethod("setTime", Integer.class); - if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null){ + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { timeToExpiration = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); } method.invoke(tokenObject, timeToExpiration); @@ -872,16 +866,18 @@ public void setupTokenClass(){ System.out.println("TOKEN CLASS IMPORTED SUCCESSFULLY"); dumpKeyLogs(); } - } - catch (ClassNotFoundException e) { + } catch (ClassNotFoundException e) { e.printStackTrace(); - } catch (NoSuchMethodException | IllegalAccessException | InstantiationException | InvocationTargetException e) { + } catch (NoSuchMethodException + | IllegalAccessException + | InstantiationException + | InvocationTargetException e) { e.printStackTrace(); } tokenAuthorizationEnabled = true; } - public void dumpKeyLogs(){ + public void dumpKeyLogs() { String managementKey = ""; String inferenceKey = ""; String apiKey = ""; @@ -900,13 +896,21 @@ public void dumpKeyLogs(){ logger.info("MANAGEMENT KEY: " + managementKey); logger.info("INFERNCE KEY: " + inferenceKey); logger.info("API KEY: " + apiKey); - logger.info("MANAGEMENT API Example: curl http://localhost:8081/models/ -H \"Authorization: Bearer " + managementKey+"\""); - logger.info("INFERNCE API Example: curl http://127.0.0.1:8080/predictions/ -T -H \"Authorization: Bearer " + inferenceKey + "\""); - logger.info("API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer " + apiKey + "\""); - - } - - public boolean isTokenEnabled(){ + logger.info( + "MANAGEMENT API Example: curl http://localhost:8081/models/ -H \"Authorization: Bearer " + + managementKey + + "\""); + logger.info( + "INFERNCE API Example: curl http://127.0.0.1:8080/predictions/ -T -H \"Authorization: Bearer " + + inferenceKey + + "\""); + logger.info( + "API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer " + + apiKey + + "\""); + } + + public boolean isTokenEnabled() { return tokenAuthorizationEnabled; } @@ -918,15 +922,20 @@ public boolean isTokenEnabled(){ public void checkTokenAuthorization(FullHttpRequest req, Integer requestType) throws ModelException { - if (tokenAuthorizationEnabled){ + if (tokenAuthorizationEnabled) { try { - Method method = tokenClass.getMethod("checkTokenAuthorization", io.netty.handler.codec.http.FullHttpRequest.class, Integer.class); - boolean result = (boolean)(method.invoke(tokenObject, req, requestType)); - if (!result){ - throw new InvalidKeyException("Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); + Method method = + tokenClass.getMethod( + "checkTokenAuthorization", + io.netty.handler.codec.http.FullHttpRequest.class, + Integer.class); + boolean result = (boolean) (method.invoke(tokenObject, req, requestType)); + if (!result) { + throw new InvalidKeyException( + "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); } System.out.println("TOKEN AUTHORIZATION WORKED"); - }catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { e.printStackTrace(); } } From fa32c282d3b91b577da17b3c1691bc3208c3b86a Mon Sep 17 00:00:00 2001 From: udaij12 Date: Tue, 30 Jan 2024 14:20:03 -0800 Subject: [PATCH 12/28] token handler --- docs/token_authorization_api.md | 14 +-- .../model/{s3 => }/InvalidKeyException.java | 0 .../java/org/pytorch/serve/ModelServer.java | 5 - .../org/pytorch/serve/ServerInitializer.java | 8 ++ .../serve/http/TokenAuthorizationHandler.java | 107 ++++++++++++++++++ .../api/rest/InferenceRequestHandler.java | 2 - .../api/rest/ManagementRequestHandler.java | 6 - .../serve/servingsdk/impl/PluginsManager.java | 4 + .../org/pytorch/serve/util/ConfigManager.java | 91 +-------------- .../org/pytorch/serve/util/TokenType.java | 7 ++ plugins/endpoints/build.gradle | 2 +- .../pytorch/serve/plugins/endpoint/Token.java | 86 +++++++++----- ts/model_server.py | 2 +- 13 files changed, 199 insertions(+), 135 deletions(-) rename frontend/archive/src/main/java/org/pytorch/serve/archive/model/{s3 => }/InvalidKeyException.java (100%) create mode 100644 frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java create mode 100644 frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index 5ac1c0472b..7e2b3ce982 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -1,8 +1,8 @@ # TorchServe token authorization API -## Customer Use +## Configuration 1. Enable token authorization by adding the provided plugin at start using the `--plugin-path` command. -2. Torchserve will enable token authorization if the plugin is provided. In the model server home folder a file `key_file.txt` will be generated. +2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.txt` will be generated. 1. Example key file: `Management Key: aadJv_R6 --- Expiration time: 2024-01-16T22:23:32.952499Z` @@ -11,19 +11,19 @@ `API Key: xryL_Vzs` 3. There are 3 keys and each have a different use. - 1. Management key: Used for management apis. Example: + 1. Management key: Used for management APIs. Example: `curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer aadJv_R6"` - 2. Inference key: Used for inference apis. Example: + 2. Inference key: Used for inference APIs. Example: `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer poZXAlqe"` - 3. API key: Used for the token authorization api. Check section 4 for api use. + 3. API key: Used for the token authorization API. Check section 4 for API use. 4. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models. -4. The plugin also includes an api in order to generate a new key to replace either the management or inference key. +4. The plugin also includes an API in order to generate a new key to replace either the management or inference key. 1. Management Example: `curl localhost:8081/token?type=management -H "Authorization: Bearer xryL_Vzs"` will replace the current management key in the key_file with a new one and will update the expiration time. 2. Inference example: `curl localhost:8081/token?type=inference -H "Authorization: Bearer xryL_Vzs"` - Users will have to use either one of the apis above. + Users will have to use either one of the APIs above. 5. When users shut down the server the key_file will be deleted. diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java similarity index 100% rename from frontend/archive/src/main/java/org/pytorch/serve/archive/model/s3/InvalidKeyException.java rename to frontend/archive/src/main/java/org/pytorch/serve/archive/model/InvalidKeyException.java diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index 084c57c593..b7c9a2823d 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -83,11 +83,6 @@ public static void main(String[] args) { ConfigManager.init(arguments); ConfigManager configManager = ConfigManager.getInstance(); PluginsManager.getInstance().initialize(); - Map plugins = - PluginsManager.getInstance().getManagementEndpoints(); - if (plugins.containsKey("token")) { - configManager.setupTokenClass(); - } MetricCache.init(); InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); ModelServer modelServer = new ModelServer(configManager); diff --git a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java index 51c5c1787a..e69806dcef 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java @@ -10,6 +10,7 @@ import org.pytorch.serve.http.HttpRequestHandler; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.http.InvalidRequestHandler; +import org.pytorch.serve.http.TokenAuthorizationHandler; import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler; import org.pytorch.serve.http.api.rest.InferenceRequestHandler; import org.pytorch.serve.http.api.rest.ManagementRequestHandler; @@ -17,6 +18,7 @@ import org.pytorch.serve.servingsdk.impl.PluginsManager; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.ConnectorType; +import org.pytorch.serve.util.TokenType; import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler; import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler; @@ -59,6 +61,9 @@ public void initChannel(Channel ch) { HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler; if (ConnectorType.ALL.equals(connectorType) || ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) { + httpRequestHandlerChain = + httpRequestHandlerChain.setNextHandler( + new TokenAuthorizationHandler(TokenType.INFERENCE)); httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new InferenceRequestHandler( @@ -68,6 +73,9 @@ public void initChannel(Channel ch) { } if (ConnectorType.ALL.equals(connectorType) || ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) { + httpRequestHandlerChain = + httpRequestHandlerChain.setNextHandler( + new TokenAuthorizationHandler(TokenType.MANAGEMENT)); httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new ManagementRequestHandler( diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java new file mode 100644 index 0000000000..62fbdc1f67 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -0,0 +1,107 @@ +package org.pytorch.serve.http; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; +import java.lang.reflect.*; +import org.pytorch.serve.archive.DownloadArchiveException; +import org.pytorch.serve.archive.model.InvalidKeyException; +import org.pytorch.serve.archive.model.ModelException; +import org.pytorch.serve.archive.workflow.WorkflowException; +import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.TokenType; +import org.pytorch.serve.wlm.WorkerInitializationException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A class handling inbound HTTP requests to the inference API. + * + *

This class // + */ +public class TokenAuthorizationHandler extends HttpRequestHandlerChain { + + private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class); + private static TokenType tokenType; + private static Boolean tokenEnabled; + private static Class tokenClass; + private static Object tokenObject; + private static Integer timeToExpirationMinutes = 60; + + /** Creates a new {@code InferenceRequestHandler} instance. */ + public TokenAuthorizationHandler(TokenType type) { + tokenType = type; + } + + @Override + public void handleRequest( + ChannelHandlerContext ctx, + FullHttpRequest req, + QueryStringDecoder decoder, + String[] segments) + throws ModelException, DownloadArchiveException, WorkflowException, + WorkerInitializationException { + ConfigManager configManager = ConfigManager.getInstance(); + if (tokenType == TokenType.MANAGEMENT) { + if (req.toString().contains("/token")) { + checkTokenAuthorization(req, 0); + } else { + checkTokenAuthorization(req, 1); + } + } else if (tokenType == TokenType.INFERENCE) { + checkTokenAuthorization(req, 2); + } + chain.handleRequest(ctx, req, decoder, segments); + } + + public static void setupTokenClass() { + try { + tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); + tokenObject = tokenClass.getDeclaredConstructor().newInstance(); + Method method = tokenClass.getMethod("setTime", Integer.class); + Integer time = ConfigManager.getInstance().getTimeToExpiration(); + if (time == 0) { + timeToExpirationMinutes = time; + } + method.invoke(tokenObject, timeToExpirationMinutes); + method = tokenClass.getMethod("generateKeyFile", Integer.class); + if ((boolean) method.invoke(tokenObject, 0)) { + logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY"); + } + } catch (ClassNotFoundException e) { + logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); + e.printStackTrace(); + return; + } catch (NoSuchMethodException + | IllegalAccessException + | InstantiationException + | InvocationTargetException e) { + e.printStackTrace(); + logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); + return; + } + tokenEnabled = true; + } + + private void checkTokenAuthorization(FullHttpRequest req, Integer type) throws ModelException { + + if (tokenEnabled) { + try { + Method method = + tokenClass.getMethod( + "checkTokenAuthorization", + io.netty.handler.codec.http.FullHttpRequest.class, + Integer.class); + boolean result = (boolean) (method.invoke(tokenObject, req, type)); + if (!result) { + throw new InvalidKeyException( + "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); + } + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); + throw new InvalidKeyException( + "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); + } + } + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java index 34c0e49261..363717cb7f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java @@ -59,8 +59,6 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); - configManager.checkTokenAuthorization(req, 2); if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index 55dc8efd89..29a6f156cf 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -32,7 +32,6 @@ import org.pytorch.serve.openapi.OpenApiUtils; import org.pytorch.serve.servingsdk.ModelServerEndpoint; import org.pytorch.serve.util.ApiUtils; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.NettyUtils; import org.pytorch.serve.util.messages.RequestInput; @@ -62,15 +61,10 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { - if (req.toString().contains("/token")) { - configManager.checkTokenAuthorization(req, 0); - } handleCustomEndpoint(ctx, req, segments, decoder); } else { - configManager.checkTokenAuthorization(req, 1); if (!"models".equals(segments[1])) { throw new ResourceNotFoundException(); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java b/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java index bac20bf32d..fdf6b01dd0 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java @@ -5,6 +5,7 @@ import java.util.Map; import java.util.ServiceLoader; import org.pytorch.serve.http.InvalidPluginException; +import org.pytorch.serve.http.TokenAuthorizationHandler; import org.pytorch.serve.servingsdk.ModelServerEndpoint; import org.pytorch.serve.servingsdk.annotations.Endpoint; import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes; @@ -30,6 +31,9 @@ public void initialize() { logger.info("Initializing plugins manager..."); inferenceEndpoints = initInferenceEndpoints(); managementEndpoints = initManagementEndpoints(); + if (managementEndpoints.containsKey("token")) { + TokenAuthorizationHandler.setupTokenClass(); + } } private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index cc273dc4ac..4f8a7d3404 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -2,7 +2,6 @@ import com.google.gson.JsonObject; import com.google.gson.reflect.TypeToken; -import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -46,8 +45,6 @@ import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; -import org.pytorch.serve.archive.model.InvalidKeyException; -import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.metrics.MetricBuilder; import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; import org.pytorch.serve.snapshot.SnapshotSerializerFactory; @@ -850,95 +847,15 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - // Imports the token class and sets the expiration time either default or custom - // calls generate key file in token api to create 3 keys and logs the result - public void setupTokenClass() { - try { - tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); - tokenObject = tokenClass.getDeclaredConstructor().newInstance(); - Method method = tokenClass.getMethod("setTime", Integer.class); - if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { - timeToExpiration = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); - } - method.invoke(tokenObject, timeToExpiration); - method = tokenClass.getMethod("generateKeyFile", Integer.class); - if ((boolean) method.invoke(tokenObject, 0)) { - System.out.println("TOKEN CLASS IMPORTED SUCCESSFULLY"); - dumpKeyLogs(); - } - } catch (ClassNotFoundException e) { - e.printStackTrace(); - } catch (NoSuchMethodException - | IllegalAccessException - | InstantiationException - | InvocationTargetException e) { - e.printStackTrace(); - } - tokenAuthorizationEnabled = true; - } - - public void dumpKeyLogs() { - String managementKey = ""; - String inferenceKey = ""; - String apiKey = ""; - try { - Method method = tokenClass.getMethod("getManagementKey"); - managementKey = (String) method.invoke(tokenObject); - method = tokenClass.getMethod("getInferenceKey"); - inferenceKey = (String) method.invoke(tokenObject); - method = tokenClass.getMethod("getKey"); - apiKey = (String) method.invoke(tokenObject); - } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { - e.printStackTrace(); - } - - logger.info("KEY FILE PATH: " + System.getProperty("user.dir") + "/key_file.txt"); - logger.info("MANAGEMENT KEY: " + managementKey); - logger.info("INFERNCE KEY: " + inferenceKey); - logger.info("API KEY: " + apiKey); - logger.info( - "MANAGEMENT API Example: curl http://localhost:8081/models/ -H \"Authorization: Bearer " - + managementKey - + "\""); - logger.info( - "INFERNCE API Example: curl http://127.0.0.1:8080/predictions/ -T -H \"Authorization: Bearer " - + inferenceKey - + "\""); - logger.info( - "API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer " - + apiKey - + "\""); - } - public boolean isTokenEnabled() { return tokenAuthorizationEnabled; } - // Calls the checkTokenAuthorization function in the token plugin - // expects two inputs: the fullhttpRequest and an integer which is associated with the type - // 0: token api - // 1: management api - // 2: inference api - public void checkTokenAuthorization(FullHttpRequest req, Integer requestType) - throws ModelException { - - if (tokenAuthorizationEnabled) { - try { - Method method = - tokenClass.getMethod( - "checkTokenAuthorization", - io.netty.handler.codec.http.FullHttpRequest.class, - Integer.class); - boolean result = (boolean) (method.invoke(tokenObject, req, requestType)); - if (!result) { - throw new InvalidKeyException( - "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); - } - System.out.println("TOKEN AUTHORIZATION WORKED"); - } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { - e.printStackTrace(); - } + public Integer getTimeToExpiration() { + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { + return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); } + return 0; } public boolean isSSLEnabled(ConnectorType connectorType) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java new file mode 100644 index 0000000000..9bf6318d2e --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java @@ -0,0 +1,7 @@ +package org.pytorch.serve.util; + +public enum TokenType { + INFERENCE, + MANAGEMENT, + TOKEN_API +} diff --git a/plugins/endpoints/build.gradle b/plugins/endpoints/build.gradle index 36c5191a6d..a162cc2ae8 100644 --- a/plugins/endpoints/build.gradle +++ b/plugins/endpoints/build.gradle @@ -1,6 +1,7 @@ dependencies { implementation "com.google.code.gson:gson:${gson_version}" implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}" + implementation "io.netty:netty-all:4.1.53.Final" } project.ext{ @@ -16,4 +17,3 @@ jar { exclude "META-INF//LICENSE*" exclude "META-INF//NOTICE*" } - diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index d2aa37b138..e7e2f7cbe0 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -1,18 +1,24 @@ package org.pytorch.serve.plugins.endpoint; // import java.util.Properties; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonArray; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.attribute.PosixFilePermission; +import java.nio.file.attribute.PosixFilePermissions; import java.security.SecureRandom; import java.time.Instant; import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.pytorch.serve.servingsdk.Context; import org.pytorch.serve.servingsdk.ModelServerEndpoint; @@ -21,6 +27,8 @@ import org.pytorch.serve.servingsdk.http.Request; import org.pytorch.serve.servingsdk.http.Response; +// import org.pytorch.serve.util.TokenType; + @Endpoint( urlPattern = "token", endpointType = EndpointTypes.MANAGEMENT, @@ -29,9 +37,9 @@ public class Token extends ModelServerEndpoint { private static String apiKey; private static String managementKey; private static String inferenceKey; - private static Instant managementExpirationTime; - private static Instant inferenceExpirationTime; - private static Integer timeToExpiration = 30; + private static Instant managementExpirationTimeMinutes; + private static Instant inferenceExpirationTimeMinutes; + private static Integer timeToExpirationMinutes = 60; private SecureRandom secureRandom = new SecureRandom(); private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); @@ -70,8 +78,8 @@ public String generateKey() { return baseEncoder.encodeToString(randomBytes); } - public Instant generateTokenExpiration(Integer time) { - return Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(time)); + public Instant generateTokenExpiration() { + return Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpirationMinutes)); } // generates a key file with new keys depending on the parameter provided @@ -79,8 +87,7 @@ public Instant generateTokenExpiration(Integer time) { // 1: generates management key and keeps other 2 the same // 2: generates inference key and keeps other 2 the same public boolean generateKeyFile(Integer keyCase) throws IOException { - String fileData = " "; - String userDirectory = System.getProperty("user.dir") + "/key_file.txt"; + String userDirectory = System.getProperty("user.dir") + "/key_file.json"; File file = new File(userDirectory); if (!file.createNewFile() && !file.exists()) { return false; @@ -91,51 +98,78 @@ public boolean generateKeyFile(Integer keyCase) throws IOException { switch (keyCase) { case 1: managementKey = generateKey(); - managementExpirationTime = generateTokenExpiration(timeToExpiration); + managementExpirationTimeMinutes = generateTokenExpiration(); break; case 2: inferenceKey = generateKey(); - inferenceExpirationTime = generateTokenExpiration(timeToExpiration); + inferenceExpirationTimeMinutes = generateTokenExpiration(); break; default: managementKey = generateKey(); inferenceKey = generateKey(); - inferenceExpirationTime = generateTokenExpiration(timeToExpiration); - managementExpirationTime = generateTokenExpiration(timeToExpiration); + inferenceExpirationTimeMinutes = generateTokenExpiration(); + managementExpirationTimeMinutes = generateTokenExpiration(); } - fileData = + JsonArray jsonArray = new JsonArray(); + jsonArray.add( "Management Key: " + managementKey + " --- Expiration time: " - + managementExpirationTime - + "\nInference Key: " + + managementExpirationTimeMinutes); + jsonArray.add( + "Inference Key: " + inferenceKey + " --- Expiration time: " - + inferenceExpirationTime - + "\nAPI Key: " - + apiKey - + "\n"; - Files.write(Paths.get("key_file.txt"), fileData.getBytes()); + + inferenceExpirationTimeMinutes); + jsonArray.add("API Key: " + apiKey); + + Files.write( + Paths.get("key_file.json"), + new GsonBuilder() + .setPrettyPrinting() + .create() + .toJson(jsonArray) + .getBytes(StandardCharsets.UTF_8)); + + if (!setFilePermissions()) { + try { + Files.delete(Paths.get("key_file.txt")); + } catch (IOException e) { + return false; + } + return false; + } + return true; + } + + public boolean setFilePermissions() { + Path path = Paths.get("key_file.json"); + try { + Set permissions = PosixFilePermissions.fromString("rw-------"); + Files.setPosixFilePermissions(path, permissions); + } catch (Exception e) { + return false; + } return true; } // checks the token provided in the http with the saved keys depening on parameters - public boolean checkTokenAuthorization(FullHttpRequest req, Integer keyCase) { + public boolean checkTokenAuthorization(FullHttpRequest req, Integer type) { String key; Instant expiration; - switch (keyCase) { + switch (type) { case 0: key = apiKey; expiration = null; break; case 1: key = managementKey; - expiration = managementExpirationTime; + expiration = managementExpirationTimeMinutes; break; default: key = inferenceKey; - expiration = inferenceExpirationTime; + expiration = inferenceExpirationTimeMinutes; } String tokenBearer = req.headers().get("Authorization"); @@ -175,14 +209,14 @@ public String getKey() { } public Instant getInferenceExpirationTime() { - return inferenceExpirationTime; + return inferenceExpirationTimeMinutes; } public Instant getManagementExpirationTime() { - return managementExpirationTime; + return managementExpirationTimeMinutes; } public void setTime(Integer time) { - timeToExpiration = time; + timeToExpirationMinutes = time; } } diff --git a/ts/model_server.py b/ts/model_server.py index b2c5fa3948..53c853f672 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -49,7 +49,7 @@ def start() -> None: parent = psutil.Process(pid) parent.terminate() try: - os.remove(os.getcwd() + "/key_file.txt") + os.remove(os.getcwd() + "/key_file.json") except FileNotFoundError: print("Token authorization not enabled") if args.foreground: From c5c0f773e2627d2fec1fad193ff1c1a49244f051 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Tue, 30 Jan 2024 15:09:26 -0800 Subject: [PATCH 13/28] fix doc --- docs/token_authorization_api.md | 2 +- .../src/main/java/org/pytorch/serve/util/ConfigManager.java | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index 7e2b3ce982..824c2f89b5 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -2,7 +2,7 @@ ## Configuration 1. Enable token authorization by adding the provided plugin at start using the `--plugin-path` command. -2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.txt` will be generated. +2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated. 1. Example key file: `Management Key: aadJv_R6 --- Expiration time: 2024-01-16T22:23:32.952499Z` diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 4f8a7d3404..7cd06a1c27 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -21,7 +21,6 @@ import java.security.KeyFactory; import java.security.KeyStore; import java.security.PrivateKey; -import java.security.SecureRandom; import java.security.cert.Certificate; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; @@ -149,8 +148,6 @@ public final class ConfigManager { private boolean telemetryEnabled; private Logger logger = LoggerFactory.getLogger(ConfigManager.class); - private SecureRandom secureRandom = new SecureRandom(); - private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); private boolean tokenAuthorizationEnabled; private Class tokenClass; private Object tokenObject; From 3b7fe29db0b17d439455d7f70182281a5675b2f6 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 31 Jan 2024 11:04:53 -0800 Subject: [PATCH 14/28] fixed handler --- .../java/org/pytorch/serve/http/TokenAuthorizationHandler.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index 62fbdc1f67..aa192c0609 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -23,7 +23,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain { private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class); private static TokenType tokenType; - private static Boolean tokenEnabled; + private static Boolean tokenEnabled = false; private static Class tokenClass; private static Object tokenObject; private static Integer timeToExpirationMinutes = 60; From 0d7afe366ecf6f18e148f74bb5f8aa3c5b404d93 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Thu, 1 Feb 2024 15:29:42 -0800 Subject: [PATCH 15/28] added integration tests --- .../serve/http/TokenAuthorizationHandler.java | 18 +- .../pytorch/serve/plugins/endpoint/Token.java | 21 +- test/pytest/test_token_authorization.py | 192 ++++++++++++++++++ test/pytest/test_utils.py | 8 +- 4 files changed, 217 insertions(+), 22 deletions(-) create mode 100644 test/pytest/test_token_authorization.py diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index aa192c0609..a038696330 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -44,12 +44,12 @@ public void handleRequest( ConfigManager configManager = ConfigManager.getInstance(); if (tokenType == TokenType.MANAGEMENT) { if (req.toString().contains("/token")) { - checkTokenAuthorization(req, 0); + checkTokenAuthorization(req, "token"); } else { - checkTokenAuthorization(req, 1); + checkTokenAuthorization(req, "management"); } } else if (tokenType == TokenType.INFERENCE) { - checkTokenAuthorization(req, 2); + checkTokenAuthorization(req, "inference"); } chain.handleRequest(ctx, req, decoder, segments); } @@ -64,8 +64,8 @@ public static void setupTokenClass() { timeToExpirationMinutes = time; } method.invoke(tokenObject, timeToExpirationMinutes); - method = tokenClass.getMethod("generateKeyFile", Integer.class); - if ((boolean) method.invoke(tokenObject, 0)) { + method = tokenClass.getMethod("generateKeyFile", String.class); + if ((boolean) method.invoke(tokenObject, "token")) { logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY"); } } catch (ClassNotFoundException e) { @@ -83,7 +83,7 @@ public static void setupTokenClass() { tokenEnabled = true; } - private void checkTokenAuthorization(FullHttpRequest req, Integer type) throws ModelException { + private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException { if (tokenEnabled) { try { @@ -91,16 +91,16 @@ private void checkTokenAuthorization(FullHttpRequest req, Integer type) throws M tokenClass.getMethod( "checkTokenAuthorization", io.netty.handler.codec.http.FullHttpRequest.class, - Integer.class); + String.class); boolean result = (boolean) (method.invoke(tokenObject, req, type)); if (!result) { throw new InvalidKeyException( - "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); + "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); } } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { e.printStackTrace(); throw new InvalidKeyException( - "Token Authenticaation failed. Token either incorrect, expired, or not provided correctly"); + "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); } } } diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index e7e2f7cbe0..acd62cf708 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -48,9 +48,9 @@ public void doGet(Request req, Response rsp, Context ctx) throws IOException { String queryResponse = parseQuery(req); String test = ""; if ("management".equals(queryResponse)) { - generateKeyFile(1); + generateKeyFile("management"); } else if ("inference".equals(queryResponse)) { - generateKeyFile(2); + generateKeyFile("inference"); } else { test = "{\n\t\"Error\": " + queryResponse + "\n}\n"; } @@ -83,10 +83,7 @@ public Instant generateTokenExpiration() { } // generates a key file with new keys depending on the parameter provided - // 0: generates all 3 keys - // 1: generates management key and keeps other 2 the same - // 2: generates inference key and keeps other 2 the same - public boolean generateKeyFile(Integer keyCase) throws IOException { + public boolean generateKeyFile(String type) throws IOException { String userDirectory = System.getProperty("user.dir") + "/key_file.json"; File file = new File(userDirectory); if (!file.createNewFile() && !file.exists()) { @@ -95,12 +92,12 @@ public boolean generateKeyFile(Integer keyCase) throws IOException { if (apiKey == null) { apiKey = generateKey(); } - switch (keyCase) { - case 1: + switch (type) { + case "management": managementKey = generateKey(); managementExpirationTimeMinutes = generateTokenExpiration(); break; - case 2: + case "inference": inferenceKey = generateKey(); inferenceExpirationTimeMinutes = generateTokenExpiration(); break; @@ -155,15 +152,15 @@ public boolean setFilePermissions() { } // checks the token provided in the http with the saved keys depening on parameters - public boolean checkTokenAuthorization(FullHttpRequest req, Integer type) { + public boolean checkTokenAuthorization(FullHttpRequest req, String type) { String key; Instant expiration; switch (type) { - case 0: + case "token": key = apiKey; expiration = null; break; - case 1: + case "management": key = managementKey; expiration = managementExpirationTimeMinutes; break; diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py new file mode 100644 index 0000000000..a2a04e79b5 --- /dev/null +++ b/test/pytest/test_token_authorization.py @@ -0,0 +1,192 @@ +import json +import os +import shutil +import subprocess +import tempfile +import time + +import requests +import test_utils + +ROOT_DIR = os.path.join(tempfile.gettempdir(), "workspace") +REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") +data_file_kitten = os.path.join(REPO_ROOT, "test/pytest/test_data/kitten.jpg") + + +# Set up token plugin +def get_plugin_jar(): + new_folder_path = os.path.join(ROOT_DIR, "plugins-path") + plugin_folder = os.path.join(REPO_ROOT, "plugins") + os.makedirs(new_folder_path, exist_ok=True) + os.chdir(plugin_folder) + subprocess.run(["./gradlew", "build"]) + jar_path = os.path.join(plugin_folder, "endpoints/build/libs") + jar_file = [file for file in os.listdir(jar_path) if file.endswith(".jar")] + if jar_file: + shutil.move( + os.path.join(jar_path, jar_file[0]), + os.path.join(new_folder_path, jar_file[0]), + ) + os.chdir(REPO_ROOT) + result = subprocess.run( + f"python ts_scripts/install_from_source", + shell=True, + capture_output=True, + text=True, + ) + + +# Parse json file and return key +def read_key_file(type): + json_file_path = os.path.join(REPO_ROOT, "key_file.json") + with open(json_file_path) as json_file: + json_data = json.load(json_file) + + # Extract the three keys + management_key = None + inference_key = None + api_key = None + for key_string in json_data: + if "Management Key" in key_string: + management_key = key_string.split(":")[1].strip().split("---")[0].strip() + elif "Inference Key" in key_string: + inference_key = key_string.split(":")[1].strip().split("---")[0].strip() + elif "API Key" in key_string: + api_key = key_string.split(":")[1].strip().split("---")[0].strip() + + options = { + "management": management_key, + "inference": inference_key, + "token": api_key, + } + key = options.get(type, "Invalid data type") + return key + + +def setup_torchserve(): + get_plugin_jar() + MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") + PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") + + test_utils.start_torchserve(no_config_snapshots=True, plugin_folder=PLUGIN_STORE) + time.sleep(10) + + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + + params = ( + ("model_name", "resnet18"), + ("url", "resnet-18.mar"), + ("initial_workers", "1"), + ("synchronous", "true"), + ) + response = requests.post( + "http://localhost:8081/models", params=params, headers=header + ) + time.sleep(5) + + result = subprocess.run( + f"cat {REPO_ROOT}/key_file.json", + shell=True, + capture_output=True, + text=True, + ) + print("Curl output:") + print(result.stdout) + + +# Test describe model API with token enabled +def test_managament_api_with_token(): + test_utils.stop_torchserve() + setup_torchserve() + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + response = requests.get(f"http://localhost:8081/models/resnet18", headers=header) + time.sleep(5) + print(response.text) + + assert response.status_code == 200, "Token check failed" + + +# Test describe model API with incorrect token and no token +def test_managament_api_with_incorrect_token(): + # Using random key + header = {"Authorization": "Bearer abcd1234"} + + response = requests.get(f"http://localhost:8081/models/resnet18", headers=header) + time.sleep(5) + print(response.text) + + assert response.status_code == 400, "Token check failed" + + response = requests.get(f"http://localhost:8081/models/resnet18") + time.sleep(5) + print(response.text) + + assert response.status_code == 400, "Token check failed" + + +# Test inference API with token enabled +def test_inference_api_with_token(): + key = read_key_file("inference") + header = {"Authorization": f"Bearer {key}"} + + response = requests.post( + url="http://localhost:8080/predictions/resnet18", + files={"data": open(data_file_kitten, "rb")}, + headers=header, + ) + time.sleep(5) + print(response.text) + + assert response.status_code == 200, "Token check failed" + + +# Test inference API with incorrect token +def test_inference_api_with_incorrect_token(): + # Using random key + header = {"Authorization": "Bearer abcd1234"} + + response = requests.post( + url="http://localhost:8080/predictions/resnet18", + files={"data": open(data_file_kitten, "rb")}, + headers=header, + ) + time.sleep(5) + print(response.text) + + assert response.status_code == 400, "Token check failed" + + +# Test Token API for regenerating new inference key +def test_token_inference_api(): + token_key = read_key_file("token") + inference_key = read_key_file("inference") + header = {"Authorization": f"Bearer {token_key}"} + params = {"type": "inference"} + + response = requests.get( + url="http://localhost:8081/token", params=params, headers=header + ) + time.sleep(5) + print(response.text) + + assert response.status_code == 200, "Token check failed" + assert inference_key != read_key_file("inference"), "Key file not updated" + + +# Test Token API for regenerating new management key +def test_token_management_api(): + token_key = read_key_file("token") + management_key = read_key_file("management") + header = {"Authorization": f"Bearer {token_key}"} + params = {"type": "management"} + + response = requests.get( + url="http://localhost:8081/token", params=params, headers=header + ) + time.sleep(5) + + assert management_key != read_key_file("management"), "Key file not updated" + assert response.status_code == 200, "Token check failed" + test_utils.stop_torchserve() diff --git a/test/pytest/test_utils.py b/test/pytest/test_utils.py index 417bba460c..70317b27df 100644 --- a/test/pytest/test_utils.py +++ b/test/pytest/test_utils.py @@ -54,7 +54,11 @@ def run(self): def start_torchserve( - model_store=None, snapshot_file=None, no_config_snapshots=False, gen_mar=True + model_store=None, + snapshot_file=None, + no_config_snapshots=False, + gen_mar=True, + plugin_folder=None, ): stop_torchserve() crate_mar_file_table() @@ -63,6 +67,8 @@ def start_torchserve( if gen_mar: mg.gen_mar(model_store) cmd.extend(["--model-store", model_store]) + if plugin_folder: + cmd.extend(["--plugins-path", plugin_folder]) if snapshot_file: cmd.extend(["--ts-config", snapshot_file]) if no_config_snapshots: From 13caa27257f45dacfcd99f268035bc119b856027 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Thu, 1 Feb 2024 19:55:51 -0800 Subject: [PATCH 16/28] Added integration tests --- .../serve/http/TokenAuthorizationHandler.java | 6 +-- .../org/pytorch/serve/util/ConfigManager.java | 16 +++---- .../pytorch/serve/plugins/endpoint/Token.java | 6 +-- .../gradle/wrapper/gradle-wrapper.properties | 2 +- plugins/settings.gradle | 3 +- test/pytest/test_token_authorization.py | 43 ++++++++++++++----- 6 files changed, 45 insertions(+), 31 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index a038696330..4f0fe25cec 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -15,7 +15,7 @@ import org.slf4j.LoggerFactory; /** - * A class handling inbound HTTP requests to the inference API. + * A class handling token check for all inbound HTTP requests. * *

This class // */ @@ -26,7 +26,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain { private static Boolean tokenEnabled = false; private static Class tokenClass; private static Object tokenObject; - private static Integer timeToExpirationMinutes = 60; + private static Integer timeToExpirationMinutes = 180; /** Creates a new {@code InferenceRequestHandler} instance. */ public TokenAuthorizationHandler(TokenType type) { @@ -60,7 +60,7 @@ public static void setupTokenClass() { tokenObject = tokenClass.getDeclaredConstructor().newInstance(); Method method = tokenClass.getMethod("setTime", Integer.class); Integer time = ConfigManager.getInstance().getTimeToExpiration(); - if (time == 0) { + if (time != 0) { timeToExpirationMinutes = time; } method.invoke(tokenObject, timeToExpirationMinutes); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 2df7c68bd1..c7f0c695d6 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -8,7 +8,6 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.*; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -151,11 +150,6 @@ public final class ConfigManager { private boolean telemetryEnabled; private Logger logger = LoggerFactory.getLogger(ConfigManager.class); - private boolean tokenAuthorizationEnabled; - private Class tokenClass; - private Object tokenObject; - private Integer timeToExpiration = 60; - private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -866,13 +860,13 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - public boolean isTokenEnabled() { - return tokenAuthorizationEnabled; - } - public Integer getTimeToExpiration() { if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { - return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); + try { + return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); + } catch (NumberFormatException e) { + logger.error("Token expiration not a valid integer"); + } } return 0; } diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index acd62cf708..8c779cbd83 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -39,7 +39,7 @@ public class Token extends ModelServerEndpoint { private static String inferenceKey; private static Instant managementExpirationTimeMinutes; private static Instant inferenceExpirationTimeMinutes; - private static Integer timeToExpirationMinutes = 60; + private static Integer timeToExpirationMinutes; private SecureRandom secureRandom = new SecureRandom(); private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); @@ -57,7 +57,7 @@ public void doGet(Request req, Response rsp, Context ctx) throws IOException { rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8)); } - // parses query and either returns "management"/"inference" or a wrong type error + // parses query and either returns management/inference or a wrong type error public String parseQuery(Request req) { QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI()); Map> parameters = decoder.parameters(); @@ -175,7 +175,7 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) { } String[] arrOfStr = tokenBearer.split(" ", 2); if (arrOfStr.length == 1) { - return false; + return false; } String token = arrOfStr[1]; diff --git a/plugins/gradle/wrapper/gradle-wrapper.properties b/plugins/gradle/wrapper/gradle-wrapper.properties index 7191a69876..62036368e7 100644 --- a/plugins/gradle/wrapper/gradle-wrapper.properties +++ b/plugins/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.4-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.3-all.zip diff --git a/plugins/settings.gradle b/plugins/settings.gradle index fd3dc1373c..1e949f6cf7 100644 --- a/plugins/settings.gradle +++ b/plugins/settings.gradle @@ -9,5 +9,4 @@ rootProject.name = 'plugins' include 'endpoints' -include 'DDBEndPoint' - +// include 'DDBEndPoint' diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index a2a04e79b5..d12eee8a0b 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -19,7 +19,8 @@ def get_plugin_jar(): plugin_folder = os.path.join(REPO_ROOT, "plugins") os.makedirs(new_folder_path, exist_ok=True) os.chdir(plugin_folder) - subprocess.run(["./gradlew", "build"]) + subprocess.run(["./gradlew", "formatJava"]) + result = subprocess.run(["./gradlew", "build"]) jar_path = os.path.join(plugin_folder, "endpoints/build/libs") jar_file = [file for file in os.listdir(jar_path) if file.endswith(".jar")] if jar_file: @@ -84,6 +85,8 @@ def setup_torchserve(): "http://localhost:8081/models", params=params, headers=header ) time.sleep(5) + print("register reponse") + print(response.text) result = subprocess.run( f"cat {REPO_ROOT}/key_file.json", @@ -101,7 +104,8 @@ def test_managament_api_with_token(): setup_torchserve() key = read_key_file("management") header = {"Authorization": f"Bearer {key}"} - response = requests.get(f"http://localhost:8081/models/resnet18", headers=header) + print(key) + response = requests.get("http://localhost:8081/models/resnet18", headers=header) time.sleep(5) print(response.text) @@ -119,12 +123,6 @@ def test_managament_api_with_incorrect_token(): assert response.status_code == 400, "Token check failed" - response = requests.get(f"http://localhost:8081/models/resnet18") - time.sleep(5) - print(response.text) - - assert response.status_code == 400, "Token check failed" - # Test inference API with token enabled def test_inference_api_with_token(): @@ -138,6 +136,7 @@ def test_inference_api_with_token(): ) time.sleep(5) print(response.text) + print(key) assert response.status_code == 200, "Token check failed" @@ -162,18 +161,38 @@ def test_inference_api_with_incorrect_token(): def test_token_inference_api(): token_key = read_key_file("token") inference_key = read_key_file("inference") - header = {"Authorization": f"Bearer {token_key}"} + header_inference = {"Authorization": f"Bearer {inference_key}"} + header_token = {"Authorization": f"Bearer {token_key}"} params = {"type": "inference"} + # check inference works with current token + response = requests.post( + url="http://localhost:8080/predictions/resnet18", + files={"data": open(data_file_kitten, "rb")}, + headers=header_inference, + ) + time.sleep(5) + assert response.status_code == 200, "Token check failed" + + # generate new inference token and check it is different response = requests.get( - url="http://localhost:8081/token", params=params, headers=header + url="http://localhost:8081/token", params=params, headers=header_token ) time.sleep(5) print(response.text) - + print(token_key) assert response.status_code == 200, "Token check failed" assert inference_key != read_key_file("inference"), "Key file not updated" + # check inference does not works with original token + response = requests.post( + url="http://localhost:8080/predictions/resnet18", + files={"data": open(data_file_kitten, "rb")}, + headers=header_inference, + ) + time.sleep(5) + assert response.status_code == 400, "Token check failed" + # Test Token API for regenerating new management key def test_token_management_api(): @@ -186,6 +205,8 @@ def test_token_management_api(): url="http://localhost:8081/token", params=params, headers=header ) time.sleep(5) + print(response.text) + print(token_key) assert management_key != read_key_file("management"), "Key file not updated" assert response.status_code == 200, "Token check failed" From a49acf5b71a6f83a5540181448340d034f22c593 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Wed, 14 Feb 2024 12:04:10 -0800 Subject: [PATCH 17/28] updating token auth --- docs/token_authorization_api.md | 2 +- .../serve/http/TokenAuthorizationHandler.java | 46 +++++++++---------- .../org/pytorch/serve/util/ConfigManager.java | 6 +-- ts/model_server.py | 2 +- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index 824c2f89b5..ba92a9a83d 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -30,7 +30,7 @@ ## Customization Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result. -1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration`. Ex:`token_expiration=30` +1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30` 2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is: 1. The urlPattern for the plugin must be 'token' and the class name must not change 2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity. diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index 4f0fe25cec..8675650096 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -26,7 +26,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain { private static Boolean tokenEnabled = false; private static Class tokenClass; private static Object tokenObject; - private static Integer timeToExpirationMinutes = 180; + private static Integer timeToExpirationMinutes = 60; /** Creates a new {@code InferenceRequestHandler} instance. */ public TokenAuthorizationHandler(TokenType type) { @@ -41,15 +41,17 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - ConfigManager configManager = ConfigManager.getInstance(); - if (tokenType == TokenType.MANAGEMENT) { - if (req.toString().contains("/token")) { - checkTokenAuthorization(req, "token"); - } else { - checkTokenAuthorization(req, "management"); + if (tokenEnabled) { + ConfigManager configManager = ConfigManager.getInstance(); + if (tokenType == TokenType.MANAGEMENT) { + if (req.toString().contains("/token")) { + checkTokenAuthorization(req, "token"); + } else { + checkTokenAuthorization(req, "management"); + } + } else if (tokenType == TokenType.INFERENCE) { + checkTokenAuthorization(req, "inference"); } - } else if (tokenType == TokenType.INFERENCE) { - checkTokenAuthorization(req, "inference"); } chain.handleRequest(ctx, req, decoder, segments); } @@ -85,23 +87,21 @@ public static void setupTokenClass() { private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException { - if (tokenEnabled) { - try { - Method method = - tokenClass.getMethod( - "checkTokenAuthorization", - io.netty.handler.codec.http.FullHttpRequest.class, - String.class); - boolean result = (boolean) (method.invoke(tokenObject, req, type)); - if (!result) { - throw new InvalidKeyException( - "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); - } - } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { - e.printStackTrace(); + try { + Method method = + tokenClass.getMethod( + "checkTokenAuthorization", + io.netty.handler.codec.http.FullHttpRequest.class, + String.class); + boolean result = (boolean) (method.invoke(tokenObject, req, type)); + if (!result) { throw new InvalidKeyException( "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); } + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + e.printStackTrace(); + throw new InvalidKeyException( + "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); } } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index c7f0c695d6..05b4324131 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -107,7 +107,7 @@ public final class ConfigManager { private static final String TS_WORKFLOW_STORE = "workflow_store"; private static final String TS_CPP_LOG_CONFIG = "cpp_log_config"; private static final String TS_OPEN_INFERENCE_PROTOCOL = "ts_open_inference_protocol"; - private static final String TS_TOKEN_EXPIRATION_TIME = "token_expiration"; // minutes + private static final String TS_TOKEN_EXPIRATION_TIME_MIN = "token_expiration_min"; // Configuration which are not documented or enabled through environment variables private static final String USE_NATIVE_IO = "use_native_io"; @@ -861,9 +861,9 @@ public boolean isSnapshotDisabled() { } public Integer getTimeToExpiration() { - if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) { + if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) { try { - return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME)); + return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN)); } catch (NumberFormatException e) { logger.error("Token expiration not a valid integer"); } diff --git a/ts/model_server.py b/ts/model_server.py index 6a363340da..eeefd833fd 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -51,7 +51,7 @@ def start() -> None: try: os.remove(os.getcwd() + "/key_file.json") except FileNotFoundError: - print("Token authorization not enabled") + print("Delete key file if it exists") if args.foreground: try: parent.wait(timeout=60) From 6b329b2586ecdbece85560ad16f364774f71cb03 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Thu, 15 Feb 2024 13:38:09 -0800 Subject: [PATCH 18/28] small changes to token auth --- docs/token_authorization_api.md | 2 +- .../serve/http/TokenAuthorizationHandler.java | 19 ++++---- .../org/pytorch/serve/util/ConfigManager.java | 10 +++++ .../pytorch/serve/plugins/endpoint/Token.java | 15 ++++--- test/pytest/test_token_authorization.py | 43 +++++-------------- 5 files changed, 40 insertions(+), 49 deletions(-) diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index ba92a9a83d..ce033f1187 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -16,7 +16,6 @@ 2. Inference key: Used for inference APIs. Example: `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer poZXAlqe"` 3. API key: Used for the token authorization API. Check section 4 for API use. - 4. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models. 4. The plugin also includes an API in order to generate a new key to replace either the management or inference key. 1. Management Example: `curl localhost:8081/token?type=management -H "Authorization: Bearer xryL_Vzs"` will replace the current management key in the key_file with a new one and will update the expiration time. @@ -37,3 +36,4 @@ Torchserve offers various ways to customize the token authorization to allow own ## Notes 1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file. +2. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models. diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index 8675650096..ab6692817e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -15,7 +15,7 @@ import org.slf4j.LoggerFactory; /** - * A class handling token check for all inbound HTTP requests. + * A class handling inbound HTTP requests to the inference API. * *

This class // */ @@ -42,7 +42,7 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { if (tokenEnabled) { - ConfigManager configManager = ConfigManager.getInstance(); + // ConfigManager configManager = ConfigManager.getInstance(); if (tokenType == TokenType.MANAGEMENT) { if (req.toString().contains("/token")) { checkTokenAuthorization(req, "token"); @@ -62,7 +62,7 @@ public static void setupTokenClass() { tokenObject = tokenClass.getDeclaredConstructor().newInstance(); Method method = tokenClass.getMethod("setTime", Integer.class); Integer time = ConfigManager.getInstance().getTimeToExpiration(); - if (time != 0) { + if (time == 0) { timeToExpirationMinutes = time; } method.invoke(tokenObject, timeToExpirationMinutes); @@ -70,17 +70,18 @@ public static void setupTokenClass() { if ((boolean) method.invoke(tokenObject, "token")) { logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY"); } - } catch (ClassNotFoundException e) { - logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); - e.printStackTrace(); - return; + // } catch (ClassNotFoundException e) { + // logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); + // e.printStackTrace(); + // return; } catch (NoSuchMethodException | IllegalAccessException | InstantiationException - | InvocationTargetException e) { + | InvocationTargetException + | ClassNotFoundException e) { e.printStackTrace(); logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); - return; + throw new IllegalStateException("Unable to import token class", e); } tokenEnabled = true; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 05b4324131..38004188c4 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -8,6 +8,7 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.*; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -150,6 +151,11 @@ public final class ConfigManager { private boolean telemetryEnabled; private Logger logger = LoggerFactory.getLogger(ConfigManager.class); + private boolean tokenAuthorizationEnabled; + private Class tokenClass; + private Object tokenObject; + private Integer timeToExpiration = 60; + private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -860,6 +866,10 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } + public boolean isTokenEnabled() { + return tokenAuthorizationEnabled; + } + public Integer getTimeToExpiration() { if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) { try { diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index 8c779cbd83..7cfd1ef997 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -39,9 +39,10 @@ public class Token extends ModelServerEndpoint { private static String inferenceKey; private static Instant managementExpirationTimeMinutes; private static Instant inferenceExpirationTimeMinutes; - private static Integer timeToExpirationMinutes; + private static Integer timeToExpirationMinutes = 60; private SecureRandom secureRandom = new SecureRandom(); private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); + private String fileName = "key_file.json"; @Override public void doGet(Request req, Response rsp, Context ctx) throws IOException { @@ -57,7 +58,7 @@ public void doGet(Request req, Response rsp, Context ctx) throws IOException { rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8)); } - // parses query and either returns management/inference or a wrong type error + // parses query and either returns "management"/"inference" or a wrong type error public String parseQuery(Request req) { QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI()); Map> parameters = decoder.parameters(); @@ -84,7 +85,7 @@ public Instant generateTokenExpiration() { // generates a key file with new keys depending on the parameter provided public boolean generateKeyFile(String type) throws IOException { - String userDirectory = System.getProperty("user.dir") + "/key_file.json"; + String userDirectory = System.getProperty("user.dir") + "/" + fileName; File file = new File(userDirectory); if (!file.createNewFile() && !file.exists()) { return false; @@ -122,7 +123,7 @@ public boolean generateKeyFile(String type) throws IOException { jsonArray.add("API Key: " + apiKey); Files.write( - Paths.get("key_file.json"), + Paths.get(fileName), new GsonBuilder() .setPrettyPrinting() .create() @@ -131,7 +132,7 @@ public boolean generateKeyFile(String type) throws IOException { if (!setFilePermissions()) { try { - Files.delete(Paths.get("key_file.txt")); + Files.delete(Paths.get(fileName)); } catch (IOException e) { return false; } @@ -141,7 +142,7 @@ public boolean generateKeyFile(String type) throws IOException { } public boolean setFilePermissions() { - Path path = Paths.get("key_file.json"); + Path path = Paths.get(fileName); try { Set permissions = PosixFilePermissions.fromString("rw-------"); Files.setPosixFilePermissions(path, permissions); @@ -175,7 +176,7 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) { } String[] arrOfStr = tokenBearer.split(" ", 2); if (arrOfStr.length == 1) { - return false; + return false; } String token = arrOfStr[1]; diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index d12eee8a0b..a2a04e79b5 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -19,8 +19,7 @@ def get_plugin_jar(): plugin_folder = os.path.join(REPO_ROOT, "plugins") os.makedirs(new_folder_path, exist_ok=True) os.chdir(plugin_folder) - subprocess.run(["./gradlew", "formatJava"]) - result = subprocess.run(["./gradlew", "build"]) + subprocess.run(["./gradlew", "build"]) jar_path = os.path.join(plugin_folder, "endpoints/build/libs") jar_file = [file for file in os.listdir(jar_path) if file.endswith(".jar")] if jar_file: @@ -85,8 +84,6 @@ def setup_torchserve(): "http://localhost:8081/models", params=params, headers=header ) time.sleep(5) - print("register reponse") - print(response.text) result = subprocess.run( f"cat {REPO_ROOT}/key_file.json", @@ -104,8 +101,7 @@ def test_managament_api_with_token(): setup_torchserve() key = read_key_file("management") header = {"Authorization": f"Bearer {key}"} - print(key) - response = requests.get("http://localhost:8081/models/resnet18", headers=header) + response = requests.get(f"http://localhost:8081/models/resnet18", headers=header) time.sleep(5) print(response.text) @@ -123,6 +119,12 @@ def test_managament_api_with_incorrect_token(): assert response.status_code == 400, "Token check failed" + response = requests.get(f"http://localhost:8081/models/resnet18") + time.sleep(5) + print(response.text) + + assert response.status_code == 400, "Token check failed" + # Test inference API with token enabled def test_inference_api_with_token(): @@ -136,7 +138,6 @@ def test_inference_api_with_token(): ) time.sleep(5) print(response.text) - print(key) assert response.status_code == 200, "Token check failed" @@ -161,38 +162,18 @@ def test_inference_api_with_incorrect_token(): def test_token_inference_api(): token_key = read_key_file("token") inference_key = read_key_file("inference") - header_inference = {"Authorization": f"Bearer {inference_key}"} - header_token = {"Authorization": f"Bearer {token_key}"} + header = {"Authorization": f"Bearer {token_key}"} params = {"type": "inference"} - # check inference works with current token - response = requests.post( - url="http://localhost:8080/predictions/resnet18", - files={"data": open(data_file_kitten, "rb")}, - headers=header_inference, - ) - time.sleep(5) - assert response.status_code == 200, "Token check failed" - - # generate new inference token and check it is different response = requests.get( - url="http://localhost:8081/token", params=params, headers=header_token + url="http://localhost:8081/token", params=params, headers=header ) time.sleep(5) print(response.text) - print(token_key) + assert response.status_code == 200, "Token check failed" assert inference_key != read_key_file("inference"), "Key file not updated" - # check inference does not works with original token - response = requests.post( - url="http://localhost:8080/predictions/resnet18", - files={"data": open(data_file_kitten, "rb")}, - headers=header_inference, - ) - time.sleep(5) - assert response.status_code == 400, "Token check failed" - # Test Token API for regenerating new management key def test_token_management_api(): @@ -205,8 +186,6 @@ def test_token_management_api(): url="http://localhost:8081/token", params=params, headers=header ) time.sleep(5) - print(response.text) - print(token_key) assert management_key != read_key_file("management"), "Key file not updated" assert response.status_code == 200, "Token check failed" From 0690e0164d567953a8d46809e168bda8a9521511 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Thu, 15 Feb 2024 15:12:17 -0800 Subject: [PATCH 19/28] fixing changes --- .../serve/http/TokenAuthorizationHandler.java | 4 +- .../org/pytorch/serve/util/ConfigManager.java | 10 ----- .../pytorch/serve/plugins/endpoint/Token.java | 4 +- test/pytest/test_token_authorization.py | 43 ++++++++++++++----- 4 files changed, 36 insertions(+), 25 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index ab6692817e..cf5ded3bd8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -15,7 +15,7 @@ import org.slf4j.LoggerFactory; /** - * A class handling inbound HTTP requests to the inference API. + * A class handling token check for all inbound HTTP requests * *

This class // */ @@ -62,7 +62,7 @@ public static void setupTokenClass() { tokenObject = tokenClass.getDeclaredConstructor().newInstance(); Method method = tokenClass.getMethod("setTime", Integer.class); Integer time = ConfigManager.getInstance().getTimeToExpiration(); - if (time == 0) { + if (time != 0) { timeToExpirationMinutes = time; } method.invoke(tokenObject, timeToExpirationMinutes); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 38004188c4..05b4324131 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -8,7 +8,6 @@ import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.*; import java.lang.reflect.Field; import java.lang.reflect.Type; import java.net.InetAddress; @@ -151,11 +150,6 @@ public final class ConfigManager { private boolean telemetryEnabled; private Logger logger = LoggerFactory.getLogger(ConfigManager.class); - private boolean tokenAuthorizationEnabled; - private Class tokenClass; - private Object tokenObject; - private Integer timeToExpiration = 60; - private ConfigManager(Arguments args) throws IOException { prop = new Properties(); @@ -866,10 +860,6 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - public boolean isTokenEnabled() { - return tokenAuthorizationEnabled; - } - public Integer getTimeToExpiration() { if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) { try { diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index 7cfd1ef997..1ee4fd3da8 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -39,7 +39,7 @@ public class Token extends ModelServerEndpoint { private static String inferenceKey; private static Instant managementExpirationTimeMinutes; private static Instant inferenceExpirationTimeMinutes; - private static Integer timeToExpirationMinutes = 60; + private static Integer timeToExpirationMinutes; private SecureRandom secureRandom = new SecureRandom(); private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); private String fileName = "key_file.json"; @@ -58,7 +58,7 @@ public void doGet(Request req, Response rsp, Context ctx) throws IOException { rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8)); } - // parses query and either returns "management"/"inference" or a wrong type error + // parses query and either returns management/inference or a wrong type error public String parseQuery(Request req) { QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI()); Map> parameters = decoder.parameters(); diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index a2a04e79b5..d12eee8a0b 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -19,7 +19,8 @@ def get_plugin_jar(): plugin_folder = os.path.join(REPO_ROOT, "plugins") os.makedirs(new_folder_path, exist_ok=True) os.chdir(plugin_folder) - subprocess.run(["./gradlew", "build"]) + subprocess.run(["./gradlew", "formatJava"]) + result = subprocess.run(["./gradlew", "build"]) jar_path = os.path.join(plugin_folder, "endpoints/build/libs") jar_file = [file for file in os.listdir(jar_path) if file.endswith(".jar")] if jar_file: @@ -84,6 +85,8 @@ def setup_torchserve(): "http://localhost:8081/models", params=params, headers=header ) time.sleep(5) + print("register reponse") + print(response.text) result = subprocess.run( f"cat {REPO_ROOT}/key_file.json", @@ -101,7 +104,8 @@ def test_managament_api_with_token(): setup_torchserve() key = read_key_file("management") header = {"Authorization": f"Bearer {key}"} - response = requests.get(f"http://localhost:8081/models/resnet18", headers=header) + print(key) + response = requests.get("http://localhost:8081/models/resnet18", headers=header) time.sleep(5) print(response.text) @@ -119,12 +123,6 @@ def test_managament_api_with_incorrect_token(): assert response.status_code == 400, "Token check failed" - response = requests.get(f"http://localhost:8081/models/resnet18") - time.sleep(5) - print(response.text) - - assert response.status_code == 400, "Token check failed" - # Test inference API with token enabled def test_inference_api_with_token(): @@ -138,6 +136,7 @@ def test_inference_api_with_token(): ) time.sleep(5) print(response.text) + print(key) assert response.status_code == 200, "Token check failed" @@ -162,18 +161,38 @@ def test_inference_api_with_incorrect_token(): def test_token_inference_api(): token_key = read_key_file("token") inference_key = read_key_file("inference") - header = {"Authorization": f"Bearer {token_key}"} + header_inference = {"Authorization": f"Bearer {inference_key}"} + header_token = {"Authorization": f"Bearer {token_key}"} params = {"type": "inference"} + # check inference works with current token + response = requests.post( + url="http://localhost:8080/predictions/resnet18", + files={"data": open(data_file_kitten, "rb")}, + headers=header_inference, + ) + time.sleep(5) + assert response.status_code == 200, "Token check failed" + + # generate new inference token and check it is different response = requests.get( - url="http://localhost:8081/token", params=params, headers=header + url="http://localhost:8081/token", params=params, headers=header_token ) time.sleep(5) print(response.text) - + print(token_key) assert response.status_code == 200, "Token check failed" assert inference_key != read_key_file("inference"), "Key file not updated" + # check inference does not works with original token + response = requests.post( + url="http://localhost:8080/predictions/resnet18", + files={"data": open(data_file_kitten, "rb")}, + headers=header_inference, + ) + time.sleep(5) + assert response.status_code == 400, "Token check failed" + # Test Token API for regenerating new management key def test_token_management_api(): @@ -186,6 +205,8 @@ def test_token_management_api(): url="http://localhost:8081/token", params=params, headers=header ) time.sleep(5) + print(response.text) + print(token_key) assert management_key != read_key_file("management"), "Key file not updated" assert response.status_code == 200, "Token check failed" From eb37eff0d1f7f1206540c9926841a1afbb900363 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Thu, 15 Feb 2024 16:31:57 -0800 Subject: [PATCH 20/28] changed keyfile to dictionary and updated readme and tests --- docs/token_authorization_api.md | 27 +++++++++++++------ .../pytorch/serve/plugins/endpoint/Token.java | 26 ++++++++++-------- test/pytest/test_token_authorization.py | 15 +++++------ 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index ce033f1187..19f7dc8a80 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -5,22 +5,33 @@ 2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated. 1. Example key file: - `Management Key: aadJv_R6 --- Expiration time: 2024-01-16T22:23:32.952499Z` +```python + [ + { + "Management Key": "I_J_ItMb", + "ExpirationTime": "2024-02-16T01:27:56.749292Z" + }, + { + "Inference Key": "FINhR1fj", + "ExpirationTime": "2024-02-16T01:27:56.749273Z" + }, + { + "API Key": "m4M-5IBY" + } + ] +``` - `Inference Key: poZXAlqe --- Expiration time: 2024-01-16T22:23:50.621298Z` - - `API Key: xryL_Vzs` 3. There are 3 keys and each have a different use. 1. Management key: Used for management APIs. Example: - `curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer aadJv_R6"` + `curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer I_J_ItMb"` 2. Inference key: Used for inference APIs. Example: - `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer poZXAlqe"` + `curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"` 3. API key: Used for the token authorization API. Check section 4 for API use. 4. The plugin also includes an API in order to generate a new key to replace either the management or inference key. 1. Management Example: - `curl localhost:8081/token?type=management -H "Authorization: Bearer xryL_Vzs"` will replace the current management key in the key_file with a new one and will update the expiration time. + `curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time. 2. Inference example: - `curl localhost:8081/token?type=inference -H "Authorization: Bearer xryL_Vzs"` + `curl localhost:8081/token?type=inference -H "Authorization: Bearer m4M-5IBY"` Users will have to use either one of the APIs above. diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index 1ee4fd3da8..df951434e4 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -3,6 +3,8 @@ // import java.util.Properties; import com.google.gson.GsonBuilder; import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.File; @@ -110,17 +112,19 @@ public boolean generateKeyFile(String type) throws IOException { } JsonArray jsonArray = new JsonArray(); - jsonArray.add( - "Management Key: " - + managementKey - + " --- Expiration time: " - + managementExpirationTimeMinutes); - jsonArray.add( - "Inference Key: " - + inferenceKey - + " --- Expiration time: " - + inferenceExpirationTimeMinutes); - jsonArray.add("API Key: " + apiKey); + JsonObject managementObject = new JsonObject(); + managementObject.addProperty("Management Key", managementKey); + managementObject.addProperty("ExpirationTime", managementExpirationTimeMinutes.toString()); + jsonArray.add(managementObject); + + JsonObject inferenceObject = new JsonObject(); + inferenceObject.addProperty("Inference Key", inferenceKey); + inferenceObject.addProperty("ExpirationTime", inferenceExpirationTimeMinutes.toString()); + jsonArray.add(inferenceObject); + + JsonObject apiKeyObject = new JsonObject(); + apiKeyObject.addProperty("API Key", apiKey); + jsonArray.add(apiKeyObject); Files.write( Paths.get(fileName), diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index d12eee8a0b..cabc646c76 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -47,14 +47,13 @@ def read_key_file(type): management_key = None inference_key = None api_key = None - for key_string in json_data: - if "Management Key" in key_string: - management_key = key_string.split(":")[1].strip().split("---")[0].strip() - elif "Inference Key" in key_string: - inference_key = key_string.split(":")[1].strip().split("---")[0].strip() - elif "API Key" in key_string: - api_key = key_string.split(":")[1].strip().split("---")[0].strip() - + for item in json_data: + if "Management Key" in item: + management_key = item["Management Key"] + elif "Inference Key" in item: + inference_key = item["Inference Key"] + elif "API Key" in item: + api_key = item["API Key"] options = { "management": management_key, "inference": inference_key, From 71bc7f5fe9eec150c75b93f0e23bb40d1aa00f3e Mon Sep 17 00:00:00 2001 From: udaij12 Date: Fri, 16 Feb 2024 10:23:58 -0800 Subject: [PATCH 21/28] remove comments --- .../org/pytorch/serve/http/TokenAuthorizationHandler.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index cf5ded3bd8..a788c92ebf 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -42,7 +42,6 @@ public void handleRequest( throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { if (tokenEnabled) { - // ConfigManager configManager = ConfigManager.getInstance(); if (tokenType == TokenType.MANAGEMENT) { if (req.toString().contains("/token")) { checkTokenAuthorization(req, "token"); @@ -70,10 +69,6 @@ public static void setupTokenClass() { if ((boolean) method.invoke(tokenObject, "token")) { logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY"); } - // } catch (ClassNotFoundException e) { - // logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY"); - // e.printStackTrace(); - // return; } catch (NoSuchMethodException | IllegalAccessException | InstantiationException From 3e182309b3a2acfebbfce074ed17678c4dcd5403 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Fri, 16 Feb 2024 13:22:34 -0800 Subject: [PATCH 22/28] changes to tests --- docs/token_authorization_api.md | 26 ++-- .../pytorch/serve/plugins/endpoint/Token.java | 25 ++-- test/pytest/test_token_authorization.py | 118 ++++++++++-------- ts/model_server.py | 5 +- 4 files changed, 92 insertions(+), 82 deletions(-) diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index 19f7dc8a80..0c7907da94 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -6,19 +6,19 @@ 1. Example key file: ```python - [ - { - "Management Key": "I_J_ItMb", - "ExpirationTime": "2024-02-16T01:27:56.749292Z" - }, - { - "Inference Key": "FINhR1fj", - "ExpirationTime": "2024-02-16T01:27:56.749273Z" - }, - { - "API Key": "m4M-5IBY" - } - ] + { + "management": { + "key": "B-E5KSRM", + "expiration time": "2024-02-16T21:12:24.801167Z" + }, + "inference": { + "key": "gNRuA7dS", + "expiration time": "2024-02-16T21:12:24.801148Z" + }, + "API": { + "key": "yv9uQajP" + } +} ``` 3. There are 3 keys and each have a different use. diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index df951434e4..5ca2ceb8b4 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -2,9 +2,7 @@ // import java.util.Properties; import com.google.gson.GsonBuilder; -import com.google.gson.JsonArray; import com.google.gson.JsonObject; - import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.File; @@ -111,27 +109,28 @@ public boolean generateKeyFile(String type) throws IOException { managementExpirationTimeMinutes = generateTokenExpiration(); } - JsonArray jsonArray = new JsonArray(); + JsonObject parentObject = new JsonObject(); + JsonObject managementObject = new JsonObject(); - managementObject.addProperty("Management Key", managementKey); - managementObject.addProperty("ExpirationTime", managementExpirationTimeMinutes.toString()); - jsonArray.add(managementObject); + managementObject.addProperty("key", managementKey); + managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString()); + parentObject.add("management", managementObject); JsonObject inferenceObject = new JsonObject(); - inferenceObject.addProperty("Inference Key", inferenceKey); - inferenceObject.addProperty("ExpirationTime", inferenceExpirationTimeMinutes.toString()); - jsonArray.add(inferenceObject); + inferenceObject.addProperty("key", inferenceKey); + inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString()); + parentObject.add("inference", inferenceObject); - JsonObject apiKeyObject = new JsonObject(); - apiKeyObject.addProperty("API Key", apiKey); - jsonArray.add(apiKeyObject); + JsonObject apiObject = new JsonObject(); + apiObject.addProperty("key", apiKey); + parentObject.add("API", apiObject); Files.write( Paths.get(fileName), new GsonBuilder() .setPrettyPrinting() .create() - .toJson(jsonArray) + .toJson(parentObject) .getBytes(StandardCharsets.UTF_8)); if (!setFilePermissions()) { diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index cabc646c76..d101e53db6 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -4,13 +4,16 @@ import subprocess import tempfile import time +from pathlib import Path +import pytest import requests import test_utils ROOT_DIR = os.path.join(tempfile.gettempdir(), "workspace") REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") data_file_kitten = os.path.join(REPO_ROOT, "test/pytest/test_data/kitten.jpg") +config_file = os.path.join(REPO_ROOT, "test/resources/config_token.properties") # Set up token plugin @@ -44,32 +47,26 @@ def read_key_file(type): json_data = json.load(json_file) # Extract the three keys - management_key = None - inference_key = None - api_key = None - for item in json_data: - if "Management Key" in item: - management_key = item["Management Key"] - elif "Inference Key" in item: - inference_key = item["Inference Key"] - elif "API Key" in item: - api_key = item["API Key"] + # management_key = + # inference_key = json_data.get("inference", {}).get("key", "NOT_PRESENT") + # api_key = json_data.get("API", {}).get("key", "NOT_PRESENT") + options = { - "management": management_key, - "inference": inference_key, - "token": api_key, + "management": json_data.get("management", {}).get("key", "NOT_PRESENT"), + "inference": json_data.get("inference", {}).get("key", "NOT_PRESENT"), + "token": json_data.get("API", {}).get("key", "NOT_PRESENT"), } key = options.get(type, "Invalid data type") return key +@pytest.fixture(scope="module") def setup_torchserve(): get_plugin_jar() MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") test_utils.start_torchserve(no_config_snapshots=True, plugin_folder=PLUGIN_STORE) - time.sleep(10) key = read_key_file("management") header = {"Authorization": f"Bearer {key}"} @@ -83,48 +80,66 @@ def setup_torchserve(): response = requests.post( "http://localhost:8081/models", params=params, headers=header ) - time.sleep(5) - print("register reponse") - print(response.text) + file_content = Path(f"{REPO_ROOT}/key_file.json").read_text() + print(file_content) - result = subprocess.run( - f"cat {REPO_ROOT}/key_file.json", - shell=True, - capture_output=True, - text=True, + yield "test" + + test_utils.stop_torchserve() + + +@pytest.fixture(scope="module") +def setup_torchserve_expiration(): + get_plugin_jar() + MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") + PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") + + test_utils.start_torchserve( + snapshot_file=config_file, + no_config_snapshots=True, + plugin_folder=PLUGIN_STORE, ) - print("Curl output:") - print(result.stdout) + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + + params = ( + ("model_name", "resnet18"), + ("url", "resnet-18.mar"), + ("initial_workers", "1"), + ("synchronous", "true"), + ) + response = requests.post( + "http://localhost:8081/models", params=params, headers=header + ) + file_content = Path(f"{REPO_ROOT}/key_file.json").read_text() + print(file_content) + + yield "test" -# Test describe model API with token enabled -def test_managament_api_with_token(): test_utils.stop_torchserve() - setup_torchserve() + + +# Test describe model API with token enabled +def test_managament_api_with_token(setup_torchserve): key = read_key_file("management") header = {"Authorization": f"Bearer {key}"} - print(key) response = requests.get("http://localhost:8081/models/resnet18", headers=header) - time.sleep(5) - print(response.text) assert response.status_code == 200, "Token check failed" # Test describe model API with incorrect token and no token -def test_managament_api_with_incorrect_token(): +def test_managament_api_with_incorrect_token(setup_torchserve): # Using random key header = {"Authorization": "Bearer abcd1234"} - response = requests.get(f"http://localhost:8081/models/resnet18", headers=header) - time.sleep(5) - print(response.text) assert response.status_code == 400, "Token check failed" # Test inference API with token enabled -def test_inference_api_with_token(): +def test_inference_api_with_token(setup_torchserve): key = read_key_file("inference") header = {"Authorization": f"Bearer {key}"} @@ -133,15 +148,12 @@ def test_inference_api_with_token(): files={"data": open(data_file_kitten, "rb")}, headers=header, ) - time.sleep(5) - print(response.text) - print(key) assert response.status_code == 200, "Token check failed" # Test inference API with incorrect token -def test_inference_api_with_incorrect_token(): +def test_inference_api_with_incorrect_token(setup_torchserve): # Using random key header = {"Authorization": "Bearer abcd1234"} @@ -150,14 +162,12 @@ def test_inference_api_with_incorrect_token(): files={"data": open(data_file_kitten, "rb")}, headers=header, ) - time.sleep(5) - print(response.text) assert response.status_code == 400, "Token check failed" # Test Token API for regenerating new inference key -def test_token_inference_api(): +def test_token_inference_api(setup_torchserve): token_key = read_key_file("token") inference_key = read_key_file("inference") header_inference = {"Authorization": f"Bearer {inference_key}"} @@ -170,16 +180,12 @@ def test_token_inference_api(): files={"data": open(data_file_kitten, "rb")}, headers=header_inference, ) - time.sleep(5) assert response.status_code == 200, "Token check failed" # generate new inference token and check it is different response = requests.get( url="http://localhost:8081/token", params=params, headers=header_token ) - time.sleep(5) - print(response.text) - print(token_key) assert response.status_code == 200, "Token check failed" assert inference_key != read_key_file("inference"), "Key file not updated" @@ -189,12 +195,11 @@ def test_token_inference_api(): files={"data": open(data_file_kitten, "rb")}, headers=header_inference, ) - time.sleep(5) assert response.status_code == 400, "Token check failed" # Test Token API for regenerating new management key -def test_token_management_api(): +def test_token_management_api(setup_torchserve): token_key = read_key_file("token") management_key = read_key_file("management") header = {"Authorization": f"Bearer {token_key}"} @@ -203,10 +208,19 @@ def test_token_management_api(): response = requests.get( url="http://localhost:8081/token", params=params, headers=header ) - time.sleep(5) - print(response.text) - print(token_key) assert management_key != read_key_file("management"), "Key file not updated" assert response.status_code == 200, "Token check failed" - test_utils.stop_torchserve() + + +# Test expiration time +def test_token_expiration_time(setup_torchserve_expiration): + key = read_key_file("management") + header = {"Authorization": f"Bearer {key}"} + response = requests.get("http://localhost:8081/models/resnet18", headers=header) + assert response.status_code == 200, "Token check failed" + + time.sleep(60) + + response = requests.get("http://localhost:8081/models/resnet18", headers=header) + assert response.status_code == 400, "Token check failed" diff --git a/ts/model_server.py b/ts/model_server.py index eeefd833fd..3b4a2d882e 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -48,10 +48,7 @@ def start() -> None: try: parent = psutil.Process(pid) parent.terminate() - try: - os.remove(os.getcwd() + "/key_file.json") - except FileNotFoundError: - print("Delete key file if it exists") + os.remove(os.getcwd() + "/key_file.json") if args.foreground: try: parent.wait(timeout=60) From f69c632869bd50718d4a339d68e43402017654bd Mon Sep 17 00:00:00 2001 From: udaij12 Date: Fri, 16 Feb 2024 13:59:31 -0800 Subject: [PATCH 23/28] added config file --- test/resources/config_token.properties | 1 + 1 file changed, 1 insertion(+) create mode 100644 test/resources/config_token.properties diff --git a/test/resources/config_token.properties b/test/resources/config_token.properties new file mode 100644 index 0000000000..2bf3499772 --- /dev/null +++ b/test/resources/config_token.properties @@ -0,0 +1 @@ +token_expiration_min=1 From 55cedd538d23dbcd130cdd89d17ba2f1c0e43071 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Fri, 16 Feb 2024 16:21:12 -0800 Subject: [PATCH 24/28] reduce time for expiration test --- docs/token_authorization_api.md | 2 +- .../pytorch/serve/http/TokenAuthorizationHandler.java | 8 ++++---- .../java/org/pytorch/serve/util/ConfigManager.java | 6 +++--- .../org/pytorch/serve/plugins/endpoint/Token.java | 8 ++++---- test/pytest/test_token_authorization.py | 11 +++++++---- test/resources/config_token.properties | 2 +- ts/model_server.py | 3 ++- 7 files changed, 22 insertions(+), 18 deletions(-) diff --git a/docs/token_authorization_api.md b/docs/token_authorization_api.md index 0c7907da94..10e346e53b 100644 --- a/docs/token_authorization_api.md +++ b/docs/token_authorization_api.md @@ -1,7 +1,7 @@ # TorchServe token authorization API ## Configuration -1. Enable token authorization by adding the provided plugin at start using the `--plugin-path` command. +1. Enable token authorization by adding the provided plugin at start using the `--plugins-path` command. 2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated. 1. Example key file: diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index a788c92ebf..cab59f11b5 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -26,7 +26,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain { private static Boolean tokenEnabled = false; private static Class tokenClass; private static Object tokenObject; - private static Integer timeToExpirationMinutes = 60; + private static Double timeToExpirationMinutes = 60.0; /** Creates a new {@code InferenceRequestHandler} instance. */ public TokenAuthorizationHandler(TokenType type) { @@ -59,9 +59,9 @@ public static void setupTokenClass() { try { tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token"); tokenObject = tokenClass.getDeclaredConstructor().newInstance(); - Method method = tokenClass.getMethod("setTime", Integer.class); - Integer time = ConfigManager.getInstance().getTimeToExpiration(); - if (time != 0) { + Method method = tokenClass.getMethod("setTime", Double.class); + Double time = ConfigManager.getInstance().getTimeToExpiration(); + if (time != 0.0) { timeToExpirationMinutes = time; } method.invoke(tokenObject, timeToExpirationMinutes); diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 05b4324131..849b895ddd 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -860,15 +860,15 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } - public Integer getTimeToExpiration() { + public Double getTimeToExpiration() { if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) { try { - return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN)); + return Double.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN)); } catch (NumberFormatException e) { logger.error("Token expiration not a valid integer"); } } - return 0; + return 0.0; } public boolean isSSLEnabled(ConnectorType connectorType) { diff --git a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java index 5ca2ceb8b4..805b6169ff 100644 --- a/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java +++ b/plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java @@ -19,7 +19,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.TimeUnit; import org.pytorch.serve.servingsdk.Context; import org.pytorch.serve.servingsdk.ModelServerEndpoint; import org.pytorch.serve.servingsdk.annotations.Endpoint; @@ -39,7 +38,7 @@ public class Token extends ModelServerEndpoint { private static String inferenceKey; private static Instant managementExpirationTimeMinutes; private static Instant inferenceExpirationTimeMinutes; - private static Integer timeToExpirationMinutes; + private static Double timeToExpirationMinutes; private SecureRandom secureRandom = new SecureRandom(); private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); private String fileName = "key_file.json"; @@ -80,7 +79,8 @@ public String generateKey() { } public Instant generateTokenExpiration() { - return Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpirationMinutes)); + long secondsToAdd = (long) (timeToExpirationMinutes * 60); + return Instant.now().plusSeconds(secondsToAdd); } // generates a key file with new keys depending on the parameter provided @@ -217,7 +217,7 @@ public Instant getManagementExpirationTime() { return managementExpirationTimeMinutes; } - public void setTime(Integer time) { + public void setTime(Double time) { timeToExpirationMinutes = time; } } diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index d101e53db6..4cda0ae609 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -66,6 +66,8 @@ def setup_torchserve(): MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") + Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True) + test_utils.start_torchserve(no_config_snapshots=True, plugin_folder=PLUGIN_STORE) key = read_key_file("management") @@ -94,10 +96,10 @@ def setup_torchserve_expiration(): MODEL_STORE = os.path.join(ROOT_DIR, "model_store/") PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path") + Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True) + test_utils.start_torchserve( - snapshot_file=config_file, - no_config_snapshots=True, - plugin_folder=PLUGIN_STORE, + snapshot_file=config_file, no_config_snapshots=True, plugin_folder=PLUGIN_STORE ) key = read_key_file("management") @@ -214,13 +216,14 @@ def test_token_management_api(setup_torchserve): # Test expiration time +@pytest.mark.module2 def test_token_expiration_time(setup_torchserve_expiration): key = read_key_file("management") header = {"Authorization": f"Bearer {key}"} response = requests.get("http://localhost:8081/models/resnet18", headers=header) assert response.status_code == 200, "Token check failed" - time.sleep(60) + time.sleep(15) response = requests.get("http://localhost:8081/models/resnet18", headers=header) assert response.status_code == 400, "Token check failed" diff --git a/test/resources/config_token.properties b/test/resources/config_token.properties index 2bf3499772..b62cd26870 100644 --- a/test/resources/config_token.properties +++ b/test/resources/config_token.properties @@ -1 +1 @@ -token_expiration_min=1 +token_expiration_min=0.25 diff --git a/ts/model_server.py b/ts/model_server.py index 3b4a2d882e..7291c250c4 100644 --- a/ts/model_server.py +++ b/ts/model_server.py @@ -3,6 +3,7 @@ """ import os +import pathlib import platform import re import subprocess @@ -48,7 +49,7 @@ def start() -> None: try: parent = psutil.Process(pid) parent.terminate() - os.remove(os.getcwd() + "/key_file.json") + pathlib.Path("key_file.json").unlink(missing_ok=True) if args.foreground: try: parent.wait(timeout=60) From 54841d9e9497f7243edac3aaa98b720585d9c21d Mon Sep 17 00:00:00 2001 From: udaij12 Date: Fri, 16 Feb 2024 17:14:40 -0800 Subject: [PATCH 25/28] change test to mnist --- test/pytest/test_data/0.png | Bin 0 -> 272 bytes test/pytest/test_token_authorization.py | 39 +++++++++++------------- 2 files changed, 17 insertions(+), 22 deletions(-) create mode 100644 test/pytest/test_data/0.png diff --git a/test/pytest/test_data/0.png b/test/pytest/test_data/0.png new file mode 100644 index 0000000000000000000000000000000000000000..a193c47ba45d876f6231d161c76a406da3b9418e GIT binary patch literal 272 zcmV+r0q_2aP)Sge`mIU!0UhiUd`b`lMFoj4+8$Ys$jqXQkTGTGcz!y=n4SE WdUQhQ?LY4T0000 Date: Fri, 16 Feb 2024 17:19:24 -0800 Subject: [PATCH 26/28] removing install from src --- test/pytest/test_token_authorization.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index c279891268..54089da9b5 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -32,12 +32,12 @@ def get_plugin_jar(): os.path.join(new_folder_path, jar_file[0]), ) os.chdir(REPO_ROOT) - result = subprocess.run( - f"python ts_scripts/install_from_source", - shell=True, - capture_output=True, - text=True, - ) + # result = subprocess.run( + # f"python ts_scripts/install_from_source", + # shell=True, + # capture_output=True, + # text=True, + # ) # Parse json file and return key From 698b95a876164f057f780248e9e76b540d84cab3 Mon Sep 17 00:00:00 2001 From: udaij12 Date: Fri, 16 Feb 2024 20:58:33 -0800 Subject: [PATCH 27/28] final test change --- test/pytest/test_token_authorization.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/pytest/test_token_authorization.py b/test/pytest/test_token_authorization.py index 54089da9b5..4fa6a4799e 100644 --- a/test/pytest/test_token_authorization.py +++ b/test/pytest/test_token_authorization.py @@ -32,12 +32,6 @@ def get_plugin_jar(): os.path.join(new_folder_path, jar_file[0]), ) os.chdir(REPO_ROOT) - # result = subprocess.run( - # f"python ts_scripts/install_from_source", - # shell=True, - # capture_output=True, - # text=True, - # ) # Parse json file and return key From 68b3b041d5dbe903acbad3a295a11230628a83e2 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 19 Feb 2024 22:20:02 +0000 Subject: [PATCH 28/28] Fix spellcheck --- ts_scripts/spellcheck_conf/wordlist.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 3cfa9e6840..aa050a38fb 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1187,3 +1187,11 @@ FxGraphCache TorchInductor fx locustapache +FINhR +IBY +ItMb +checkTokenAuthorization +fj +generateKeyFile +setTime +urlPattern