Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serving] Adds mutliple node cluster configuration support #2190

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions serving/docker/config.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8080
cluster_address=http://0.0.0.0:8888
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: any reason for choosing 8888? Should it be a closer port like 8081?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User may choose management port as 8081

model_store=/opt/ml/model
load_models=ALL
#model_url_pattern=.*
42 changes: 40 additions & 2 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.plugins.DependencyManager;
import ai.djl.serving.plugins.FolderScanPluginManager;
import ai.djl.serving.util.ClusterConfig;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.Connector;
import ai.djl.serving.util.ServerGroups;
Expand Down Expand Up @@ -194,7 +195,6 @@ public List<ChannelFuture> start()
GeneralSecurityException,
ServerStartupException {
long begin = System.nanoTime();
stopped.set(false);

String version = Engine.getDjlVersion();
logger.info("Starting djl-serving: {} ...", version);
Expand All @@ -204,13 +204,16 @@ public List<ChannelFuture> start()

pluginManager.loadPlugins(true);

initMultiNode();

try {
initModelStore();
} catch (BadWorkflowException | CompletionException e) {
throw new ServerStartupException(
"Failed to initialize startup models and workflows", e);
}

stopped.set(false);
Connector inferenceConnector =
configManager.getConnector(Connector.ConnectorType.INFERENCE);
Connector managementConnector =
Expand Down Expand Up @@ -273,6 +276,41 @@ public void stop() {
serverGroups.reset();
}

private void initMultiNode()
throws GeneralSecurityException,
IOException,
InterruptedException,
ServerStartupException {
ClusterConfig cc = ClusterConfig.getInstance();
int clusterSize = cc.getClusterSize();
if (clusterSize > 1) {
Connector multiNodeConnector =
configManager.getConnector(Connector.ConnectorType.CLUSTER);
multiNodeConnector.clean();

EventLoopGroup serverGroup = serverGroups.getServerGroup();
EventLoopGroup workerGroup = serverGroups.getChildGroup();

ChannelFuture future = initializeServer(multiNodeConnector, serverGroup, workerGroup);

// start download model here
cc.countDown();

logger.info("Waiting for all worker nodes ready ...");
cc.await();

future.channel().close();
serverGroups.shutdown(true);
serverGroups.reset();

String status = cc.getError();
if (status != null) {
throw new ServerStartupException("Failed to initialize cluster: " + status);
}
logger.info("Cluster initialized with {} nodes.", clusterSize);
}
}

private ChannelFuture initializeServer(
Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup)
throws InterruptedException, IOException, GeneralSecurityException {
Expand Down Expand Up @@ -486,7 +524,7 @@ String mapModelUrl(Path path) {
} catch (MalformedURLException e) {
throw new AssertionError("Invalid path: " + path, e);
} catch (IOException e) {
logger.warn("Failed to access file: " + path, e);
logger.warn("Failed to access file: {}", path, e);
return null;
}
}
Expand Down
4 changes: 4 additions & 0 deletions serving/src/main/java/ai/djl/serving/ServerInitializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.serving;

import ai.djl.serving.http.AdapterManagementRequestHandler;
import ai.djl.serving.http.ClusterRequestHandler;
import ai.djl.serving.http.ConfigurableHttpRequestHandler;
import ai.djl.serving.http.InferenceRequestHandler;
import ai.djl.serving.http.InvalidRequestHandler;
Expand Down Expand Up @@ -74,6 +75,9 @@ public void initChannel(Channel ch) {
case INFERENCE:
pipeline.addLast("inference", new InferenceRequestHandler());
break;
case CLUSTER:
pipeline.addLast("cluster", new ClusterRequestHandler());
break;
case BOTH:
default:
pipeline.addLast(new ConfigurableHttpRequestHandler(pluginManager));
Expand Down
105 changes: 105 additions & 0 deletions serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.http;

import ai.djl.ModelException;
import ai.djl.serving.util.ClusterConfig;
import ai.djl.serving.util.NettyUtils;
import ai.djl.util.Utils;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

/** A class handling inbound HTTP requests for the cluster management API. */
public class ClusterRequestHandler extends HttpRequestHandler {

private static final Logger logger = LoggerFactory.getLogger(ClusterRequestHandler.class);

private ClusterConfig config = ClusterConfig.getInstance();

/** {@inheritDoc} */
@Override
public boolean acceptInboundMessage(Object msg) throws Exception {
if (super.acceptInboundMessage(msg)) {
FullHttpRequest req = (FullHttpRequest) msg;
return req.uri().startsWith("/cluster/");
}
return false;
}

/** {@inheritDoc} */
@Override
protected void handleRequest(
ChannelHandlerContext ctx,
FullHttpRequest req,
QueryStringDecoder decoder,
String[] segments)
throws ModelException {
switch (segments[2]) {
case "sshkey":
Path home = Paths.get(System.getProperty("user.home")).resolve(".ssh");
Path file = home.resolve("id_rsa.pub");
if (Files.notExists(file)) {
sshkeygen(home.resolve("id_rsa").toString());
}
NettyUtils.sendFile(ctx, file, false);
return;
case "status":
List<String> messages = decoder.parameters().get("message");
if (messages.size() != 1) {
NettyUtils.sendJsonResponse(ctx, new StatusResponse("Invalid request"));
return;
} else if (!"OK".equals(messages.get(0))) {
config.setError(messages.get(0));
}
config.countDown();
NettyUtils.sendJsonResponse(ctx, new StatusResponse("OK"));
return;
default:
throw new ResourceNotFoundException();
}
}

private void sshkeygen(String rsaFile) {
try {
String[] commands = {"ssh-keygen", "-q", "-t", "rsa", "-N", "''", "-f", rsaFile};
Process exec = new ProcessBuilder(commands).redirectErrorStream(true).start();
String logOutput;
try (InputStream is = exec.getInputStream()) {
logOutput = Utils.toString(is);
}
int exitCode = exec.waitFor();
if (0 != exitCode) {
logger.error("Generate ssh key failed: {}", logOutput);
config.setError(logOutput);
throw new IllegalStateException("Generate ssh key failed");
} else {
logger.debug(logOutput);
}
} catch (IOException | InterruptedException e) {
config.setError("Generate ssh key failed");
throw new IllegalStateException("Generate ssh key failed", e);
}
}
}
87 changes: 87 additions & 0 deletions serving/src/main/java/ai/djl/serving/util/ClusterConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.util;

import ai.djl.util.Utils;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/** A class that holds cluster configurations. */
public final class ClusterConfig {

private static final ClusterConfig INSTANCE = new ClusterConfig();

private int clusterSize;
private CountDownLatch latch;
private String error;

private ClusterConfig() {
clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "1"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be slightly nicer to configure this using the config manager rather than just environment variables

Copy link
Contributor Author

@frankfliu frankfliu Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This is configured by LWS, and it only support env var so far
  2. There are other code (PyEnv) is also just use this env var

latch = new CountDownLatch(clusterSize);
}

/**
* Returns the {@code ClusterConfig} singleton object.
*
* @return the {@code ClusterConfig} singleton object
*/
public static ClusterConfig getInstance() {
return INSTANCE;
}

/**
* Returns the cluster size.
*
* @return the cluster size
*/
public int getClusterSize() {
return clusterSize;
}

/**
* Returns the error status message.
*
* @return the error status message
*/
public String getError() {
return error;
}

/**
* Sets the error status message.
*
* @param error the error status message
*/
public void setError(String error) {
this.error = error;
}

/** Decreases the number of waiting workers. */
public void countDown() {
latch.countDown();
}

/**
* Causes current threads to wait until all workers are ready.
*
* @throws InterruptedException if current thread is interrupted
*/
public void await() throws InterruptedException {
// TODO: support per model timeout
int timeout = Integer.parseInt(Utils.getenv("MODEL_LOADING_TIMEOUT", "240"));
if (!latch.await(timeout, TimeUnit.SECONDS)) {
error = "Worker nodes timed out";
}
}
}
17 changes: 13 additions & 4 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public final class ConfigManager {

private static final String INFERENCE_ADDRESS = "inference_address";
private static final String MANAGEMENT_ADDRESS = "management_address";
private static final String CLUSTER_ADDRESS = "cluster_address";
private static final String LOAD_MODELS = "load_models";
private static final String WAIT_MODEL_LOADING = "wait_model_loading";
private static final String ALLOW_MULTI_STATUS = "allow_multi_status";
Expand Down Expand Up @@ -236,10 +237,18 @@ private boolean onError(String key) {
*/
public Connector getConnector(Connector.ConnectorType type) {
String binding;
if (type == Connector.ConnectorType.MANAGEMENT) {
binding = prop.getProperty(MANAGEMENT_ADDRESS, "http://127.0.0.1:8080");
} else {
binding = prop.getProperty(INFERENCE_ADDRESS, "http://127.0.0.1:8080");
switch (type) {
case MANAGEMENT:
binding = prop.getProperty(MANAGEMENT_ADDRESS, "http://127.0.0.1:8080");
break;
case CLUSTER:
binding = prop.getProperty(CLUSTER_ADDRESS, "http://127.0.0.1:8888");
break;
case INFERENCE:
case BOTH:
default:
binding = prop.getProperty(INFERENCE_ADDRESS, "http://127.0.0.1:8080");
break;
}
return Connector.parse(binding, type);
}
Expand Down
1 change: 1 addition & 0 deletions serving/src/main/java/ai/djl/serving/util/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ public String toString() {

/** An enum represents type of connector. */
public enum ConnectorType {
CLUSTER,
INFERENCE,
MANAGEMENT,
BOTH
Expand Down
Loading