Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🐛 BUG] Fix bugs when collecting results from mp.spawn in multi-GPU training #1875

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions docs/source/get_started/distributed_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,33 @@ In above example, you can create a new python file (e.g., `run_a.py`) on node A,
nproc = 4,
group_offset = 0
)

# Optional, only needed if you want to get the result of each process.
queue = mp.get_context('spawn').SimpleQueue()

config_dict = config_dict or {}
config_dict.update({
"world_size": args.world_size,
"ip": args.ip,
"port": args.port,
"nproc": args.nproc,
"offset": args.group_offset,
})
kwargs = {
"config_dict": config_dict,
"queue": queue, # Optional
}

mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
args=(args.model, args.dataset, args.config_file_list, kwargs),
nprocs=nproc,
join=True,
)

# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()


Then run the following command:

Expand All @@ -159,21 +171,33 @@ Similarly, you can create a new python file (e.g., `run_b.py`) on node B, and wr
nproc = 4,
group_offset = 4
)

# Optional, only needed if you want to get the result of each process.
queue = mp.get_context('spawn').SimpleQueue()

config_dict = config_dict or {}
config_dict.update({
"world_size": args.world_size,
"ip": args.ip,
"port": args.port,
"nproc": args.nproc,
"offset": args.group_offset,
})
kwargs = {
"config_dict": config_dict,
"queue": queue, # Optional
}

mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
args=(args.model, args.dataset, args.config_file_list, kwargs),
nprocs=nproc,
join=True,
)

# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()


Then run the following command:

Expand Down
1 change: 1 addition & 0 deletions recbole/quick_start/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from recbole.quick_start.quick_start import (
run,
run_recbole,
objective_function,
load_data_and_model,
Expand Down
103 changes: 84 additions & 19 deletions recbole/quick_start/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
########################
"""
import logging
from logging import getLogger

import sys
import torch.distributed as dist
from collections.abc import MutableMapping
from logging import getLogger


import pickle
from ray import tune

from recbole.config import Config
from recbole.data import (
create_dataset,
data_preparation,
save_split_dataloaders,
load_split_dataloaders,
)
from recbole.data.transform import construct_transform
from recbole.utils import (
Expand All @@ -39,8 +36,69 @@
)


def run(
model,
dataset,
config_file_list=None,
config_dict=None,
saved=True,
nproc=1,
world_size=-1,
ip="localhost",
port="5678",
group_offset=0,
):
if nproc == 1 and world_size <= 0:
res = run_recbole(
model=model,
dataset=dataset,
config_file_list=config_file_list,
config_dict=config_dict,
saved=saved,
)
else:
if world_size == -1:
world_size = nproc
import torch.multiprocessing as mp

# Refer to https://discuss.pytorch.org/t/problems-with-torch-multiprocess-spawn-and-simplequeue/69674/2
# https://discuss.pytorch.org/t/return-from-mp-spawn/94302/2
queue = mp.get_context('spawn').SimpleQueue()

config_dict = config_dict or {}
config_dict.update(
{
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": group_offset,
}
)
kwargs = {
"config_dict": config_dict,
"queue": queue,
}

mp.spawn(
run_recboles,
args=(model, dataset, config_file_list, kwargs),
nprocs=nproc,
join=True,
)

# Normally, there should be only one item in the queue
res = None if queue.empty() else queue.get()
return res


def run_recbole(
model=None, dataset=None, config_file_list=None, config_dict=None, saved=True
model=None,
dataset=None,
config_file_list=None,
config_dict=None,
saved=True,
queue=None,
):
r"""A fast running api, which includes the complete process of
training and testing a model on a specified dataset
Expand All @@ -51,6 +109,7 @@ def run_recbole(
config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
saved (bool, optional): Whether to save the model. Defaults to ``True``.
queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``.
"""
# configurations initialization
config = Config(
Expand Down Expand Up @@ -104,27 +163,33 @@ def run_recbole(
logger.info(set_color("best valid ", "yellow") + f": {best_valid_result}")
logger.info(set_color("test result", "yellow") + f": {test_result}")

return {
result = {
"best_valid_score": best_valid_score,
"valid_score_bigger": config["valid_metric_bigger"],
"best_valid_result": best_valid_result,
"test_result": test_result,
}

if not config["single_spec"]:
dist.destroy_process_group()

if config["local_rank"] == 0 and queue is not None:
queue.put(result) # for multiprocessing, e.g., mp.spawn

return result # for the single process


def run_recboles(rank, *args):
ip, port, world_size, nproc, offset = args[3:]
args = args[:3]
kwargs = args[-1]
if not isinstance(kwargs, MutableMapping):
raise ValueError(
f"The last argument of run_recboles should be a dict, but got {type(kwargs)}"
)
kwargs["config_dict"] = kwargs.get("config_dict", {})
kwargs["config_dict"]["local_rank"] = rank
run_recbole(
*args,
config_dict={
"local_rank": rank,
"world_size": world_size,
"ip": ip,
"port": port,
"nproc": nproc,
"offset": offset,
},
*args[:3],
**kwargs,
)


Expand Down
36 changes: 11 additions & 25 deletions run_recbole.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
# @Email : chenyuwuxinn@gmail.com, houyupeng@ruc.edu.cn, zhlin@ruc.edu.cn

import argparse
from ast import arg

from recbole.quick_start import run_recbole, run_recboles
from recbole.quick_start import run

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -44,26 +43,13 @@
args.config_files.strip().split(" ") if args.config_files else None
)

if args.nproc == 1 and args.world_size <= 0:
run_recbole(
model=args.model, dataset=args.dataset, config_file_list=config_file_list
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
run(
args.model,
args.dataset,
config_file_list=config_file_list,
nproc=args.nproc,
world_size=args.world_size,
ip=args.ip,
port=args.port,
group_offset=args.group_offset,
)
44 changes: 11 additions & 33 deletions run_recbole_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,10 @@


import argparse
from ast import arg

from recbole.quick_start import run_recbole, run_recboles
from recbole.quick_start import run
from recbole.utils import list_to_latex


def run(args, model, config_file_list):
if args.nproc == 1 and args.world_size <= 0:
res = run_recbole(
model=model,
dataset=args.dataset,
config_file_list=config_file_list,
)
else:
if args.world_size == -1:
args.world_size = args.nproc
import torch.multiprocessing as mp

res = mp.spawn(
run_recboles,
args=(
args.model,
args.dataset,
config_file_list,
args.ip,
args.port,
args.world_size,
args.nproc,
args.group_offset,
),
nprocs=args.nproc,
)
return res


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -92,7 +61,16 @@ def run(args, model, config_file_list):

valid_res_dict = {"Model": model}
test_res_dict = {"Model": model}
result = run(args, model, config_file_list)
result = run(
model,
args.dataset,
config_file_list=config_file_list,
nproc=args.nproc,
world_size=args.world_size,
ip=args.ip,
port=args.port,
group_offset=args.group_offset,
)
valid_res_dict.update(result["best_valid_result"])
test_res_dict.update(result["test_result"])
bigger_flag = result["valid_score_bigger"]
Expand Down
Loading
Loading