Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] IndexError / Runtime Error with torch.nn.TransformerEncoder #1795

Closed
floatshadow opened this issue Feb 26, 2022 · 2 comments
Closed

[BUG] IndexError / Runtime Error with torch.nn.TransformerEncoder #1795

floatshadow opened this issue Feb 26, 2022 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@floatshadow
Copy link

floatshadow commented Feb 26, 2022

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
I'm fairly new to DeepSpeed, and I tried deepspeed zero-3 and offload with pytorch transformer.
And I write a simple test code:

import os
import torch
import deepspeed
import argparse

import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

from apex.optimizers import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.utils import see_memory_usage

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--world_size", default=2)
parser.add_argument("--init_method", default='tcp://localhost:12306')
parser.add_argument("--master_addr", default="127.0.0.1")
parser.add_argument("--master_port", default="12346")
deepspeed.add_config_arguments(parser)
args = parser.parse_args()
args.rank = int(os.getenv("RANK", "0"))
args.init_method = f"tcp://{args.master_addr}:{args.master_port}"

print(f"initializing distributed...", flush=True)
torch.distributed.init_process_group(backend='nccl', 
                           init_method=args.init_method, 
                           rank=args.local_rank, 
                           world_size=args.world_size)
#deepspeed.init_distributed(distributed_port=args.master_port, init_method=args.init_method)
args.rank = torch.distributed.get_rank()
if torch.distributed.is_initialized():
  print(f"Worker Rank {args.rank} initialized.")


class BoringDataset(Dataset):
  def __init__(self) -> None:
      super().__init__()
      self.data = torch.randn(20, 128, 1024)
  
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
      return self.data[index, :, :], self.data[index, :, :]


train_set = BoringDataset()
train_sampler = DistributedSampler(train_set)
train_loader = DataLoader(train_set, batch_size=2, sampler=train_sampler)


see_memory_usage(f"Before Building Model", force=True)
encoder_layer = nn.TransformerEncoderLayer(
    d_model=1024,
    nhead=8) 
with deepspeed.zero.Init(remote_device="cpu", config_dict_or_path=args.deepspeed_config, enabled=True):
  transformer=nn.TransformerEncoder(
    encoder_layer=encoder_layer,
    num_layers=10)
#transformer.cuda(args.rank) 
#transformer = DistributedDataParallel(transformer, device_ids=[args.rank])
see_memory_usage(f"Before Building Model", force=True)


# optimizer = FusedAdam(transformer.parameters())
# optimizer = DeepSpeedCPUAdam(transformer.parameters())


see_memory_usage(f"Before DeepSpeed initialized", force=True)
model, optimizer, _, _ = deepspeed.initialize(args=args,
                                              model=transformer)
see_memory_usage(f"After DeepSpeed initialized", force=True)




loss_fn = nn.MSELoss()
for step, (batch, label) in enumerate(train_loader):
  print(f"On GPU {args.local_rank}, before step {step}")
  batch = batch.cuda(args.rank)
  label = label.cuda(args.rank)
  print(f"batch size {batch.shape}")
  output = model(batch)
  model.backward(loss_fn(output, label))
  model.step()
  print(f"On GPU {args.local_rank}, after step {step}")

But I encountered traceback

RuntimeError: expected scalar type Float but found Half

 On GPU 0, before step 0
