Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Multi-GPU error related to process initialization? #4289

Closed
elkotito opened this issue May 26, 2020 · 9 comments
Closed

Multi-GPU error related to process initialization? #4289

elkotito opened this issue May 26, 2020 · 9 comments

Comments

@elkotito
Copy link

System (please complete the following information):

  • OS: Ubuntu 18.04
  • AllenNLP version: v1.0.0.rc4
  • PyTorch version: 1.4.0

Question
I tried to run some model using distributed computation:

  trainer: {
    cuda_device : 0,
    distributed: true,
    world_size: 4,

but I failed 😢

Traceback (most recent call last):
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/__main__.py", line 23, in <module>
    run()
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/__main__.py", line 19, in run
    main(prog="allennlp")
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/__init__.py", line 92, in main
    args.func(args)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 112, in train_model_from_args
    dry_run=args.dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 171, in train_model_from_file
    dry_run=dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 230, in train_model
    dry_run=dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 418, in _train_worker
    params=params, serialization_dir=serialization_dir, local_rank=process_rank,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 580, in from_params
    **extras,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 611, in from_params
    return constructor_to_call(**kwargs)  # type: ignore
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 669, in from_partial_objects
    model=model_, data_loader=data_loader_, validation_data_loader=validation_data_loader_,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/lazy.py", line 46, in construct
    return self._constructor(**kwargs)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 446, in constructor
    return value_cls.from_params(params=deepcopy(popped_params), **constructor_extras)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 580, in from_params
    **extras,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 611, in from_params
    return constructor_to_call(**kwargs)  # type: ignore
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/training/trainer.py", line 1065, in from_partial_objects
    opt_level=opt_level,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/training/trainer.py", line 407, in __init__
    find_unused_parameters=True,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 255, in __init__
    self.process_group = _get_default_group()
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py", line 262, in _get_default_group
    raise RuntimeError("Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Process finished with exit code 1

I suspect that the work is ongoing (I use v1.0.0rc4), but clearly I need to somehow initialize a process group like below, which is either missing in the code or it's me who miss something?

dist.init_process_group("gloo", rank=rank, world_size=world_size)

@epwalsh
Copy link
Member

epwalsh commented May 26, 2020

Hi @mateuszpieniak, sorry that our distributed training mechanism is a little opaque at the moment. We're have plans of adding a tutorial soon.

The issue here is with your configuration file. You shouldn't set "distributed", "cuda_device", or "world_size" in the "trainer" part of your config. Instead, should just specify distributed training at the top level of your config like this: https://github.com/allenai/allennlp-models/blob/transformer-qa-training/training_config/rc/transformer_qa_distributed.jsonnet#L43

@elkotito
Copy link
Author

elkotito commented May 26, 2020

@epwalsh That's fine, thank you! I have a question though. Should batch_size be configured "per device"? I mean whether 1 GPU, batch_size = 32 is equivalent to 4 GPU, batch_size = 8? I guess it won't be precisely equivalent because DDP uses average, but I ask about the rule.

@epwalsh
Copy link
Member

epwalsh commented May 26, 2020

Yes, batch_size should be interpreted as "per device batch size". With 4 GPUs, setting batch_size = 8 means each individual worker uses a batch size of 8, making your "effective batch size" 32.

@epwalsh
Copy link
Member

epwalsh commented May 26, 2020

I think we can close this, but if you have any other issues feel free to reach out!

@epwalsh epwalsh closed this as completed May 26, 2020
@elkotito
Copy link
Author

elkotito commented May 27, 2020

@epwalsh Just a quick follow up

  1. I have one node with 32 GB of RAM and 4 GPUs with 11 GB of memory each. I load the MNLI dataset and spawn into 4 processes. The spawned function for each process is _train_worker. It runs internally TrainModel.from_params(...), which reads the whole dataset datasets = training_util.read_all_datasets(...). My machine hangs out due to OOM (RAM). I think it's related to Using DistributedDataParallel for multi GPU training #2536, in particular:

We do not need to worry about the dataset readers reading multiple copies of the data. People who care about this can write new dataset readers in exactly the way you have for the snli reader.

  1. Alright, so I decided to use a lazy flag for DatasetReader. Unfortunately it doesn't work, because the exception is raised here (even for simple CPU case):
Traceback (most recent call last):
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/__main__.py", line 23, in <module>
    run()
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/__main__.py", line 19, in run
    main(prog="allennlp")
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/__init__.py", line 92, in main
    args.func(args)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 144, in train_model_from_args
    dry_run=args.dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 203, in train_model_from_file
    dry_run=dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 266, in train_model
    dry_run=dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 453, in _train_worker
    batch_weight_key=batch_weight_key,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 580, in from_params
    **extras,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 611, in from_params
    return constructor_to_call(**kwargs)  # type: ignore
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 665, in from_partial_objects
    dataset.index_with(model_.vocab)
AttributeError: 'itertools.islice' object has no attribute 'index_with'

Process finished with exit code 1

Updated to v1.0.0rc5, it seems like the work is ongoing.

  1. Now I can use a lazy data reader, but it doesn't allow me to use BucketSampler. It is understandable, but perhaps it should be noted in DataReader's docs? You may think it's not necessary because it comes from torch.data.DataLoader, but since AllenNLP uses DataReader instead of Dataset from the user's perspective, it would be nice.
Traceback (most recent call last):
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/__main__.py", line 23, in <module>
    run()
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/__main__.py", line 19, in run
    main(prog="allennlp")
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/__init__.py", line 92, in main
    args.func(args)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 112, in train_model_from_args
    dry_run=args.dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 171, in train_model_from_file
    dry_run=dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 230, in train_model
    dry_run=dry_run,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 418, in _train_worker
    params=params, serialization_dir=serialization_dir, local_rank=process_rank,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 580, in from_params
    **extras,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 611, in from_params
    return constructor_to_call(**kwargs)  # type: ignore
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/commands/train.py", line 644, in from_partial_objects
    data_loader_ = data_loader.construct(dataset=datasets["train"])
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/lazy.py", line 46, in construct
    return self._constructor(**kwargs)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 446, in constructor
    return value_cls.from_params(params=deepcopy(popped_params), **constructor_extras)
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 580, in from_params
    **extras,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/common/from_params.py", line 611, in from_params
    return constructor_to_call(**kwargs)  # type: ignore
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/data/dataloader.py", line 143, in from_partial_objects
    batches_per_epoch=batches_per_epoch,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/allennlp/data/dataloader.py", line 80, in __init__
    multiprocessing_context=multiprocessing_context,
  File "/home/pi3ni0/.venv/dev/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 184, in __init__
    "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
ValueError: DataLoader with IterableDataset: expected unspecified batch_sampler option, but got batch_sampler=<allennlp.data.samplers.bucket_batch_sampler.BucketBatchSampler object at 0x7fb7e3d16f28>

Process finished with exit code 1
  1. I noticed in the code that vocabulary is created for each process separately, but on the subset of data! I may be mistaken, but I didn't see any aggregation function for vocab's parts. It would mean it's a bug then. An interesting case also: if a user overwrote the DataReader in such way that the whole dataset goes to each process, then counting would be done n times for each token. Perhaps I can help with that?

Since I mostly fine-tune the model, is it possible to "disable vocab building" to speed up (just for now)?

@epwalsh
Copy link
Member

epwalsh commented May 27, 2020

Hi @mateuszpieniak, would you mind making a PR to add to the docs to address # 3?

As for # 4, the vocab is created from all of the instances by the main process before any of the workers are spawned. See here: https://github.com/allenai/allennlp/blob/master/allennlp/commands/train.py#L271

The vocab is then saved to the serialization directory, and then the "vocab" params are modified so that that each spawned worker just reads that vocab from the saved files: https://github.com/allenai/allennlp/blob/master/allennlp/commands/train.py#L274

@epwalsh
Copy link
Member

epwalsh commented May 27, 2020

Let me know if that doesn't make sense, I'm pretty new to the distributed code.

@elkotito
Copy link
Author

elkotito commented Jun 2, 2020

@epwalsh Sure, I will do a PR. I think it makes sense if the training takes place on a single machine. Otherwise, the vocabulary should be sent over the network to the workers, because they cannot read from_files. Don't worry, I am also pretty new to distributed code 😉

Btw, how does gradient accumulation work for distributed training? Is it the number of steps per worker before the gradients are sent to the master? Let's consider an example with 4 GPUs with batch_size == 4 and num_gradient_accumulation_steps == 8. Is the "effective batch size" equals to 128 then?

@epwalsh
Copy link
Member

epwalsh commented Jun 2, 2020

You're right, this won't work over the network. Currently our distributed training will only work on a single machine (so, not exactly distributed, but in theory it's more efficient than the old multi-GPU training that used DataParallel instead of DistributedDataParallel).

And yes, the num_gradient_accumulation_steps parameter is given to each worker's trainer as-is. So in your example, each worker does 8 batches of size 4 before the gradients are synchronized and an optimization step is performed.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants