Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] GPU memory explodes when using Conv2D layers in Dict Observations FeatureExtractor #863

Closed
ThomasRochefortB opened this issue Apr 14, 2022 · 10 comments
Labels
bug Something isn't working custom gym env Issue related to Custom Gym Env

Comments

@ThomasRochefortB
Copy link

🐛 Bug

I am having issues in SB3 with a CustomFeatureExtractor for a Dict observation space that is making my GPU memory explode. The observation space is composed of a single channel image (1,51,101) and three vectors with corresponding dimensions (9,) , (1,) & (1,). My problem is when I add Conv2D layers for the image in the feature extractor, the GPU memory explodes and I get a OOM error. Replacing the Conv2D layer with a simple Flatten() layer works like a charm. When using the Conv2D layer for the image, the GPU memory of my 3080 caps off at over 9GB, and using the flatten layer instead of the CONV2D runs with only 3.2GB of used memory.... How can a single layer of 2D convolution add over 6GB of GPU memory usage? I have done the math for the space taken by the batch of observations in Float32: ((101*51) + (9) + (1) +(1)) * 4bytes * 10000 = 0.2 GB of memory. The value and policy network with the feature extractor should only be around 1.3M parameters which should more than fit in my 10GB of memory on the 3080.

To Reproduce

import gym
import numpy as np
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize,VecTransposeImage, SubprocVecEnv, VecFrameStack, VecCheckNan
from stable_baselines3 import PPO,DQN,SAC,A2C,TD3
import torch
from torch import nn
import torch as th
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
torch.set_default_tensor_type(torch.FloatTensor)

lass CustomEnv(gym.Env):
    
    def __init__(self):
        self.action_space = gym.spaces.Box(low=-1,high=1,shape=(7,), dtype=np.float32)
        self.observation_space = gym.spaces.Dict(
                                    spaces={
                                        "out_1": gym.spaces.Box(-1, 1, (9,),dtype=np.float32),
                                        "out_2":gym.spaces.Box(0,1,(1,),dtype=np.float32),
                                        "image": gym.spaces.Box(0, 1, (1,51,101),dtype=np.float32),
                                        "out_3":gym.spaces.Box(0,1,(1,),dtype=np.float32)
                                        }
                                    )    
    def reset(self):

        out_dict={"out_1":np.float32(np.zeros((9,))),
                 "out_2":np.float32(np.zeros((1,))),
                 "image":np.float32(np.zeros((1,51,101))),
                 "out_3":np.float32(np.zeros((1,)))}

        return out_dict

    def step(self, action):

        out_dict={"out_1":np.float32(np.zeros((9,))),
                 "out_2":np.float32(np.zeros((1,))),
                 "image":np.float32(np.zeros((1,51,101))),
                 "out_3":np.float32(np.zeros((1,)))}

        reward=0
        done=False
        info={}
        return out_dict, reward, done, info


class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        # We do not know features-dim here before going over all the items,
        # so put something dummy for now. PyTorch requires calling
        # nn.Module.__init__ before adding modules
        super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)

        extractors = {}
        total_concat_size = 0
        # We need to know size of the output of this extractor,
        # so go over all the spaces and compute output feature sizes
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                n_input_channels = subspace.shape[0]
                extractors[key] = nn.Sequential(nn.Conv2d(n_input_channels, 16, kernel_size=3, stride=1, padding=1),
                                                nn.ReLU(),
                                                nn.MaxPool2d(2),
                                                nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
                                                nn.ReLU(),
                                                nn.MaxPool2d(2),
                                                
                                                nn.Flatten(),
                                                )
                
                # Uncomment if using the two conv2D layers:
                total_concat_size += ((subspace.shape[1]//4)*(subspace.shape[2]//4)*16)
                
                # Uncomment if only using flatten
                #total_concat_size += ((subspace.shape[1])*(subspace.shape[2]))

            elif key == "out_1" or "out_2" or "out_3"   :
                # Run through a simple MLP
                extractors[key] = nn.Sequential(nn.Linear(subspace.shape[0], 16),
                                                nn.ReLU(),
                                                nn.Linear(16,16),
                                                nn.ReLU(),
                                                nn.Flatten()
                                               )
                total_concat_size += (16)
           
        self.extractors = nn.ModuleDict(extractors)
        # Update the features dim manually
        self._features_dim = total_concat_size
        print(total_concat_size)
        
    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []
        # self.extractors contain nn.Modules that do all the processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))
        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return th.cat(encoded_tensor_list, dim=1)
    
policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    activation_fn=th.nn.ReLU,
    normalize_images=False,
    net_arch=[dict(pi=[256,256,256], vf=[256,256,256])],
)

# Create the vectorized environment
env=CustomEnv()    
num_cpu = 50 # Number of processes to use
env= make_vec_env(lambda:env, n_envs=num_cpu,vec_env_cls=SubprocVecEnv)
env=VecNormalize(env,gamma=1.0)

model = PPO("MultiInputPolicy",
            env,
            policy_kwargs=policy_kwargs,
            seed=42,
            n_steps=1200,
            batch_size=10000,
            verbose=1,
           device='cuda')

model.learn(100000)

Expected behavior

I expect the two conv2D layer to only add a modest amount of parameters.
### System Info

  • stable-baselines3==1.5.0
  • torch==1.11.0+cu113
  • gym==0.21.0
  • python==3.7.10

GPU is an RTX3080.
I installed everything using pip on a separate conda environment.

Additional context

I am surprised that the network runs fine with just a nn.Flatten() layer for the image feature extractor and that adding the two layers of Conv2D adds up to 6GB of memory usage.
I am using a lot of vectorized environment because the real use case involves an cpu intensive env.

Checklist

  • [ X] I have checked that there is no similar issue in the repo (required)
  • [X ] I have read the documentation (required)
  • [ X] I have provided a minimal working example to reproduce the bug (required)
@ThomasRochefortB ThomasRochefortB added the bug Something isn't working label Apr 14, 2022
@araffin araffin added the custom gym env Issue related to Custom Gym Env label Apr 14, 2022
@araffin
Copy link
Member

araffin commented Apr 14, 2022

Hello,
I will try to have a look, probably related to #834, see my comment for potential solutions: #834 (comment)

@araffin
Copy link
Member

araffin commented Apr 14, 2022

have done the math for the space taken by the batch of observations in Float32: ((101*51) + (9) + (1) +(1)) * 4bytes * 10000 = 0.2 GB of memory.
The value and policy network with the feature extractor should only be around 1.3M parameters which should more than fit in my 10GB of memory on the 3080.

The error happens when you collect data or when training is done (during gradient update)?
during gradient update, the GPU will need to store all the gradients for the backward pass.

Why do you use a VecNormalize with the image? (images ares normalized by default when defined properly, see env checker)
"UserWarning: It seems that your observation is an image but the dtype of your observation_space is not np.uint8. If your observation is not an image, we recommend you to flatten the observation to have only a 1D vector"

@ThomasRochefortB
Copy link
Author

@araffin the error happens at the start of the training.

I am using VecNormalize because of the general recommendation of the documentation. I guess I could only normalize the rewards and not the observations?

@ThomasRochefortB
Copy link
Author

@araffin What is curious is the drastic difference in memory usage between using only the Flatten() layer and the two Conv2D layer.

@araffin
Copy link
Member

araffin commented Apr 14, 2022

I guess I could only normalize the rewards and not the observations?

yes, and you can exclude specific observation key too (recommended here).

could you please give the output of sb3.get_system_info()?

@ThomasRochefortB
Copy link
Author

OS: Linux-5.14.18-100.fc33.x86_64-x86_64-with-fedora-33-Thirty_Three #1 SMP Fri Nov 12 17:38:44 UTC 2021
Python: 3.7.10
Stable-Baselines3: 1.5.0
PyTorch: 1.11.0+cu113
GPU Enabled: True
Numpy: 1.20.3
Gym: 0.21.0

@araffin
Copy link
Member

araffin commented Apr 14, 2022

have done the math for the space taken by the batch of observations in Float32: ((101*51) + (9) + (1) +(1)) * 4bytes * 10000 = 0.2 GB of memory.

Your math is only taking into account the input, no?
It does not take into account the intermediate tensors (for each layer of the CNN), nor the tensors used for computing gradients, no?

I think you should be able to provide a minimal code to reproduce that issue independent of SB3 (because your issues happens at train time).
In that case, that would be a PyTorch issue, not an SB3 one...

@ThomasRochefortB
Copy link
Author

@araffin I guess your right! Although its strange that the difference between the FeatureExtractor with the nn.Flatten() layer only vs. the one with the two Conv2D should only be a few thousand parameters. It should not have that much of an impact on the memory usage. Were talking a 6GB difference in memory usage for two Conv2D layers of 16 3x3 kernels each lol.

@ThomasRochefortB
Copy link
Author

ThomasRochefortB commented Apr 14, 2022

replacing the two nn.Conv2D layers by two nn.Linear layers of 1024 units works like a charm and has no significant impact on the memory usage. Those 2 dense layers should have more parameters than the two Conv2D layers... Again pointing to something weird with the nn.Conv2D memory usage.

@EloyAnguiano
Copy link

EloyAnguiano commented Jul 27, 2023

Still having some issues at #1630 with this. I am not using any Conv2D layer. In fact I am using a GraphTransformer, so I cannot reduce the size of the matrixes or the graph nor flatten them. Does anyone addressed the real issue with this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working custom gym env Issue related to Custom Gym Env
Projects
None yet
Development

No branches or pull requests

3 participants