batch size torch.Size([2, 128, 1024])
Traceback (most recent call last):
Traceback (most recent call last):
  File "BoringDeepSpeed.py", line 82, in <module>
  File "BoringDeepSpeed.py", line 82, in <module>
    output = model(batch)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    output = model(batch)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    loss = self.module(*inputs, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 181, in forward
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 181, in forward
    output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 293, in forward
    src2 = self.self_attn(src, src, src, attn_mask=src_mask,
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 293, in forward
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 980, in forward
    src2 = self.self_attn(src, src, src, attn_mask=src_mask,
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    return F.multi_head_attention_forward(
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/functional.py", line 4636, in multi_head_attention_forward
    result = self.forward(*input, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 980, in forward
    q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 209, in decorate_fwd
    return F.multi_head_attention_forward(
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/nn/functional.py", line 4636, in multi_head_attention_forward
    return fwd(*args, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/linear.py", line 60, in forward
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half
    q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 209, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/linear.py", line 60, in forward
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half

So, I refer to previous issues but find nothing helps. I guess I may half the input tensor manually. But I get another traceback

IndexError: list index out of range

On GPU 0, before step 0
batch size torch.Size([2, 128, 1024])
Traceback (most recent call last):
  File "BoringDeepSpeed.py", line 83, in <module>
Traceback (most recent call last):
  File "BoringDeepSpeed.py", line 83, in <module>
    model.backward(loss_fn(output, label))
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1740, in backward
    model.backward(loss_fn(output, label))
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1740, in backward
    self.allreduce_gradients()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1663, in allreduce_gradients
    self.allreduce_gradients()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1663, in allreduce_gradients
    self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1896, in overlapping_partition_gradients_reduce_epilogue
    self.optimizer.overlapping_partition_gradients_reduce_epilogue()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1896, in overlapping_partition_gradients_reduce_epilogue
    self.independent_gradient_partition_epilogue()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1798, in independent_gradient_partition_epilogue
    self.partition_previous_reduced_grads()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 2198, in partition_previous_reduced_grads
    self.independent_gradient_partition_epilogue()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1798, in independent_gradient_partition_epilogue
    fp32_grad_tensor = self.fp32_partitioned_groups_flat[
IndexError: list index out of range
    self.partition_previous_reduced_grads()
  File "/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 2198, in partition_previous_reduced_grads
    fp32_grad_tensor = self.fp32_partitioned_groups_flat[
IndexError: list index out of range

sorry that I really can hardly find some hints in deepspeed docs or elsewhere, so I post a issue here

DeepSpeed JSON config

{
    "train_micro_batch_size_per_gpu": 2 ,
    "gradient_accumulation_steps": 1,
    "steps_per_print": 1,
    "zero_optimization": {
      "stage": 3,
      "stage3_max_live_parameters": 1e9,
      "stage3_max_reuse_distance": 1e9,
      "stage3_param_persitence_threshold": 1e5,
      "stage3_prefetch_bucket_size": 5e7,
      "contiguous_gradients": true,
      "overlap_comm": true,
      "reduce_bucket_size": 90000000,
      "sub_group_size": 1e8,
      "offload_param": {
        "device": "cpu",
        "pin_memory": true
      },
      "offload_optimizer": {
        "device": "cpu",
        "pin_memory": true
      }
    },
    "optimizer": {
      "type": "Adam",
      "params": {
        "lr": 0.001,
        "betas": [
          0.9,
          0.95
        ]
      }
    },
    "fp16": {
      "enabled": true,
      "loss_scale": 0,
      "loss_scale_window": 1000,
      "hysteresis": 2,
      "min_loss_scale": 1
    },
    "gradient_clipping": 1.0,
    "wall_clock_breakdown": true,
    "zero_allow_untested_optimizer": false,
    "aio": {
      "block_size": 1048576,
      "queue_depth": 16,
      "single_submit": false,
      "overlap_events": true,
      "thread_count": 2
    }
  }
  

ds_report output
Please run ds_report to give us details about your setup.

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [YES] ...... [OKAY]
cpu_adagrad ............ [YES] ...... [OKAY]
fused_adam ............. [YES] ...... [OKAY]
fused_lamb ............. [YES] ...... [OKAY]
sparse_attn ............ [YES] ...... [OKAY]
transformer ............ [YES] ...... [OKAY]
stochastic_transformer . [YES] ...... [OKAY]
async_io ............... [YES] ...... [OKAY]
transformer_inference .. [YES] ...... [OKAY]
utils .................. [YES] ...... [OKAY]
quantizer .............. [YES] ...... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/torch']
torch version .................... 1.8.1+cu101
torch cuda version ............... 10.1
nvcc version ..................... 10.1
deepspeed install path ........... ['/chenjh02/miniconda3/envs/yuan/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.5.8, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.8, cuda 10.1

Screenshots
If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: Ubuntu 18.04 x86_64
  • GPU count and types 2 2080ti on a single node
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
deepspeed launcher, here's my script

NNODES=1
GPUS_PER_NODE=2
MASTER_PORT=12346
MASTER_ADDR=localhost
DISTRIBUTED_ARGS="--num_nodes 1 --num_gpus 2 --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

deepspeed ${DISTRIBUTED_ARGS} BoringDeepSpeed.py --deepspeed --deepspeed_config ds_zero3.json

Docker context
Are you using a specific docker image that you can share?

Additional context
Add any other context about the problem here.

@floatshadow floatshadow added the bug Something isn't working label Feb 26, 2022
@mrwyattii mrwyattii self-assigned this Mar 21, 2022
@mrwyattii
Copy link
Contributor

mrwyattii commented Mar 25, 2022

I was able to get your model running. You need to make your dataset type .half() and also modify your deepspeed.initalize(...) line. See the example code below. Also, there is a bug in the latest DeepSpeed that will cause an error with ZeRO stage 3 for this model. Please use v0.5.10 until we can resolve this bug (pip install deepspeed==0.5.10):

import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import deepspeed


class BoringDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.data = torch.randn(200, 128, 1024).half()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index, :, :], self.data[index, :, :]


train_set = BoringDataset()
train_loader = DataLoader(train_set, batch_size=32)

ds_config = {
    "train_micro_batch_size_per_gpu": 2,
    "gradient_accumulation_steps": 1,
    "steps_per_print": 1,
    "zero_optimization": {
        "stage": 3,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_param_persitence_threshold": 1e5,
        "stage3_prefetch_bucket_size": 5e7,
        "contiguous_gradients": True,
        "overlap_comm": True,
        "reduce_bucket_size": 90000000,
        "sub_group_size": 1e8,
        "offload_param": {"device": "cpu", "pin_memory": True},
        "offload_optimizer": {"device": "cpu", "pin_memory": True},
    },
    "optimizer": {"type": "Adam", "params": {"lr": 0.001, "betas": [0.9, 0.95]}},
    "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1,
    },
    "gradient_clipping": 1.0,
    "wall_clock_breakdown": True,
    "zero_allow_untested_optimizer": False,
    "aio": {
        "block_size": 1048576,
        "queue_depth": 16,
        "single_submit": False,
        "overlap_events": True,
        "thread_count": 2,
    },
}

encoder_layer = nn.TransformerEncoderLayer(d_model=1024, nhead=8)
transformer = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=10)

model, _, _, _ = deepspeed.initialize(
    config=ds_config, model=transformer, model_parameters=transformer.parameters()
)

loss_fn = nn.MSELoss()
model.train()
rank = int(os.getenv("RANK", "0"))
for step, (batch, label) in enumerate(train_loader):
    if rank == 0:
        print("step:", step)
    batch = batch.cuda(rank)
    label = label.cuda(rank)
    output = model(batch)
    model.backward(loss_fn(output, label))
    model.step()

@mrwyattii
Copy link
Contributor

Closing due to inactivity, please reopen if you are unable to run the model with the updated script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants