Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP initialization breaks on HPC system at larger scale #10216

Closed
proutrc opened this issue Oct 28, 2021 · 18 comments · Fixed by #10825
Closed

DDP initialization breaks on HPC system at larger scale #10216

proutrc opened this issue Oct 28, 2021 · 18 comments · Fixed by #10825
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on won't fix This will not be worked on
Milestone

Comments

@proutrc
Copy link

proutrc commented Oct 28, 2021

We are experiencing an issue with DDP at larger scales on our HPC system (Summit at OLCF - LSF scheduler). The specific threshold is at 14 nodes, where things suddenly aren't able to initialize anymore. It appears there is all of a sudden an inability to setup ranks across nodes properly, as depicted below in the output.

Each node as 6 GPUs. So, in total, we are trying to use 84 GPUs when things are suddenly unable to initialize. At 78 GPUs (or 13 nodes) things work as expected.

Initialization output at 13 nodes (78 GPUs - this works as expected):

initializing ddp: GLOBAL_RANK: 72, MEMBER: 73/78
initializing ddp: GLOBAL_RANK: 36, MEMBER: 37/78
initializing ddp: GLOBAL_RANK: 10, MEMBER: 11/78
initializing ddp: GLOBAL_RANK: 19, MEMBER: 20/78
initializing ddp: GLOBAL_RANK: 32, MEMBER: 33/78
initializing ddp: GLOBAL_RANK: 12, MEMBER: 13/78
initializing ddp: GLOBAL_RANK: 41, MEMBER: 42/78
initializing ddp: GLOBAL_RANK: 74, MEMBER: 75/78
initializing ddp: GLOBAL_RANK: 27, MEMBER: 28/78
initializing ddp: GLOBAL_RANK: 21, MEMBER: 22/78
initializing ddp: GLOBAL_RANK: 13, MEMBER: 14/78
initializing ddp: GLOBAL_RANK: 73, MEMBER: 74/78
initializing ddp: GLOBAL_RANK: 25, MEMBER: 26/78
initializing ddp: GLOBAL_RANK: 24, MEMBER: 25/78
initializing ddp: GLOBAL_RANK: 28, MEMBER: 29/78
initializing ddp: GLOBAL_RANK: 67, MEMBER: 68/78
initializing ddp: GLOBAL_RANK: 69, MEMBER: 70/78
initializing ddp: GLOBAL_RANK: 68, MEMBER: 69/78
initializing ddp: GLOBAL_RANK: 47, MEMBER: 48/78
initializing ddp: GLOBAL_RANK: 56, MEMBER: 57/78
initializing ddp: GLOBAL_RANK: 71, MEMBER: 72/78
initializing ddp: GLOBAL_RANK: 53, MEMBER: 54/78
initializing ddp: GLOBAL_RANK: 66, MEMBER: 67/78
initializing ddp: GLOBAL_RANK: 70, MEMBER: 71/78
initializing ddp: GLOBAL_RANK: 55, MEMBER: 56/78
initializing ddp: GLOBAL_RANK: 49, MEMBER: 50/78
initializing ddp: GLOBAL_RANK: 44, MEMBER: 45/78
initializing ddp: GLOBAL_RANK: 46, MEMBER: 47/78
initializing ddp: GLOBAL_RANK: 57, MEMBER: 58/78
initializing ddp: GLOBAL_RANK: 51, MEMBER: 52/78
initializing ddp: GLOBAL_RANK: 45, MEMBER: 46/78
initializing ddp: GLOBAL_RANK: 54, MEMBER: 55/78
initializing ddp: GLOBAL_RANK: 50, MEMBER: 51/78
initializing ddp: GLOBAL_RANK: 59, MEMBER: 60/78
initializing ddp: GLOBAL_RANK: 52, MEMBER: 53/78
initializing ddp: GLOBAL_RANK: 43, MEMBER: 44/78
initializing ddp: GLOBAL_RANK: 58, MEMBER: 59/78
initializing ddp: GLOBAL_RANK: 42, MEMBER: 43/78
initializing ddp: GLOBAL_RANK: 48, MEMBER: 49/78
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All DDP processes registered. Starting ddp with 78 processes

Failed initialization at 14 nodes (84 GPUs - this hangs at this point):

initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84

Here is the code:

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28)
        )

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)

        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# define datasets/dataloaders
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_dl = DataLoader(dataset)


# train
model = LitAutoEncoder()
trainer = pl.Trainer(gpus="0,1,2,3,4,5", auto_select_gpus=True, num_nodes=14, max_epochs=3, accelerator='ddp')
trainer.fit(model, train_dl)
  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.9
  • PyTorch Version (e.g., 1.8): 1.9
  • Python version: 3.7
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source): source
  • If compiling from source, the output of torch.__config__.show():
'PyTorch built with:\n  - GCC 7.3\n  - C++ Version: 201402\n  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n  - CPU capability usage: VSX\n  - CUDA Runtime 11.0\n  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80\n  - CuDNN 8.1.1  (built against CUDA 11.2)\n  - Magma 2.5.4\n  - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CUDA_VERSION=11.0, CUDNN_VERSION=8.1.1, CXX_COMPILER=/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_build_env/bin/powerpc64le-conda_cos7-linux-gnu-c++, CXX_FLAGS=-fvisibility-inlines-hidden -fmessage-length=0 -mcpu=power8 -mtune=power8 -mpower8-fusion -mpower8-vector -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O3 -pipe -I/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/include -fdebug-prefix-map=/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/work=/usr/local/src/conda/pytorch-base-1.9.0 -fdebug-prefix-map=/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol=/usr/local/src/conda-prefix -D__STDC_FORMAT_MACROS -I/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/include -I/sw/peak/cuda/11.0.3/include -I/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_build_env/include -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_VERSION=1.9.0, USE_CUDA=1, USE_CUDNN=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKLDNN=OFF, USE_MPI=0, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=1,

cc @awaelchli @rohitgr7

@proutrc proutrc added bug Something isn't working help wanted Open to be worked on labels Oct 28, 2021
@awaelchli
Copy link
Contributor

awaelchli commented Oct 29, 2021

@proutrc Hi, thanks for reporting.

You mention LSF. Could you verify that trainer.training_type_plugin.cluster_environment prints LSFEnvironment? If that's the case, what will trainer.training_type_plugin.cluster_environment.global_rank() print for you? It should be determined by the env variable JSM_NAMESPACE_RANK.

@awaelchli awaelchli added the distributed Generic distributed-related topic label Oct 29, 2021
@proutrc
Copy link
Author

proutrc commented Nov 2, 2021

@awaelchli something is indeed happening when going from 13 to 14 nodes on our system. See below:

13 nodes snippet (what we expect):

...
<pytorch_lightning.plugins.environments.lsf_environment.LSFEnvironment object at 0x2000a51c7650>
global_rank: 41
world_size: 78
<pytorch_lightning.plugins.environments.lsf_environment.LSFEnvironment object at 0x2000a51cc4d0>
global_rank: 56
world_size: 78
<pytorch_lightning.plugins.environments.lsf_environment.LSFEnvironment object at 0x2000a51c6810>
global_rank: 34
world_size: 78
<pytorch_lightning.plugins.environments.lsf_environment.LSFEnvironment object at 0x2000a4cb52d0>
global_rank: 25
world_size: 78
...

14 nodes snippet:

...
<pytorch_lightning.plugins.environments.lightning_environment.LightningEnvironment object at 0x2000a51bead0>
global_rank: 0
world_size: 84
<pytorch_lightning.plugins.environments.lightning_environment.LightningEnvironment object at 0x2000a51c6510>
global_rank: 0
...

At 14 nodes it shows LightningEnvironment instead of LSFEnvironment. In addition, everything is global_rank:0.. it does have the correct world_size though (84 for 14 nodes).

@proutrc
Copy link
Author

proutrc commented Nov 2, 2021

It looks like LSB_HOSTS is not always set, to show the list of hosts. It appears there could be a limit to the size of the list, where it stops using that variable at a certain size. I think we are hitting this limit.

I can confirm I see LSB_HOSTS when I use 13 nodes, but not when I use 14 nodes.

The variable to use is perhaps LSB_MCPU_HOSTS, instead of LSB_HOSTS.

Somewhat dated reference, but I bet it is still relative: https://www.ibm.com/support/pages/limitation-environment-variable-lsbhosts

@awaelchli
Copy link
Contributor

Thank you @proutrc for helping out here.
Right, given your hint, I'm finding the same, bit more detailed, information on the docs here:
https://www.ibm.com/docs/en/spectrum-lsf/10.1.0?topic=variables-environment-variable-reference
(scroll down to the section for LSB_MCPU_HOSTS)

Quote:

The environment variables LSB_HOSTS and LSB_MCPU_HOSTS both contain the same information, but the information is presented in different formats. LSB_MCPU_HOSTS uses a shorter format than LSB_HOSTS. As a general rule, sbatchd sets both these variables. However, for some parallel jobs, LSB_HOSTS is not set.

For parallel jobs, several CPUs are used, and the length of LSB_HOSTS can become very long. sbatchd needs to spend
a lot of time parsing the string. If the size of LSB_HOSTS exceeds 4096 bytes, LSB_HOSTS is ignored, and sbatchd sets only LSB_MCPU_HOSTS.

To verify the hosts and CPUs used for your dispatched job, check the value of LSB_HOSTS for single CPU jobs, and check the value of LSB_MCPU_HOSTS for parallel jobs.

Their example there shows the format:

LSB_HOSTS= "hostA hostA hostA hostB hostB hostB"
LSB_MCPU_HOSTS="hostA 3 hostB 3" 

So, should we default to LSB_HOSTS in our plugin and if not set try to get the hosts from LSB_MCPU_HOSTS instead? When this happens, we can log a debug message. The fact that the variable contains the word "CPU" is a bit strange but I think for us this does not matter.

@awaelchli
Copy link
Contributor

awaelchli commented Nov 4, 2021

Could you help me verify that this fixes the issue by creating this custom cluster env:

from pytorch_lightning.plugins.environments import ClusterEnvironment


class NewLSFEnvironment(LSFEnvironment):

    @staticmethod
    def is_using_lsf() -> bool:
        required_env_vars = ("LSB_JOBID", "LSB_MCPU_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
        return all(v in os.environ for v in required_env_vars)

    @staticmethod
    def _read_hosts():
        hosts_config = os.environ.get("LSB_MCPU_HOSTS", "")
        if not hosts_config:
            raise ValueError("Could not find hosts in environment variable LSB_MCPU_HOSTS")
        host_config = hosts_config.split()

        if len(host_config) % 2 != 0:
            raise ValueError(
                "Cannot parse hosts from LSB_MCPU_HOSTS environment variable. Expected format:"
                ' "<node0_name> <node0_num_procs> <node1_name> ..."'
            )
        return host_config[::2]

    def _get_master_address(self):
        return self._read_hosts()[0]

and adding it to your trainer like so:

trainer = Trainer(plugins=NewLSFEnvironment())

The main change I made is switch to that env variable and adjusted the parsing.
I simulated it but can't run it for real. If you confirm, I can create a fix directly from this.

If things break down on your side, could you let me know the values of the two environment variables LSB_MCPU_HOSTS and LSB_HOSTS.

@proutrc
Copy link
Author

proutrc commented Nov 5, 2021

Could you help me verify that this fixes the issue by creating this custom cluster env:

from pytorch_lightning.plugins.environments import ClusterEnvironment


class NewLSFEnvironment(LSFEnvironment):

    @staticmethod
    def is_using_lsf() -> bool:
        required_env_vars = ("LSB_JOBID", "LSB_MCPU_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
        return all(v in os.environ for v in required_env_vars)

    @staticmethod
    def _read_hosts():
        hosts_config = os.environ.get("LSB_MCPU_HOSTS", "")
        if not hosts_config:
            raise ValueError("Could not find hosts in environment variable LSB_MCPU_HOSTS")
        host_config = hosts_config.split()

        if len(host_config) % 2 != 0:
            raise ValueError(
                "Cannot parse hosts from LSB_MCPU_HOSTS environment variable. Expected format:"
                ' "<node0_name> <node0_num_procs> <node1_name> ..."'
            )
        return host_config[::2]

    def _get_master_address(self):
        return self._read_hosts()[0]

and adding it to your trainer like so:

trainer = Trainer(plugins=NewLSFEnvironment())

The main change I made is switch to that env variable and adjusted the parsing. I simulated it but can't run it for real. If you confirm, I can create a fix directly from this.

If things break down on your side, could you let me know the values of the two environment variables LSB_MCPU_HOSTS and LSB_HOSTS.

When trying this override in my small program (listed above) it complains about LSFEnvironment not being defined. I thought maybe you meant import LSFEnviornment at the top, instead of import ClusterEnvironment, but that produces another error. At any rate, I could be doing something silly (likely perhaps), but can you provide a full example of this override?

@awaelchli
Copy link
Contributor

awaelchli commented Nov 5, 2021

Thank you very much for trying and the patience.
Here is the full example (with correct imports, sorry about that :)) and with your code combined:

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.plugins import DDPPlugin


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)

        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


from pytorch_lightning.plugins.environments import LSFEnvironment


class NewLSFEnvironment(LSFEnvironment):
    @staticmethod
    def is_using_lsf() -> bool:
        required_env_vars = ("LSB_JOBID", "LSB_MCPU_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
        return all(v in os.environ for v in required_env_vars)

    @staticmethod
    def _read_hosts():
        hosts_config = os.environ.get("LSB_MCPU_HOSTS", "")
        if not hosts_config:
            raise ValueError("Could not find hosts in environment variable LSB_MCPU_HOSTS")
        host_config = hosts_config.split()

        if len(host_config) % 2 != 0:
            raise ValueError(
                "Cannot parse hosts from LSB_MCPU_HOSTS environment variable. Expected format:"
                ' "<node0_name> <node0_num_procs> <node1_name> ..."'
            )
        return host_config[::2]

    def _get_master_address(self):
        return self._read_hosts()[0]


if __name__ == "__main__":
    # define datasets/dataloaders
    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    train_dl = DataLoader(dataset)

    # simulate locally
    # os.environ["LSB_JOBID"] = "1234"
    # os.environ["LSB_MCPU_HOSTS"] = "localhost 2 localhost 2"
    # os.environ["JSM_NAMESPACE_LOCAL_RANK"] = "1"
    # os.environ["JSM_NAMESPACE_RANK"] = "1"
    # os.environ["JSM_NAMESPACE_SIZE"] = "4"

    # train
    model = LitAutoEncoder()

    # simulate locally
    # trainer = pl.Trainer(
    #     num_processes=2,
    #     num_nodes=2,
    #     max_epochs=3,
    #     plugins=DDPPlugin(cluster_environment=NewLSFEnvironment()),
    # )

    trainer = pl.Trainer(
        gpus="0,1,2,3,4,5",
        auto_select_gpus=True,
        num_nodes=14,
        max_epochs=3,
        plugins=DDPPlugin(cluster_environment=NewLSFEnvironment()),
    )
    print("LSB_MCPU_HOSTS:", os.environ["LSB_MCPU_HOSTS"])
    trainer.fit(model, train_dl)

@proutrc
Copy link
Author

proutrc commented Nov 5, 2021

Thank you very much for trying and the patience. Here is the full example (with correct imports, sorry about that :)) and with your code combined:

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.plugins import DDPPlugin


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)

        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


from pytorch_lightning.plugins.environments import LSFEnvironment


class NewLSFEnvironment(LSFEnvironment):
    @staticmethod
    def is_using_lsf() -> bool:
        required_env_vars = ("LSB_JOBID", "LSB_MCPU_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE")
        return all(v in os.environ for v in required_env_vars)

    @staticmethod
    def _read_hosts():
        hosts_config = os.environ.get("LSB_MCPU_HOSTS", "")
        if not hosts_config:
            raise ValueError("Could not find hosts in environment variable LSB_MCPU_HOSTS")
        host_config = hosts_config.split()

        if len(host_config) % 2 != 0:
            raise ValueError(
                "Cannot parse hosts from LSB_MCPU_HOSTS environment variable. Expected format:"
                ' "<node0_name> <node0_num_procs> <node1_name> ..."'
            )
        return host_config[::2]

    def _get_master_address(self):
        return self._read_hosts()[0]


if __name__ == "__main__":
    # define datasets/dataloaders
    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    train_dl = DataLoader(dataset)

    # simulate locally
    # os.environ["LSB_JOBID"] = "1234"
    # os.environ["LSB_MCPU_HOSTS"] = "localhost 2 localhost 2"
    # os.environ["JSM_NAMESPACE_LOCAL_RANK"] = "1"
    # os.environ["JSM_NAMESPACE_RANK"] = "1"
    # os.environ["JSM_NAMESPACE_SIZE"] = "4"

    # train
    model = LitAutoEncoder()

    # simulate locally
    # trainer = pl.Trainer(
    #     num_processes=2,
    #     num_nodes=2,
    #     max_epochs=3,
    #     plugins=DDPPlugin(cluster_environment=NewLSFEnvironment()),
    # )

    trainer = pl.Trainer(
        gpus="0,1,2,3,4,5",
        auto_select_gpus=True,
        num_nodes=14,
        max_epochs=3,
        plugins=DDPPlugin(cluster_environment=NewLSFEnvironment()),
    )
    print("LSB_MCPU_HOSTS:", os.environ["LSB_MCPU_HOSTS"])
    trainer.fit(model, train_dl)

It looks like we are switching to the DDPPlugin now, plugins=DDPPlugin(cluster_environment=NewLSFEnvironment()) , which is causing other types of issues for me.

Error with DDPPlugin method:

pytorch_lightning.utilities.exceptions.MisconfigurationException: Horovod does not support setting num_nodes / num_gpus explicitly. Use horovodrun / mpirun to configure the number of processes.
Traceback (most recent call last):
  File "auto-test.py", line 99, in <module>
    plugins=DDPPlugin(cluster_environment=NewLSFEnvironment()),
  File "/gpfs/alpine/stf007/scratch/rprout/deep-learning-hpc-project-template/torch-1.9/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py", line 40, in insert_env_defaults
    return fn(self, **kwargs)
  File "/gpfs/alpine/stf007/scratch/rprout/deep-learning-hpc-project-template/torch-1.9/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 371, in __init__
    plugins,
  File "/gpfs/alpine/stf007/scratch/rprout/deep-learning-hpc-project-template/torch-1.9/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 152, in __init__
    self.set_distributed_mode()
  File "/gpfs/alpine/stf007/scratch/rprout/deep-learning-hpc-project-template/torch-1.9/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 740, in set_distributed_mode
    self._set_horovod_backend()
  File "/gpfs/alpine/stf007/scratch/rprout/deep-learning-hpc-project-template/torch-1.9/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 821, in _set_horovod_backend
    self.check_horovod()
  File "/gpfs/alpine/stf007/scratch/rprout/deep-learning-hpc-project-template/torch-1.9/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py", line 856, in check_horovod
    "Horovod does not support setting num_nodes / num_gpus explicitly. Use "

This seems mostly explanatory. Just wanted to confirm this is the recommended path, as it appears to be changing some underlying things (i.e. Horovod does not support setting num_nodes / num_gpus explicitly). I can of course try without num_nodes. Just not sure I understand all the implied changes here yet.

On another note, when I do this (not using DDPPlugin) it hangs and never runs the model (but it does seem to initialize all ranks on the GPUs beyond our previous limit):

trainer = pl.Trainer(
    gpus="0,1,2,3,4,5",
    auto_select_gpus=True,
    num_nodes=14,
    max_epochs=3,
    accelerator='ddp',
    plugins=[NewLSFEnvironment()],
)

If the latter method, without DDPPlugin, is suppose to work we may have an issue still. If I need to switch to DDPlugin method, I can continue down that path. Just wanted to confirm these changes and what is implied.

Thanks!

@proutrc
Copy link
Author

proutrc commented Nov 15, 2021

@awaelchli Hopefully my last question made sense.

Ultimately, it appeared the new env variable and parsing worked. I could see it seemingly initialize the GPUs if I don't use the DDPPlugin, but the model never runs (just hangs there).

The switching to the DDPlugin caused some different behavior (seems like it doesn't like the trainer setup, as it is, with this plugin).

Please let me know if I can do or provide anything else.

@awaelchli
Copy link
Contributor

@proutrc apologies for the delay. I'm struggling to simulate what you are seeing because I can't try it on a LSF cluster myself.

@ajtritt who is the original contributor of the LSF Environment plugin, could you review the changes in my PR over here #10354 and if you still have access to LSF maybe verify if I broke anything, that would be very appreciated.

@tchaton
Copy link
Contributor

tchaton commented Nov 18, 2021

Dear @proutrc,

Any chance you would be open to having a pair coding debugging session with the Lightning Team (@awaelchli) so we can resolve this issue?

Best,
T.C

@ajtritt
Copy link
Contributor

ajtritt commented Nov 18, 2021

Hi @proutrc,

I have been using a custom ClusterEnvironment to run on Summit:

https://github.com/exabiome/deep-taxon/blob/master/src/exabiome/nn/lsf_environment.py

It's been some time since I update things though. Here's what my environment looks like:

PyTorch Lightning Version (e.g., 1.3.0): 1.4.3
PyTorch Version (e.g., 1.8): 1.7.1
Python version: 3.8.10
OS (e.g., Linux):
CUDA/cuDNN version: 11.0.221

The rest of the environment is cloned from open-ce-1.2.0-py38-0

I've been able to run on 128 nodes successfully, I haven't tried anything past that.

Let me know what that turns up.

Andrew

@proutrc
Copy link
Author

proutrc commented Nov 24, 2021

@awaelchli @tchaton @ajtritt

My apologies for delay this time :).

The lsf_environment.py does work for me, example from exabiome project above, getting me passed the issue with using LSB_HOSTS (the environment variable that goes empty at a certain size - our particular case, on Summit, was anything beyond 13 nodes).

It seems the provided LSFEnvironment, in pytorch-lightning, could reflect your example?

I see you use LSB_DJOB_RANKFILE for reading hosts. @awaelchli This is perhaps better. It creates a file with the list of hosts, looking like this (just two node example):

bash-4.4$ echo $LSB_DJOB_RANKFILE
/ccs/home/rprout/.lsbatch/1637776511.1646717.hostfile
bash-4.4$ cat /ccs/home/rprout/.lsbatch/1637776511.1646717.hostfile
batch4
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n14
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15
d29n15

Here is my trainer, FWIW:

trainer = pl.Trainer(
    gpus="0,1,2,3,4,5",
    auto_select_gpus=True,
    num_nodes=32,
    max_epochs=3,
    accelerator='ddp',
    plugins=[LSFEnvironment()],
)

Getting me this for the 32 nodes x 6 GPUs/node (192 ranks):

distributed_backend=nccl
All DDP processes registered. Starting ddp with 192 processes

And LSFEnvironment() came from your example in https://github.com/exabiome/deep-taxon/blob/master/src/exabiome/nn/lsf_environment.py

@tchaton
Copy link
Contributor

tchaton commented Nov 24, 2021

Dear @proutrc,

It seems your LSFEnvironment implementation is more reliable at scale. Would you be willing to make a contribution and test. I believe it should be easy to mock.

Best,
T.C

@proutrc
Copy link
Author

proutrc commented Nov 24, 2021

Dear @proutrc,

It seems your LSFEnvironment implementation is more reliable at scale. Would you be willing to make a contribution and test. I believe it should be easy to mock.

Best, T.C

@tchaton I am happy to help however I can, but I can't claim the implementation. @ajtritt is the one who provided the implementation for the LSFEnvironment that worked for me. I just tested it within my little example.

I may have goofed something with @awaelchli fix. That method may be valid too, but the one using LSB_DJOB_RANKFILE might be easier than LSB_MCPU_HOSTS. I will retry the one using LSB_MPCU_HOSTS again though, to see if I made a mistake last time.

I think I might see what those guys think too.

@awaelchli
Copy link
Contributor

awaelchli commented Nov 25, 2021

We can certainly add the LSB_RANK_HOSTFILE functionality. However, according to the documentation, this environment variable is not set by default, so we probably still want to fall back to LSB_HOSTS or LSB_MCPU_HOSTS in case it is not defined.

Since I don't have the environment to properly test and debug my PR #10354, I would definitely go for what is confirmed here to work. If @ajtritt has the time, we would be happy to receive a contribution of his improved LSF code 😃

@stale
Copy link

stale bot commented Dec 26, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Dec 26, 2021
@awaelchli awaelchli added this to the 1.5.x milestone Dec 26, 2021
@cmlakhan
Copy link

cmlakhan commented Jan 6, 2022

We are experiencing an issue with DDP at larger scales on our HPC system (Summit at OLCF - LSF scheduler). The specific threshold is at 14 nodes, where things suddenly aren't able to initialize anymore. It appears there is all of a sudden an inability to setup ranks across nodes properly, as depicted below in the output.

Each node as 6 GPUs. So, in total, we are trying to use 84 GPUs when things are suddenly unable to initialize. At 78 GPUs (or 13 nodes) things work as expected.

Initialization output at 13 nodes (78 GPUs - this works as expected):

