Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

learn.py: Incompatibility with gym env despite having stable_baselines3 version 2.x #221

Closed
KevinHan1209 opened this issue Jun 23, 2024 · 0 comments

Comments

@KevinHan1209
Copy link

Hi,

I am running the learn.py example that came with the repo. The only modification I have made is turning it to a notebook (ipynb) file to render\evaluate the already trained model, but this error pops up in the original file too.

It seems like the training works perfectly fine and also saves the model, but I encounter an error at that instance of predict() method. It seems that there are inconsistencies with what gym vs sb3 returns for the reset() method, hence inputting an invalid argument into predict(). I tried this bypass, and while it fixed the issue at predict(), it ran into another issue later on running evaluation.py in the stable_baselines3 package, causing me to be unable to evaluate my state. This other issue is of the same nature as it has another problem with the dimensions that are returned with reset().

I made sure that I have gymnasium support installed as well, and have tried multiple versions of stable_baslines3 2.x. I'm kind of at a loss for what to do here. Does anyone know if I am supposed to edit the source code in sb3?

Error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[10], [line 1](vscode-notebook-cell:?execution_count=10&line=1)
----> [1](vscode-notebook-cell:?execution_count=10&line=1) render()

Cell In[9], [line 97](vscode-notebook-cell:?execution_count=9&line=97)
     [95](vscode-notebook-cell:?execution_count=9&line=95) start = time.time()
     [96](vscode-notebook-cell:?execution_count=9&line=96) for i in range((test_env.EPISODE_LEN_SEC+2)*test_env.CTRL_FREQ):
---> [97](vscode-notebook-cell:?execution_count=9&line=97)     action, _states = model.predict(obs,
     [98](vscode-notebook-cell:?execution_count=9&line=98)                                     deterministic=True
     [99](vscode-notebook-cell:?execution_count=9&line=99)                                     )
    [100](vscode-notebook-cell:?execution_count=9&line=100)     obs, reward, terminated, truncated, info = test_env.step(action)
    [101](vscode-notebook-cell:?execution_count=9&line=101)     obs2 = obs.squeeze()

File ~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:553, in BaseAlgorithm.predict(self, observation, state, episode_start, deterministic)
    [533](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:533) def predict(
    [534](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:534)     self,
    [535](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:535)     observation: Union[np.ndarray, Dict[str, np.ndarray]],
   (...)
    [538](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:538)     deterministic: bool = False,
    [539](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:539) ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
    [540](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:540)     """
    [541](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:541)     Get the policy action from an observation (and optional hidden state).
    [542](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/base_class.py:542)     Includes sugar-coating to handle different observations (e.g. normalizing images).
   (...)
...
    [361](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/policies.py:361)     )
    [363](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/policies.py:363) obs_tensor, vectorized_env = self.obs_to_tensor(observation)
    [365](https://file+.vscode-resource.vscode-cdn.net/Users/kevinhan/gym-pybullet-drones/gym_pybullet_drones/examples/~/opt/anaconda3/envs/drones/lib/python3.10/site-packages/stable_baselines3/common/policies.py:365) with th.no_grad():

ValueError: You have passed a tuple to the predict() function instead of a Numpy array or a Dict. You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) vs `obs = vec_env.reset()` (SB3 VecEnv). See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api

@KevinHan1209 KevinHan1209 changed the title Imcompatibility with gym env despite having stable_baselines3 version 2.x learn.py: Incompatibility with gym env despite having stable_baselines3 version 2.x Jun 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant