Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed.zero.Init causes infinite recursion error #2139

Closed
awaelchli opened this issue Jul 26, 2022 · 7 comments · Fixed by #2150 or Lightning-AI/pytorch-lightning#13967
Closed

deepspeed.zero.Init causes infinite recursion error #2139

awaelchli opened this issue Jul 26, 2022 · 7 comments · Fixed by #2150 or Lightning-AI/pytorch-lightning#13967
Labels
bug Something isn't working inference

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Jul 26, 2022

Describe the bug

When the deepspeed.zero.Init wraps not only the model but also the deepspeed.initialize call, a RecursionError is raised.
This happens in deepspeed 0.6.5 but NOT in 0.6.4. It blocks the integration with Lightning Lite where we until now wrapped the entire run() method with the context.

To Reproduce

import argparse
import os

import deepspeed
import torch
import torch.nn as nn


class TheModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2, bias=False)


config = {
    "activation_checkpointing": {
        "contiguous_memory_optimization": False,
        "cpu_checkpointing": False,
        "partition_activations": False,
        "synchronize_checkpoint_boundary": False,
    },
    "aio": {
        "block_size": 1048576,
        "overlap_events": True,
        "queue_depth": 8,
        "single_submit": False,
        "thread_count": 1,
    },
    "train_micro_batch_size_per_gpu": 1,
    "zero_allow_untested_optimizer": True,
    "zero_optimization": {
        "allgather_bucket_size": 200000000,
        "allgather_partitions": True,
        "contiguous_gradients": True,
        "overlap_comm": True,
        "reduce_bucket_size": 200000000,
        "reduce_scatter": True,
        "stage": 3,
        "sub_group_size": 1000000000000,
    },
}


def worker(rank):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12234"
    os.environ["WORLD_SIZE"] = "2"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    deepspeed.init_distributed()

    model_parallel_context = deepspeed.zero.Init(
        remote_device="cpu", pin_memory=True, config_dict_or_path=config, dtype=torch.float32
    )

    # If the context goes over the model AND the deepspeed.initilize call, we get an infinite recursion error
    # This worked in 0.6.4, but not in 0.6.5
    with model_parallel_context:
        model = TheModel()

    # If the context only goes over the model, no error occurs (unindent the lines below)
        deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
            args=argparse.Namespace(device_rank=rank),
            model=model,
            # model_parameters=model.parameters(),
            # optimizer=optimizer,
            dist_init_required=False,
            config=config,
        )


if __name__ == "__main__":
    torch.multiprocessing.spawn(worker, nprocs=2)

Output (may need to press ctrl+c on hang):

  File "/home/adrian/anaconda3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 432, in __getattr__
    if name in dir(self):
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1847, in __dir__
    parameters = list(self._parameters.keys())
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 432, in __getattr__
    if name in dir(self):
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in __dir__
    module_attrs = dir(self.__class__)
RecursionError: maximum recursion depth exceeded while calling a Python object

Expected behavior
This worked in 0.6.4, so my assumption is that the change was unintentional. Git blame points to #1915. We weren't able to spot exactly which lines caused it, but suspect the getattr changes on the deepspeed engine.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/adrian/anaconda3/envs/lightning/lib/python3.10/site-packages/torch']
torch version .................... 1.11.0
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.1
deepspeed install path ........... ['/home/adrian/anaconda3/envs/lightning/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.6.4, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.11, cuda 11.3

System info (please complete the following information):

  • OS: Ubuntu
  • GPU count and types: 2x RTX3090
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version: 3.9
  • Any other relevant info about your setup

Launcher context
I'm launching using torch.multiprocessing for simplicity in reproducing, but the bug is unrelated to how it is getting launched.

Docker context
No docker

Additional context

@tjruwase
Copy link
Contributor

@awaelchli, thanks for reporting this issue. zero.Init() is meant to wrap model construction only. Is there a reason why you are wrapping the entire execution, including deepspeed.initialize()?

@awaelchli
Copy link
Contributor Author

awaelchli commented Jul 27, 2022

We are providing a wrapper around the user code in Lightning Lite. It looks something like this:

class Lite(LightningLite):
    def run(self):
        # users code goes here (model, training loop, etc.)

       model = ...
       model, optimizer = self.setup(model, optimizer)

       # train model
         
if __name__ == "__main__":
    lite = Lite(accelerator="gpu", devices=2)
    lite.run()

We have a deepspeed integration. Without any major changes to their code, the user can turn deepspeed on by changing this:

Lite(accelerator="gpu", devices=2, strategy=DeepspeedStrategy(stage=3, ...))

Internally, we wrap the run method with the deepspeed.zero.Init context manager so the user doesn't have to call it themselves. Then it would go over the entire user's code, including the self.setup() call where we initialize the deepspeed.
Previously this worked as it wouldn't interfere with the deepspeed.initialize.

It would be great if this could still be supported. If not, we'd need to introduce a context manager in Lite that the user has to call over their model instantiation.

cc @carmocca

@tjruwase
Copy link
Contributor

@awaelchli, thanks for sharing this context. Honestly, I don't know why it worked previously as wrapping deepspeed.initialize() was not an intended use for zero.Init. So, I am quite puzzled as to the recent changes that could have broken the integration. I will try to repro the failure.

@jeffra
Copy link
Collaborator

jeffra commented Jul 28, 2022

@awaelchli can you give this #2150 a try? In our local tests it seems to have fixed your issue though.

@awaelchli
Copy link
Contributor Author

@jeffra Thank you very much for the quick response on the issue. I just tried the fix in my simple test and it works well! ❤️

@jeffra
Copy link
Collaborator

jeffra commented Jul 30, 2022

@awaelchli please re-open if this issue isn't resolved for you. The related PR is now merged into master

@awaelchli
Copy link
Contributor Author

Fantastic. Thank you for resolving this so quickly!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
3 participants