initializing ddp: GLOBAL_RANK: 72, MEMBER: 73/78
initializing ddp: GLOBAL_RANK: 36, MEMBER: 37/78
initializing ddp: GLOBAL_RANK: 10, MEMBER: 11/78
initializing ddp: GLOBAL_RANK: 19, MEMBER: 20/78
initializing ddp: GLOBAL_RANK: 32, MEMBER: 33/78
initializing ddp: GLOBAL_RANK: 12, MEMBER: 13/78
initializing ddp: GLOBAL_RANK: 41, MEMBER: 42/78
initializing ddp: GLOBAL_RANK: 74, MEMBER: 75/78
initializing ddp: GLOBAL_RANK: 27, MEMBER: 28/78
initializing ddp: GLOBAL_RANK: 21, MEMBER: 22/78
initializing ddp: GLOBAL_RANK: 13, MEMBER: 14/78
initializing ddp: GLOBAL_RANK: 73, MEMBER: 74/78
initializing ddp: GLOBAL_RANK: 25, MEMBER: 26/78
initializing ddp: GLOBAL_RANK: 24, MEMBER: 25/78
initializing ddp: GLOBAL_RANK: 28, MEMBER: 29/78
initializing ddp: GLOBAL_RANK: 67, MEMBER: 68/78
initializing ddp: GLOBAL_RANK: 69, MEMBER: 70/78
initializing ddp: GLOBAL_RANK: 68, MEMBER: 69/78
initializing ddp: GLOBAL_RANK: 47, MEMBER: 48/78
initializing ddp: GLOBAL_RANK: 56, MEMBER: 57/78
initializing ddp: GLOBAL_RANK: 71, MEMBER: 72/78
initializing ddp: GLOBAL_RANK: 53, MEMBER: 54/78
initializing ddp: GLOBAL_RANK: 66, MEMBER: 67/78
initializing ddp: GLOBAL_RANK: 70, MEMBER: 71/78
initializing ddp: GLOBAL_RANK: 55, MEMBER: 56/78
initializing ddp: GLOBAL_RANK: 49, MEMBER: 50/78
initializing ddp: GLOBAL_RANK: 44, MEMBER: 45/78
initializing ddp: GLOBAL_RANK: 46, MEMBER: 47/78
initializing ddp: GLOBAL_RANK: 57, MEMBER: 58/78
initializing ddp: GLOBAL_RANK: 51, MEMBER: 52/78
initializing ddp: GLOBAL_RANK: 45, MEMBER: 46/78
initializing ddp: GLOBAL_RANK: 54, MEMBER: 55/78
initializing ddp: GLOBAL_RANK: 50, MEMBER: 51/78
initializing ddp: GLOBAL_RANK: 59, MEMBER: 60/78
initializing ddp: GLOBAL_RANK: 52, MEMBER: 53/78
initializing ddp: GLOBAL_RANK: 43, MEMBER: 44/78
initializing ddp: GLOBAL_RANK: 58, MEMBER: 59/78
initializing ddp: GLOBAL_RANK: 42, MEMBER: 43/78
initializing ddp: GLOBAL_RANK: 48, MEMBER: 49/78
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All DDP processes registered. Starting ddp with 78 processes

Failed initialization at 14 nodes (84 GPUs - this hangs at this point):

initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/84
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/84

Here is the code:

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28)
        )

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)

        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# define datasets/dataloaders
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_dl = DataLoader(dataset)


# train
model = LitAutoEncoder()
trainer = pl.Trainer(gpus="0,1,2,3,4,5", auto_select_gpus=True, num_nodes=14, max_epochs=3, accelerator='ddp')
trainer.fit(model, train_dl)
  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.9
  • PyTorch Version (e.g., 1.8): 1.9
  • Python version: 3.7
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source): source
  • If compiling from source, the output of torch.__config__.show():
'PyTorch built with:\n  - GCC 7.3\n  - C++ Version: 201402\n  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n  - CPU capability usage: VSX\n  - CUDA Runtime 11.0\n  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80\n  - CuDNN 8.1.1  (built against CUDA 11.2)\n  - Magma 2.5.4\n  - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CUDA_VERSION=11.0, CUDNN_VERSION=8.1.1, CXX_COMPILER=/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_build_env/bin/powerpc64le-conda_cos7-linux-gnu-c++, CXX_FLAGS=-fvisibility-inlines-hidden -fmessage-length=0 -mcpu=power8 -mtune=power8 -mpower8-fusion -mpower8-vector -ftree-vectorize -fPIC -fstack-protector-strong -fno-plt -O3 -pipe -I/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/include -fdebug-prefix-map=/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/work=/usr/local/src/conda/pytorch-base-1.9.0 -fdebug-prefix-map=/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol=/usr/local/src/conda-prefix -D__STDC_FORMAT_MACROS -I/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/include -I/sw/peak/cuda/11.0.3/include -I/gpfs/alpine/world-shared/stf007/davismj/open-ce-builds/rhel8-oce-1.4.0/python-env/conda-bld/pytorch-base_1633116212289/_build_env/include -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_VERSION=1.9.0, USE_CUDA=1, USE_CUDNN=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKLDNN=OFF, USE_MPI=0, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=1,

cc @awaelchli @rohitgr7

Would you mind sharing your lsf submit script? I am new to LSF and would love to see how you configure your LSF script for submission to the cluster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
5 participants