Skip to content

Commit

Permalink
Support multi node for lmi-dist
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Jul 8, 2024
1 parent 04445f8 commit 1ab823e
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 35 deletions.
5 changes: 5 additions & 0 deletions engines/python/setup/djl_python/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def python_engine_args():
dest="tensor_parallel_degree",
type=int,
help='The tensor parallel degree')
parser.add_argument('--cluster-size',
required=False,
dest="cluster_size",
type=int,
help='The cluster size')
parser.add_argument('--recommended-entry-point',
required=False,
type=str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class HuggingFaceProperties(Properties):
device_id: int = -1
task: str = None
tensor_parallel_degree: int = -1
cluster_size: int = 1
device_map: str = None
load_in_4bit: Optional[bool] = None
load_in_8bit: Optional[bool] = None
Expand Down Expand Up @@ -112,10 +113,12 @@ def construct_kwargs_device_map(self):
self.kwargs["device_map"] = self.device_map
self.device = None
logging.info(f"Using device map {self.device_map}")
elif self.tensor_parallel_degree > 0 and torch.cuda.device_count() > 0:
elif self.tensor_parallel_degree > 0 \
and self.cluster_size > 0 \
and torch.cuda.device_count() > 0:
self.kwargs["device_map"] = "auto"
self.device = None
world_size = torch.cuda.device_count()
world_size = torch.cuda.device_count() * self.cluster_size
assert world_size == self.tensor_parallel_degree, \
f"TP degree ({self.tensor_parallel_degree}) doesn't match available GPUs ({world_size})"
logging.info(f"Using {world_size} gpus")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Properties(BaseModel):
# Make the default to auto, after java front end changes and test cases are changed.
rolling_batch: RollingBatchEnum = RollingBatchEnum.disable
tensor_parallel_degree: int = 1
cluster_size: int = 1
trust_remote_code: bool = False
enable_streaming: StreamingEnum = StreamingEnum.false
batch_size: int = 1
Expand Down
10 changes: 7 additions & 3 deletions engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,24 @@ def __init__(self, args, service):

self.model_dir = args.model_dir
self.sock_type = args.sock_type
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name
self.port = args.port
self.service = service
self.device_id = args.device_id
self.tensor_parallel_degree = args.tensor_parallel_degree
self.cluster_size = args.cluster_size
self.entry_point = args.entry_point
self.recommended_entry_point = args.recommended_entry_point

if self.sock_type == "unix":
if self.sock_name is None:
raise ValueError("Missing sock-name argument.")
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name

self.clean_up()
elif self.sock_type == "tcp":
self.sock_name = "127.0.0.1"
if self.port is None:
raise ValueError("Missing port argument.")
self.port = int(self.port) + int(rank) if rank else self.port
else:
raise ValueError(f"Invalid socket-type: {self.sock_type}.")

Expand Down Expand Up @@ -99,10 +100,11 @@ def run_server(self):
if self.sock_type == "unix":
self.sock.bind(self.sock_name)
else:
logging.info(f"Socket bind on address: {self.sock_name}:{self.port}")
self.sock.bind((self.sock_name, int(self.port)))

self.sock.listen(128)
logging.info("Python engine started.")
logging.info(f"Python engine started.")

(cl_socket, _) = self.sock.accept()
# workaround error(35, 'Resource temporarily unavailable') on OSX
Expand All @@ -115,6 +117,8 @@ def run_server(self):
prop = inputs.get_properties()
if self.tensor_parallel_degree:
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
if self.cluster_size:
prop["cluster_size"] = self.cluster_size
prop["device_id"] = self.device_id
if "output_formatter" in prop and hasattr(
self.service, prop["output_formatter"]):
Expand Down
124 changes: 107 additions & 17 deletions engines/python/src/main/java/ai/djl/python/engine/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -67,17 +68,19 @@ class Connection {
private static final String MASTER_ADDR = "127.0.0.1";

private int port;
private PyEnv pyEnv;
private SocketAddress socketAddress;
private Channel channel;
private RequestHandler requestHandler;

Connection(PyEnv pyEnv, int basePort, int rank) {
Connection(PyEnv pyEnv, int basePort, int rank, String hostname) {
this.pyEnv = pyEnv;
requestHandler = new RequestHandler();
port = 19000 + basePort;
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank);
socketAddress = getSocketAddress(pyEnv.isMpiMode(), pyEnv.getClusterSize(), rank, hostname);
}

static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
static Process startPython(PyEnv pyEnv, Model model, int workerId, int port, String[] hosts)
throws IOException {
Path tmp = Paths.get(System.getProperty("java.io.tmpdir"));
try (Stream<Path> stream = Files.list(tmp)) {
Expand All @@ -100,7 +103,7 @@ static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
});
}
File modelPath = model.getModelPath().toFile();
String[] args = getPythonStartCmd(pyEnv, model, workerId, port);
String[] args = getPythonStartCmd(pyEnv, model, workerId, port, hosts);
String[] envp = pyEnv.getEnvironmentVars(model);
logger.debug("cmd: {}", (Object) args);

Expand All @@ -120,16 +123,88 @@ CompletableFuture<Output> send(Input input) throws InterruptedException {
return f;
}

static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) {
static String[] getPythonStartCmd(
PyEnv pyEnv, Model model, int workerId, int port, String[] hosts) {
Device device = model.getNDManager().getDevice();
int deviceId = device.getDeviceId();
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
String entryPoint = pyEnv.getEntryPoint();
String recommendedEntryPoint = pyEnv.getRecommendedEntryPoint();
// int pipelineParallelDegree = pyEnv.getPipelineParallelDegree();
int clusterSize = pyEnv.getClusterSize();

if (clusterSize > 1) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices);
StringBuilder sb = new StringBuilder();
boolean first = true;
for (String host : hosts) {
if (first) {
first = false;
} else {
sb.append(',');
}
sb.append(host).append(':').append(tensorParallelDegree / clusterSize);
}
String[] args = new String[48];
args[0] = "mpirun";
args[1] = "-np";
// TODO: When we support multi nodes, change it to the product of tensor parallel value
// and
// pipeline parallel value.
args[2] = String.valueOf(tensorParallelDegree);
args[3] = "--host";
args[4] = sb.toString();
args[5] = "--allow-run-as-root";
args[6] = "--bind-to";
args[7] = "none";
args[8] = "--mca";
args[9] = "orte_keep_fqdn_hostnames";
args[10] = "t";
args[11] = "--tag-output";
args[12] = "-x";
args[13] = "FI_PROVIDER=efa";
args[14] = "-x";
args[15] = "RDMAV_FORK_SAFE=1";
args[16] = "-x";
args[17] = "FI_EFA_USE_DEVICE_RDMA=1";
args[18] = "-x";
args[19] = "LD_LIBRARY_PATH";
args[20] = "-x";
args[21] = "PYTHONPATH";
args[22] = "-x";
args[23] = "CUDA_VISIBLE_DEVICES=" + cudaDevices;
args[24] = "-x";
args[25] = "MASTER_ADDR";
args[26] = "-x";
args[27] = "MASTER_PORT=" + port;
args[28] = "-x";
args[29] = "MKL_DYNAMIC=FALSE";
args[30] = pyEnv.getPythonExecutable();
args[31] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
args[32] = "--model-dir";
args[33] = model.getModelPath().toAbsolutePath().toString();
args[34] = "--entry-point";
args[35] = pyEnv.getEntryPoint();
args[36] = "--sock-type";
args[37] = "tcp";
args[38] = "--sock-name";
args[39] = "0.0.0.0";
args[40] = "--port";
args[41] = String.valueOf(port);
args[42] = "--tensor-parallel-degree";
args[43] = String.valueOf(tensorParallelDegree);
args[44] = "--cluster-size";
args[45] = String.valueOf(clusterSize);
args[46] = "--recommended-entry-point";
args[47] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
return args;
}

if (pyEnv.isMpiMode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
String[] args = new String[42];
String[] args = new String[44];
args[0] = "mpirun";
args[1] = "-np";
// TODO: When we support multi nodes, change it to the product of tensor parallel value
Expand Down Expand Up @@ -173,8 +248,10 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
args[37] = getSocketPath(port);
args[38] = "--tensor-parallel-degree";
args[39] = String.valueOf(tensorParallelDegree);
args[40] = "--recommended-entry-point";
args[41] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
args[40] = "--cluster-size";
args[41] = String.valueOf(clusterSize);
args[42] = "--recommended-entry-point";
args[43] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
return args;
}

Expand All @@ -196,7 +273,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
logger.info("Set OMP_NUM_THREADS={}", neuronThreads);
}
boolean uds = Epoll.isAvailable() || KQueue.isAvailable();
String[] args = new String[14];
String[] args = new String[16];
args[0] = pyEnv.getPythonExecutable();
args[1] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
args[2] = "--sock-type";
Expand All @@ -209,8 +286,10 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
args[9] = entryPoint == null ? "" : entryPoint;
args[10] = "--device-id";
args[11] = String.valueOf(deviceId);
args[12] = "--recommended-entry-point";
args[13] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
args[12] = "--cluster-size";
args[13] = String.valueOf(clusterSize);
args[14] = "--recommended-entry-point";
args[15] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
return args;
}

Expand Down Expand Up @@ -248,13 +327,15 @@ private static String getNeuronThreads(int tensorParallelDegree) {
return String.valueOf(1);
}

void connect() throws InterruptedException {
void connect() throws InterruptedException, UnknownHostException {
logger.info("Connecting to socket: {}", socketAddress);
EventLoopGroup group = PyEnv.getEventLoopGroup();

Bootstrap clientBootstrap = new Bootstrap();

clientBootstrap
.group(group)
.channel(getClientChannel())
.channel(getClientChannel(pyEnv.getClusterSize()))
.remoteAddress(socketAddress)
.handler(
new ChannelInitializer<>() {
Expand Down Expand Up @@ -295,7 +376,11 @@ private static String getSocketPath(int port) {
return System.getProperty("java.io.tmpdir") + "/djl_sock." + port;
}

private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
private SocketAddress getSocketAddress(
boolean mpiMode, int clusterSize, int rank, String hostname) {
if (clusterSize > 1) {
return new InetSocketAddress(hostname, port + rank);
}
if (mpiMode) {
return new DomainSocketAddress(getSocketPath(port) + '.' + rank);
}
Expand All @@ -306,17 +391,22 @@ private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
return new InetSocketAddress("127.0.0.1", port);
}

static EventLoopGroup newEventLoopGroup() {
static EventLoopGroup newEventLoopGroup(int clusterSize) {
if (clusterSize > 1) {
return new NioEventLoopGroup(new DaemonThreadFactory());
}
if (Epoll.isAvailable()) {
return new EpollEventLoopGroup(new DaemonThreadFactory());
} else if (KQueue.isAvailable()) {
return new KQueueEventLoopGroup(new DaemonThreadFactory());
}

return new NioEventLoopGroup(new DaemonThreadFactory());
}

private static Class<? extends Channel> getClientChannel() {
private static Class<? extends Channel> getClientChannel(int clusterSize) {
if (clusterSize > 1) {
return NioSocketChannel.class;
}
if (Epoll.isAvailable()) {
return EpollDomainSocketChannel.class;
} else if (KQueue.isAvailable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ public final class PyEngine extends Engine {

private String engineName;
private boolean mpiMode;

private int clusterSize;
private Engine alternativeEngine;
private boolean initialized;

PyEngine(String engineName, boolean mpiMode) {
PyEngine(String engineName, boolean mpiMode, int clusterSize) {
this.engineName = engineName;
this.mpiMode = mpiMode;
this.clusterSize = clusterSize;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -98,4 +101,8 @@ public NDManager newBaseManager(Device device) {
boolean isMpiMode() {
return mpiMode;
}

int getClusterSize() {
return clusterSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.engine.Engine;
import ai.djl.engine.EngineProvider;
import ai.djl.util.Utils;

/** {@code PyEngineProvider} is the Python implementation of {@link EngineProvider}. */
public class PyEngineProvider implements EngineProvider {
Expand Down Expand Up @@ -43,8 +44,9 @@ public Engine getEngine() {
synchronized (this) {
if (!initialized) {
initialized = true;
PyEnv.init();
engine = new PyEngine(getEngineName(), mpiMode);
int clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "1"));
PyEnv.init(clusterSize);
engine = new PyEngine(getEngineName(), mpiMode, clusterSize);
}
}
}
Expand Down
Loading

0 comments on commit 1ab823e

Please sign in to comment.