-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Updates to support xgboost==2.1.0
#46667
Changes from 9 commits
386f731
d82323c
da5eadc
49d645c
e6297c4
f97bbf4
661d6ad
2170aa3
6ea9533
dc3c878
030f415
2a83490
7ad42d4
a7ac8a0
b09df85
8518fee
3a9fed6
fd9ad1a
5053ef8
f40de02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
import json | ||
import logging | ||
import os | ||
import threading | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import xgboost | ||
from packaging.version import Version | ||
from xgboost import RabitTracker | ||
from xgboost.collective import CommunicatorContext | ||
|
||
|
@@ -37,7 +41,7 @@ class XGBoostConfig(BackendConfig): | |
def train_func_context(self): | ||
@contextmanager | ||
def collective_communication_context(): | ||
with CommunicatorContext(): | ||
with CommunicatorContext(**_get_xgboost_args()): | ||
yield | ||
|
||
return collective_communication_context | ||
|
@@ -52,13 +56,10 @@ def backend_cls(self): | |
|
||
class _XGBoostRabitBackend(Backend): | ||
def __init__(self): | ||
self._tracker = None | ||
|
||
def on_training_start( | ||
self, worker_group: WorkerGroup, backend_config: XGBoostConfig | ||
): | ||
assert backend_config.xgboost_communicator == "rabit" | ||
self._tracker: Optional[RabitTracker] = None | ||
self._wait_thread: Optional[threading.Thread] = None | ||
|
||
def _setup_xgboost_less_than_210(self, worker_group: WorkerGroup): | ||
# Set up the rabit tracker on the Train driver. | ||
num_workers = len(worker_group) | ||
rabit_args = {"DMLC_NUM_WORKER": num_workers} | ||
|
@@ -67,12 +68,14 @@ def on_training_start( | |
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks | ||
# align with Ray Train worker ranks. | ||
# The worker ranks will be sorted by `DMLC_TASK_ID`, | ||
# which is defined in `on_training_start`. | ||
# which is defined below. | ||
self._tracker = RabitTracker( | ||
host_ip=train_driver_ip, n_workers=num_workers, sortby="task" | ||
n_workers=num_workers, host_ip=train_driver_ip, sortby="task" | ||
) | ||
rabit_args.update(self._tracker.worker_envs()) | ||
self._tracker.start(num_workers) | ||
self._tracker.start(n_workers=num_workers) | ||
|
||
worker_args = self._tracker.worker_envs() | ||
rabit_args.update(worker_args) | ||
|
||
start_log = ( | ||
"RabitTracker coordinator started with parameters:\n" | ||
|
@@ -95,13 +98,80 @@ def set_xgboost_env_vars(): | |
|
||
worker_group.execute(set_xgboost_env_vars) | ||
|
||
def _setup_xgboost(self, worker_group: WorkerGroup): | ||
# Set up the rabit tracker on the Train driver. | ||
num_workers = len(worker_group) | ||
rabit_args = {"n_workers": num_workers} | ||
train_driver_ip = ray.util.get_node_ip_address() | ||
|
||
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks | ||
# align with Ray Train worker ranks. | ||
# The worker ranks will be sorted by `dmlc_task_id`, | ||
# which is defined below. | ||
self._tracker = RabitTracker( | ||
n_workers=num_workers, host_ip=train_driver_ip, sortby="task" | ||
) | ||
self._tracker.start() | ||
|
||
# The RabitTracker is started in a separate thread, and the | ||
# `wait_for` method must be called for `worker_args` to return. | ||
self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True) | ||
self._wait_thread.start() | ||
|
||
rabit_args.update(self._tracker.worker_args()) | ||
|
||
start_log = ( | ||
"RabitTracker coordinator started with parameters:\n" | ||
f"{json.dumps(rabit_args, indent=2)}" | ||
) | ||
logger.debug(start_log) | ||
|
||
def set_xgboost_communicator_args(args): | ||
import ray.train | ||
|
||
args["dmlc_task_id"] = ( | ||
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:" | ||
f"{ray.get_runtime_context().get_actor_id()}" | ||
) | ||
|
||
_set_xgboost_args(args) | ||
|
||
worker_group.execute(set_xgboost_communicator_args, rabit_args) | ||
|
||
def on_training_start( | ||
self, worker_group: WorkerGroup, backend_config: XGBoostConfig | ||
): | ||
assert backend_config.xgboost_communicator == "rabit" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it seems There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I can probably remove this field for now, since we don't support the "federated" option. |
||
|
||
if Version(xgboost.__version__) < Version("2.1.0"): | ||
self._setup_xgboost_less_than_210(worker_group) | ||
else: | ||
self._setup_xgboost(worker_group) | ||
|
||
def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig): | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
timeout = 5 | ||
self._tracker.thread.join(timeout=timeout) | ||
|
||
if self._tracker.thread.is_alive(): | ||
logger.warning( | ||
"During shutdown, the RabitTracker thread failed to join " | ||
f"within {timeout} seconds. " | ||
"The process will still be terminated as part of Ray actor cleanup." | ||
) | ||
if self._wait_thread is not None: | ||
self._wait_thread.join(timeout=timeout) | ||
|
||
if self._wait_thread.is_alive(): | ||
logger.warning( | ||
"During shutdown, the RabitTracker thread failed to join " | ||
f"within {timeout} seconds. " | ||
"The process will still be terminated as part of Ray actor cleanup." | ||
) | ||
|
||
|
||
_xgboost_args: dict = {} | ||
_xgboost_args_lock = threading.Lock() | ||
|
||
|
||
def _set_xgboost_args(args): | ||
with _xgboost_args_lock: | ||
global _xgboost_args | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_xgboost_args = args | ||
|
||
|
||
def _get_xgboost_args() -> dict: | ||
with _xgboost_args_lock: | ||
return _xgboost_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we able to save the xgboost_args into XGBoost config so we can avoid modifying the global variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, interesting. I actually don't understand why we need both
BackendConfig
andBackend
classes. Any context here @matthewdeng ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
BackendConfig
is the public API that the user could interact with. There is probably a better/cleaner way to organize the two.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah currently the dependency between
BackendConfig
andBackend
are unidirectional. It's kind of hard to pass information fromBacked -> BackendConfig
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should
train_func_context
be part of theBackend
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or at the very least the default one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh hm maybe that won't work because we construct the train loop before the backend...