-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multi node for lmi-dist #2125
Conversation
bf4e1ac
to
d3f7acf
Compare
logger.info("Printing mpi boolean: {}", pyEnv.isMpiMode()); | ||
|
||
if (clusterSize > 1) { | ||
String leaderAddress = Utils.getenv("LWS_LEADER_ADDRESS"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need revisit all env names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These names are fixed by LWS env, we cannot change this in LWS+EKS env, except the DJL_CLUSTER_SIZE
6db60b5
to
7d63ed7
Compare
engines/python/src/main/java/ai/djl/python/engine/Connection.java
Outdated
Show resolved
Hide resolved
engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Outdated
Show resolved
Hide resolved
engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Outdated
Show resolved
Hide resolved
engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Outdated
Show resolved
Hide resolved
@@ -187,6 +217,20 @@ synchronized void startPythonProcess() { | |||
} | |||
} | |||
|
|||
public static String[] getHosts() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better create a object to hold all those values.
Device device = model.getNDManager().getDevice(); | ||
int deviceId = device.getDeviceId(); | ||
int tensorParallelDegree = pyEnv.getTensorParallelDegree(); | ||
// int pipelineParallelDegree = pyEnv.getPipelineParallelDegree(); | ||
|
||
if (clusterSize > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better for us to wrap the settings into different function for better debug-bility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a PyEnv.isMultiNode()
utility function.
args[6] = "--bind-to"; | ||
args[7] = "none"; | ||
args[8] = "--mca"; | ||
args[9] = "btl_vader_single_copy_mechanism"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the mca for multi-node might be differemt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
@@ -112,10 +113,12 @@ def construct_kwargs_device_map(self): | |||
self.kwargs["device_map"] = self.device_map | |||
self.device = None | |||
logging.info(f"Using device map {self.device_map}") | |||
elif self.tensor_parallel_degree > 0 and torch.cuda.device_count() > 0: | |||
elif self.tensor_parallel_degree > 0 \ | |||
and self.cluster_size > 0 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which case cluster_size is 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cluster_size should not be 0. Just having a safe check.
52cd931
to
1ab823e
Compare
* | ||
* @return the pipeline parallel degree | ||
*/ | ||
public int getPipelineParallelDegree() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove the pipelineParallelDegree read from this PR as we're not directly using it here. I will add this separately later as required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
a496758
to
7caaad0
Compare
a38f240
to
2b5c3cb
Compare
String[] res = new String[clusterSize]; | ||
res[0] = leaderAddress; | ||
for (int i = 1; i < clusterSize; i++) { | ||
res[i] = String.format("%s-%s-%d.%s.%s", lwsName, groupIndex, i, lwsName, namespace); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res[i] = String.format("%s-%s-%d.%s.%s", lwsName, groupIndex, i, lwsName, namespace); | |
res[i] = lwsName + '-' + groupIndex + '-' + i + '.' + lwsName + '.' + namespace; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced the environment variable name.
9ddc65c
to
c3471e3
Compare
Description
Brief description of what this PR is about