Skip to content

Commit

Permalink
Support multi node for lmi-dist (#2125)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Jul 9, 2024
1 parent d62f747 commit 70aca0c
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 25 deletions.
5 changes: 5 additions & 0 deletions engines/python/setup/djl_python/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def python_engine_args():
dest="tensor_parallel_degree",
type=int,
help='The tensor parallel degree')
parser.add_argument('--cluster-size',
required=False,
dest="cluster_size",
type=int,
help='The cluster size')
parser.add_argument('--recommended-entry-point',
required=False,
type=str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class HuggingFaceProperties(Properties):
device_id: int = -1
task: str = None
tensor_parallel_degree: int = -1
cluster_size: int = 1
device_map: str = None
load_in_4bit: Optional[bool] = None
load_in_8bit: Optional[bool] = None
Expand Down Expand Up @@ -112,10 +113,12 @@ def construct_kwargs_device_map(self):
self.kwargs["device_map"] = self.device_map
self.device = None
logging.info(f"Using device map {self.device_map}")
elif self.tensor_parallel_degree > 0 and torch.cuda.device_count() > 0:
elif self.tensor_parallel_degree > 0 \
and self.cluster_size > 0 \
and torch.cuda.device_count() > 0:
self.kwargs["device_map"] = "auto"
self.device = None
world_size = torch.cuda.device_count()
world_size = torch.cuda.device_count() * self.cluster_size
assert world_size == self.tensor_parallel_degree, \
f"TP degree ({self.tensor_parallel_degree}) doesn't match available GPUs ({world_size})"
logging.info(f"Using {world_size} gpus")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Properties(BaseModel):
# Make the default to auto, after java front end changes and test cases are changed.
rolling_batch: RollingBatchEnum = RollingBatchEnum.disable
tensor_parallel_degree: int = 1
cluster_size: int = 1
trust_remote_code: bool = False
enable_streaming: StreamingEnum = StreamingEnum.false
batch_size: int = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def test_hf_all_configs(self):
"model_id": "model_id",
"model_dir": "model_dir",
"tensor_parallel_degree": "4",
"cluster_size": "2",
"load_in_4bit": "false",
"load_in_8bit": "true",
"low_cpu_mem_usage": "true",
Expand All @@ -305,6 +306,10 @@ def test_hf_all_configs(self):
}

hf_configs = HuggingFaceProperties(**properties)
self.assertEqual(hf_configs.tensor_parallel_degree,
int(properties['tensor_parallel_degree']))
self.assertEqual(hf_configs.cluster_size,
int(properties['cluster_size']))
self.assertTrue(hf_configs.load_in_8bit)
self.assertTrue(hf_configs.low_cpu_mem_usage)
self.assertFalse(hf_configs.disable_flash_attn)
Expand Down
12 changes: 10 additions & 2 deletions engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,27 @@ def __init__(self, args, service):

self.model_dir = args.model_dir
self.sock_type = args.sock_type
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name
self.sock_name = args.sock_name
self.port = args.port
self.service = service
self.device_id = args.device_id
self.tensor_parallel_degree = args.tensor_parallel_degree
self.cluster_size = args.cluster_size
self.entry_point = args.entry_point
self.recommended_entry_point = args.recommended_entry_point

if self.sock_type == "unix":
if self.sock_name is None:
raise ValueError("Missing sock-name argument.")
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name

self.clean_up()
elif self.sock_type == "tcp":
self.sock_name = "127.0.0.1"
if self.sock_name is None:
self.sock_name = "0.0.0.0"
if self.port is None:
raise ValueError("Missing port argument.")
self.port = int(self.port) + int(rank) if rank else self.port
else:
raise ValueError(f"Invalid socket-type: {self.sock_type}.")

Expand Down Expand Up @@ -99,6 +103,8 @@ def run_server(self):
if self.sock_type == "unix":
self.sock.bind(self.sock_name)
else:
logging.info(
f"Socket bind on address: {self.sock_name}:{self.port}")
self.sock.bind((self.sock_name, int(self.port)))

self.sock.listen(128)
Expand All @@ -115,6 +121,8 @@ def run_server(self):
prop = inputs.get_properties()
if self.tensor_parallel_degree:
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
if self.cluster_size:
prop["cluster_size"] = self.cluster_size
prop["device_id"] = self.device_id
if "output_formatter" in prop and hasattr(
self.service, prop["output_formatter"]):
Expand Down
106 changes: 90 additions & 16 deletions engines/python/src/main/java/ai/djl/python/engine/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -64,20 +65,19 @@
class Connection {

private static final Logger logger = LoggerFactory.getLogger(Connection.class);
private static final String MASTER_ADDR = "127.0.0.1";

private int port;
private SocketAddress socketAddress;
private Channel channel;
private RequestHandler requestHandler;

Connection(PyEnv pyEnv, int basePort, int rank) {
Connection(PyEnv pyEnv, int basePort, int rank, String hostname) {
requestHandler = new RequestHandler();
port = 19000 + basePort;
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank);
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank, hostname);
}

static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
static Process startPython(PyEnv pyEnv, Model model, int workerId, int port, String[] hosts)
throws IOException {
Path tmp = Paths.get(System.getProperty("java.io.tmpdir"));
try (Stream<Path> stream = Files.list(tmp)) {
Expand All @@ -100,7 +100,7 @@ static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
});
}
File modelPath = model.getModelPath().toFile();
String[] args = getPythonStartCmd(pyEnv, model, workerId, port);
String[] args = getPythonStartCmd(pyEnv, model, workerId, port, hosts);
String[] envp = pyEnv.getEnvironmentVars(model);
logger.debug("cmd: {}", (Object) args);

Expand All @@ -120,21 +120,84 @@ CompletableFuture<Output> send(Input input) throws InterruptedException {
return f;
}

static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) {
static String[] getPythonStartCmd(
PyEnv pyEnv, Model model, int workerId, int port, String[] hosts) {
Device device = model.getNDManager().getDevice();
int deviceId = device.getDeviceId();
int clusterSize = PyEnv.getClusterSize();
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
String entryPoint = pyEnv.getEntryPoint();
String recommendedEntryPoint = pyEnv.getRecommendedEntryPoint();

if (PyEnv.isMultiNode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree / clusterSize);
logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices);
StringBuilder sb = new StringBuilder();
boolean first = true;
for (String host : hosts) {
if (first) {
first = false;
} else {
sb.append(',');
}
sb.append(host).append(':').append(tensorParallelDegree / clusterSize);
}
String[] args = new String[46];
args[0] = "mpirun";
args[1] = "-np";
args[2] = String.valueOf(tensorParallelDegree);
args[3] = "--host";
args[4] = sb.toString();
args[5] = "--allow-run-as-root";
args[6] = "--bind-to";
args[7] = "none";
args[8] = "--mca";
args[9] = "orte_keep_fqdn_hostnames";
args[10] = "t";
args[11] = "--tag-output";
args[12] = "-x";
args[13] = "FI_PROVIDER=efa";
args[14] = "-x";
args[15] = "RDMAV_FORK_SAFE=1";
args[16] = "-x";
args[17] = "FI_EFA_USE_DEVICE_RDMA=1";
args[18] = "-x";
args[19] = "LD_LIBRARY_PATH";
args[20] = "-x";
args[21] = "PYTHONPATH";
args[22] = "-x";
args[23] = "CUDA_VISIBLE_DEVICES=" + cudaDevices;
args[24] = "-x";
args[25] = "MASTER_ADDR=" + pyEnv.getMasterAddr();
args[26] = "-x";
args[27] = "MKL_DYNAMIC=FALSE";
args[28] = pyEnv.getPythonExecutable();
args[29] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
args[30] = "--model-dir";
args[31] = model.getModelPath().toAbsolutePath().toString();
args[32] = "--entry-point";
args[33] = entryPoint == null ? "" : entryPoint;
args[34] = "--sock-type";
args[35] = "tcp";
args[36] = "--sock-name";
args[37] = "0.0.0.0";
args[38] = "--port";
args[39] = String.valueOf(port);
args[40] = "--tensor-parallel-degree";
args[41] = String.valueOf(tensorParallelDegree);
args[42] = "--cluster-size";
args[43] = String.valueOf(clusterSize);
args[44] = "--recommended-entry-point";
args[45] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
return args;
}

if (pyEnv.isMpiMode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
String[] args = new String[42];
args[0] = "mpirun";
args[1] = "-np";
// TODO: When we support multi nodes, change it to the product of tensor parallel value
// and
// pipeline parallel value.
args[2] = String.valueOf(tensorParallelDegree);
args[3] = "--allow-run-as-root";
args[4] = "--bind-to";
Expand All @@ -156,7 +219,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
args[20] = "-x";
args[21] = "CUDA_VISIBLE_DEVICES=" + cudaDevices;
args[22] = "-x";
args[23] = "MASTER_ADDR=" + MASTER_ADDR;
args[23] = "MASTER_ADDR=" + pyEnv.getMasterAddr();
args[24] = "-x";
args[25] = "MASTER_PORT=" + port;
args[26] = "-x";
Expand Down Expand Up @@ -196,7 +259,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
logger.info("Set OMP_NUM_THREADS={}", neuronThreads);
}
boolean uds = Epoll.isAvailable() || KQueue.isAvailable();
String[] args = new String[14];
String[] args = new String[16];
args[0] = pyEnv.getPythonExecutable();
args[1] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
args[2] = "--sock-type";
Expand All @@ -209,8 +272,10 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
args[9] = entryPoint == null ? "" : entryPoint;
args[10] = "--device-id";
args[11] = String.valueOf(deviceId);
args[12] = "--recommended-entry-point";
args[13] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
args[12] = "--cluster-size";
args[13] = String.valueOf(clusterSize);
args[14] = "--recommended-entry-point";
args[15] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
return args;
}

Expand Down Expand Up @@ -248,7 +313,8 @@ private static String getNeuronThreads(int tensorParallelDegree) {
return String.valueOf(1);
}

void connect() throws InterruptedException {
void connect() throws InterruptedException, UnknownHostException {
logger.debug("Connecting to socket: {}", socketAddress);
EventLoopGroup group = PyEnv.getEventLoopGroup();

Bootstrap clientBootstrap = new Bootstrap();
Expand Down Expand Up @@ -295,7 +361,10 @@ private static String getSocketPath(int port) {
return System.getProperty("java.io.tmpdir") + "/djl_sock." + port;
}

private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
private SocketAddress getSocketAddress(boolean mpiMode, int rank, String hostname) {
if (PyEnv.isMultiNode()) {
return new InetSocketAddress(hostname, port + rank);
}
if (mpiMode) {
return new DomainSocketAddress(getSocketPath(port) + '.' + rank);
}
Expand All @@ -307,16 +376,21 @@ private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
}

static EventLoopGroup newEventLoopGroup() {
if (PyEnv.isMultiNode()) {
return new NioEventLoopGroup(new DaemonThreadFactory());
}
if (Epoll.isAvailable()) {
return new EpollEventLoopGroup(new DaemonThreadFactory());
} else if (KQueue.isAvailable()) {
return new KQueueEventLoopGroup(new DaemonThreadFactory());
}

return new NioEventLoopGroup(new DaemonThreadFactory());
}

private static Class<? extends Channel> getClientChannel() {
if (PyEnv.isMultiNode()) {
return NioSocketChannel.class;
}
if (Epoll.isAvailable()) {
return EpollDomainSocketChannel.class;
} else if (KQueue.isAvailable()) {
Expand Down
27 changes: 26 additions & 1 deletion engines/python/src/main/java/ai/djl/python/engine/PyEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class PyEnv {

static final Logger logger = LoggerFactory.getLogger(PyEnv.class);

private static int clusterSize;
private static String engineCacheDir;
private static String version;
private static EventLoopGroup eventLoopGroup;
Expand Down Expand Up @@ -84,6 +85,7 @@ static synchronized void init() {
return;
}

setClusterSize();
eventLoopGroup = Connection.newEventLoopGroup();

Path tmp = null;
Expand Down Expand Up @@ -128,6 +130,20 @@ static synchronized void init() {
}
}

static void setClusterSize() {
if (clusterSize == 0) {
clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "1"));
}
}

static int getClusterSize() {
return clusterSize;
}

static boolean isMultiNode() {
return clusterSize > 1;
}

static String getVersion() {
return version;
}
Expand Down Expand Up @@ -304,6 +320,15 @@ public void setPythonExecutable(String pythonExecutable) {
this.pythonExecutable = pythonExecutable;
}

/**
* Returns the master address.
*
* @return the master address
*/
public String getMasterAddr() {
return Utils.getenv("MASTER_ADDR", "127.0.0.1");
}

/**
* Returns the tensor parallel degree.
*
Expand Down Expand Up @@ -339,7 +364,7 @@ public void setTensorParallelDegree(int tensorParallelDegree) {
}

int getMpiWorkers() {
int gpuCount = CudaUtils.getGpuCount();
int gpuCount = CudaUtils.getGpuCount() * clusterSize;
String visibleDevices = Utils.getenv("CUDA_VISIBLE_DEVICES");
if (gpuCount > 0 && visibleDevices != null) {
int visibleCount = visibleDevices.split(",").length;
Expand Down
Loading

0 comments on commit 70aca0c

Please sign in to comment.