Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Updates to support xgboost==2.1.0 #46667

Merged
merged 20 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
386f731
make compatible for xgboost 2.1.0
justinvyu Jul 16, 2024
d82323c
update xgboost to 2.1.0
justinvyu Jul 16, 2024
da5eadc
make global var naming private
justinvyu Jul 16, 2024
49d645c
small cleanup
justinvyu Jul 16, 2024
e6297c4
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Jul 17, 2024
f97bbf4
update req-compiled
justinvyu Jul 17, 2024
661d6ad
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Jul 23, 2024
2170aa3
update requirements compiled
justinvyu Jul 24, 2024
6ea9533
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Jul 25, 2024
dc3c878
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Jul 25, 2024
030f415
[TEMP] remove ci dep for pip compile to run
justinvyu Jul 25, 2024
2a83490
update req compiled
justinvyu Jul 25, 2024
7ad42d4
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Aug 1, 2024
a7ac8a0
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Aug 6, 2024
b09df85
separate into 2 different classes
justinvyu Aug 6, 2024
8518fee
TEMP: add nvidia nccl dep
justinvyu Aug 6, 2024
3a9fed6
update req compiled
justinvyu Aug 6, 2024
fd9ad1a
revert TEMP
justinvyu Aug 6, 2024
5053ef8
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Aug 6, 2024
f40de02
Merge branch 'master' of https://github.com/ray-project/ray into xgb2…
justinvyu Aug 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions python/ray/train/xgboost/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import json
import logging
import os
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional

import xgboost
from packaging.version import Version
from xgboost import RabitTracker
from xgboost.collective import CommunicatorContext

Expand Down Expand Up @@ -37,7 +41,7 @@ class XGBoostConfig(BackendConfig):
def train_func_context(self):
@contextmanager
def collective_communication_context():
with CommunicatorContext():
with CommunicatorContext(**_get_xgboost_args()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to save the xgboost_args into XGBoost config so we can avoid modifying the global variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, interesting. I actually don't understand why we need both BackendConfig and Backend classes. Any context here @matthewdeng ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BackendConfig is the public API that the user could interact with. There is probably a better/cleaner way to organize the two.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah currently the dependency between BackendConfig and Backend are unidirectional. It's kind of hard to pass information from Backed -> BackendConfig.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should train_func_context be part of the Backend instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at the very least the default one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh hm maybe that won't work because we construct the train loop before the backend...

yield

return collective_communication_context
Expand All @@ -52,13 +56,10 @@ def backend_cls(self):

class _XGBoostRabitBackend(Backend):
def __init__(self):
self._tracker = None

def on_training_start(
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
):
assert backend_config.xgboost_communicator == "rabit"
self._tracker: Optional[RabitTracker] = None
self._wait_thread: Optional[threading.Thread] = None

def _setup_xgboost_less_than_210(self, worker_group: WorkerGroup):
# Set up the rabit tracker on the Train driver.
num_workers = len(worker_group)
rabit_args = {"DMLC_NUM_WORKER": num_workers}
Expand All @@ -67,12 +68,14 @@ def on_training_start(
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
# align with Ray Train worker ranks.
# The worker ranks will be sorted by `DMLC_TASK_ID`,
# which is defined in `on_training_start`.
# which is defined below.
self._tracker = RabitTracker(
host_ip=train_driver_ip, n_workers=num_workers, sortby="task"
n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
)
rabit_args.update(self._tracker.worker_envs())
self._tracker.start(num_workers)
self._tracker.start(n_workers=num_workers)

worker_args = self._tracker.worker_envs()
rabit_args.update(worker_args)

start_log = (
"RabitTracker coordinator started with parameters:\n"
Expand All @@ -95,13 +98,80 @@ def set_xgboost_env_vars():

worker_group.execute(set_xgboost_env_vars)

def _setup_xgboost(self, worker_group: WorkerGroup):
# Set up the rabit tracker on the Train driver.
num_workers = len(worker_group)
rabit_args = {"n_workers": num_workers}
train_driver_ip = ray.util.get_node_ip_address()

# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
# align with Ray Train worker ranks.
# The worker ranks will be sorted by `dmlc_task_id`,
# which is defined below.
self._tracker = RabitTracker(
n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
)
self._tracker.start()

# The RabitTracker is started in a separate thread, and the
# `wait_for` method must be called for `worker_args` to return.
self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True)
self._wait_thread.start()

rabit_args.update(self._tracker.worker_args())

start_log = (
"RabitTracker coordinator started with parameters:\n"
f"{json.dumps(rabit_args, indent=2)}"
)
logger.debug(start_log)

def set_xgboost_communicator_args(args):
import ray.train

args["dmlc_task_id"] = (
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
f"{ray.get_runtime_context().get_actor_id()}"
)

_set_xgboost_args(args)

worker_group.execute(set_xgboost_communicator_args, rabit_args)

def on_training_start(
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
):
assert backend_config.xgboost_communicator == "rabit"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems XGBoostingConfig has a hard coded backend_config field being "rabit", why do we still need an assertion here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I can probably remove this field for now, since we don't support the "federated" option.


if Version(xgboost.__version__) < Version("2.1.0"):
self._setup_xgboost_less_than_210(worker_group)
else:
self._setup_xgboost(worker_group)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
timeout = 5
self._tracker.thread.join(timeout=timeout)

if self._tracker.thread.is_alive():
logger.warning(
"During shutdown, the RabitTracker thread failed to join "
f"within {timeout} seconds. "
"The process will still be terminated as part of Ray actor cleanup."
)
if self._wait_thread is not None:
self._wait_thread.join(timeout=timeout)

if self._wait_thread.is_alive():
logger.warning(
"During shutdown, the RabitTracker thread failed to join "
f"within {timeout} seconds. "
"The process will still be terminated as part of Ray actor cleanup."
)


_xgboost_args: dict = {}
_xgboost_args_lock = threading.Lock()


def _set_xgboost_args(args):
with _xgboost_args_lock:
global _xgboost_args
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
_xgboost_args = args


def _get_xgboost_args() -> dict:
with _xgboost_args_lock:
return _xgboost_args
2 changes: 1 addition & 1 deletion python/requirements/ml/core-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mlflow==2.9.2
wandb==0.17.0

# ML training frameworks
xgboost==1.7.6
xgboost==2.1.0
hongpeng-guo marked this conversation as resolved.
Show resolved Hide resolved
lightgbm==3.3.5

# Huggingface
Expand Down
Loading