Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'DistributedDataParallel' object has no attribute 'generate' when validating T5 model #754

Closed
2 of 4 tasks
gansem opened this issue Oct 13, 2022 · 6 comments

Comments

@gansem
Copy link

gansem commented Oct 13, 2022

System Info

- `Accelerate` version: 0.12.0
- Platform: Linux-5.4.0-1086-gcp-x86_64-with-glibc2.17
- Python version: 3.8.13
- Numpy version: 1.23.1
- PyTorch version (GPU?): 1.12.1 (True)
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Reproduction:

from transformers import T5Tokenizer, T5ForConditionalGeneration
from accelerate import Accelerator, notebook_launcher


def main():
    accelerator = Accelerator()

    tokenizer = T5Tokenizer.from_pretrained("t5-base", extra_ids=0)
    model = T5ForConditionalGeneration.from_pretrained("t5-base")

    input_ids = tokenizer("Test Input", return_tensors="pt").input_ids
    model = accelerator.prepare(model)

    outputs = model.generate(input_ids=input_ids.to("cuda:0"))

if __name__ == "__main__":
    notebook_launcher(main, num_processes=2)

Stack trace:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/site-packages/accelerate/utils/launch.py", line 72, in __call__
    self.launcher(*args)
  File "/home/ubuntu/owlin/Context_Scoring/temp2.py", line 14, in main
    outputs = model.generate(input_ids=input_ids.to("cuda:0"))
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1207, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DistributedDataParallel' object has no attribute 'generate'
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/site-packages/accelerate/launchers.py", line 127, in notebook_launcher
    start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
  File "/home/ubuntu/owlin/Context_Scoring/temp2.py", line 18, in <module>
    notebook_launcher(main, num_processes=2)
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/miniconda3/envs/temp_env/lib/python3.8/runpy.py", line 194, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

Expected behavior

The expected behaviour for this code is to return an output tensor containing the tokens for the generated text just like a normal T5ForConditionalGeneration model would. It seems that the error is being raised because accelerate.prepare() returns a DistributedDataParallel object which does not have this attribute. My question is if there is a way to do generation in a distributed manner? Currently I found a work around by unwrapping the model and running it on a single GPU. However, ideally I would like to do this in a multi-gpu setting as well.

@pacman100
Copy link
Contributor

Hello, generate already works in a distributed/multi-gpu setting, please refer https://github.com/huggingface/transformers/blob/main/examples/pytorch/translation/run_translation_no_trainer.py as an example.

@gansem
Copy link
Author

gansem commented Oct 13, 2022

Thanks for the quick response! I was able to implement the multi-gpu generation using the example.

@shivangsharma1
Copy link

Hi @gansem ,

I am also looking for similar thing, I have trained t5 model, want to multi GPU inference on my custom dataset(csv), can you please share a snippet to do multi GPU inference.

@cokuehuang
Copy link

I am meeting the same problem on multi GPU inference for translation .
Hi @gansem , can you please share your implement on Reproduction example above ? Thanks a lot!

@gansem
Copy link
Author

gansem commented Oct 20, 2022

Simply calling accelerator.unwrap_model() on the model, as shown here fixed it for me.

A minimal working example:

from transformers import T5Tokenizer, T5ForConditionalGeneration
from accelerate import Accelerator, notebook_launcher
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.inputs = data
        self.output = data

    def __getitem__(self, index):
        return {
            'inputs': self.inputs[index],
            'targets': self.output[index]
        }

    def __len__(self):
        return len(self.inputs)


def calculate_your_metrics(predictions, targets):
    pass


def main():
    text = """
    Amsterdam (/ˈæmstərdæm/ AM-stər-dam, UK also /ˌæmstərˈdæm/ AM-stər-DAM,[9][10] Dutch: [ˌɑmstərˈdɑm] (listen), lit. 
    The Dam on the River Amstel) is the capital and most populous city of the Netherlands; with a population of 907,976[11] within the city proper, 
    1,558,755 in the urban area[6] and 2,480,394 in the metropolitan area.[12] Found within the Dutch province of North Holland,[13][14] Amsterdam is 
    colloquially referred to as the "Venice of the North", due to the large number of canals which form a UNESCO World Heritage Site.[15]
    Amsterdam was founded at the Amstel, that was dammed to control flooding; the city's name derives from the Amstel dam.[16] 
    Originating as a small fishing village in the late 12th century, Amsterdam became one of the most important ports in the world during the Dutch Golden Age 
    of the 17th century, and became the leading centre for the finance and trade sectors.[17] In the 19th and 20th centuries, the city expanded and many new 
    neighborhoods and suburbs were planned and built. The 17th-century canals of Amsterdam and the 19–20th century Defence Line of Amsterdam are on the UNESCO 
    World Heritage List. Sloten, annexed in 1921 by the municipality of Amsterdam, is the oldest part of the city, dating to the 9th century. 
    """

    accelerator = Accelerator()

    tokenizer = T5Tokenizer.from_pretrained("t5-base", extra_ids=0)
    model = T5ForConditionalGeneration.from_pretrained("t5-base")

    input_ids = tokenizer([text] * 512, return_tensors="pt", padding=True).input_ids

    dataset = TestDataset(input_ids)
    dataloader = DataLoader(dataset, batch_size=64)

    model, dataloader = accelerator.prepare(model, dataloader)

    model.eval()
    for _ in tqdm(range(100)):
        for batch in dataloader:
            input_ids = batch["inputs"]
            targets = batch["targets"]
            predictions = accelerator.unwrap_model(model).generate(input_ids=input_ids)
            all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets))
            metrics = calculate_your_metrics(all_predictions, all_targets)

if __name__ == "__main__":
    notebook_launcher(main, num_processes=2)

Calling nvidia-smi while running returns:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03 Driver Version: 470.141.03 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|====================+===============+============== |
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 75C P0 71W / 70W | 4624MiB / 15109MiB | 84% Default |
| | | N/A |
+----------------------------------+------------------------+-------------------------+
| 1 Tesla T4 Off | 00000000:00:05.0 Off | 0 |
| N/A 69C P0 70W / 70W | 4626MiB / 15109MiB | 81% Default |
| | | N/A |
+----------------------------------+-------------------------+----------------------+

I hope this helps, let me know if you have any more questions :)

@yhyu13
Copy link

yhyu13 commented Apr 19, 2023

@gansem I noticed you do have persistent mode turned based on nvidia-smi, persistent mode is crucial if you want NVLink to boost even greater mult-gpu performance

https://forums.developer.nvidia.com/t/two-dual-geforce-rtx-3090s-and-nvlink-ubuntu-support-at-least-blender-has-support-for-nvlink/160561/14

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants