Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue in forward(....) function of class ActorCriticPolicy while working on Custom Gym Environment. #2043

Open
5 tasks done
SachinVashisth opened this issue Nov 20, 2024 · 8 comments
Labels
custom gym env Issue related to Custom Gym Env

Comments

@SachinVashisth
Copy link

SachinVashisth commented Nov 20, 2024

🐛 Bug

I have created a Custom Environment as well as Custom ActorCritic Policy. In the custom environment, I have two functions reset and step. I initialize a variable score to 0 in the reset function but in the step function, this variable score is calculated and gets a non-zero value (say 0.95). I return this score variable in both the reset and step functions but when I print the variable obs inside the forward function of the class ActorCriticPolicy given in the file common/policies.py, it always prints the value 0 i.e. it always takes the value of the score that I initialize in the reset function, not from the step function.

Code example

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import PPO
from gymnasium.wrappers import FlattenObservation
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.policies import ActorCriticPolicy
from transformers import T5Tokenizer, T5ForConditionalGeneration

class CustomEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.score = 0
        self.observation_space = gym.spaces.Dict({
            'context': gym.spaces.Box(low = -np.inf, high = np.inf, shape = (2048, ), dtype = np.float32),
            'score': gym.spaces.Box(low = 0, high = 1, shape = (1,), dtype = np.float32),
        })
        self.action_space = gym.spaces.Discrete(5)

    def reset(self, seed=None, options=None):
       self.score = 0
       '''
       _encode_text(...) is a method that encodes the string in the context with a T5 Tokenizer and passes it through the encode part of the T5 model. I have initialized the T5 model and tokenizer in init() function but not shown for simplicity. Similarly, I have not shown the _encode_text(....) method for simplicity. 
       '''
       # It is ensured that both "context" and "score" are converted to numpy array in the dictionary before returning.
        return {"context": self._encode_text("Some Context"), 
                "score": np.array([self.score], dtype = np.float32)}, {}

    def step(self, action):
        # here TakeAction(...) is a function that uses the "action" variable to produce some desired result. Not shown for simplicity.
        new_context, new_score = self.TakeAction(action)
        reward = new_score
        self.score = new_score
        terminated = new_score > 0.97
        truncated = False
        info = {}
        return return {"context": self._encode_text(new_context), 
                "score": np.array([self.score], dtype = np.float32)}, reward, terminated, truncated, info


class CustomPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, model_name="google/flan-t5-xl", **kwargs):
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)

        # Load HuggingFace model and tokenizer
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)

        # Policy head (action logits) and value head
        self.policy_head = torch.nn.Linear(self.model.config.d_model , self.action_space.n)
        self.value_head = torch.nn.Linear(self.model.config.d_model, 1)


env = CustomEnv()
env = FlattenObservation(env)
check_env(env)

model = PPO(CustomPolicy, env, verbose=1, policy_kwargs={"model_name": "google/flan-t5-xl"}).learn(1000)

As shown in the above code, I didn't implement a custom forward method for my CustomPolicy class. When I try to print the obs variable inside the forward function defined in the ActorCriticPolicy class in the file common/policies.py, then I always see the value of the score as 0. Same for the variable context.

# This function is taken from the class ActorCriticPolicy of the file common/policies.py. Here, I put it for illustration purposes. 
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        """
        Forward pass in all the networks (actor and critic)

        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """
        
        print("obs in the implicity defined forward function is: ", obs)

The problem is that I am afraid that during rollout, when the time comes to update the policy, then it doesn't take into account the score and context values (in the observation) from the step function but from the reset function.

Observation returned by the reset method is:

{'context': array([ 0.01571305,  0.00165541, -0.0109045 , ...,  0.02560602, 0.06692536, -0.01847569], dtype=float32), 'score': array([0.], dtype=float32)}

Observation returned by the step method is:

{'context': array([ 0.12903318,  0.03283503, -0.01909198, ...,  0.05851421, 0.03815667, -0.00293452], dtype=float32), 'score': array([0.9716854], dtype=float32)}

But obs shown in the forward function of the ActorCriticPolicy class is:

tensor([[ 0.0157,  0.0016, -0.0109,  ...,  0.0669, -0.0184,  0.0000]], device='cuda:0')

As you can see, the forward function shows the context and score returned from the reset function, not from the step function.

Relevant log output / Error message

No response

System Info

No response

Checklist

@SachinVashisth SachinVashisth added the custom gym env Issue related to Custom Gym Env label Nov 20, 2024
@araffin
Copy link
Member

araffin commented Nov 20, 2024

env = FlattenObservation(env)

@SachinVashisth
Copy link
Author

Yes Yes, I have done that already.
env = FlattenObservation(env)

After the line env = CustomEnv(), I have put the FlattenObservation wrapper on the env variable.

Because earlier, I was getting the error:
dict object has no attribute flatten

@araffin
Copy link
Member

araffin commented Nov 20, 2024

Sorry, I read too fast, I thought the issue was the concatenation. Btw SB3 does support dict obs.
Also it seems that you might want to define a custom features extractor, not a full custom policy (see docs).

I'll have a look later, but there is no reason the transition from the step method would not be taken into account.

@araffin
Copy link
Member

araffin commented Nov 21, 2024

Using this minimal code, I don't see any problem:

import gymnasium as gym
import numpy as np
import torch
import torch as th
from gymnasium import spaces
from gymnasium.wrappers import FlattenObservation
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.policies import ActorCriticPolicy


class CustomEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Dict(
            {
                "context": spaces.Box(
                    low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32
                ),
                "score": spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
            }
        )
        self.action_space = gym.spaces.Discrete(5)

    def reset(self, seed=None, options=None):
        return {
            "context": np.array([1.0, 2.0], dtype=np.float32),
            "score": np.array([0.0], dtype=np.float32),
        }, {}

    def step(self, action):
        reward = 0.0
        terminated = False
        truncated = False
        info = {}
        return (
            {
                "context": np.array([1.0, 2.0], dtype=np.float32),
                "score": np.array([0.98], dtype=np.float32),
            },
            reward,
            terminated,
            truncated,
            info,
        )


class CustomPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)

        # Policy head (action logits) and value head
        self.policy_head = torch.nn.Linear(10, self.action_space.n)
        self.value_head = torch.nn.Linear(10, 1)

    def forward(
        self, obs: th.Tensor, deterministic: bool = False
    ) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        """
        Forward pass in all the networks (actor and critic)

        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """

        print("obs in the implicity defined forward function is: ", obs)
        return super().forward(obs, deterministic)


env = CustomEnv()
env = FlattenObservation(env)
check_env(env)

model = PPO(
    CustomPolicy,
    env,
    verbose=1,
    n_steps=8,
    batch_size=8,
    n_epochs=1,
).learn(10)

@SachinVashisth
Copy link
Author

Hi, Thanks for the reply.

I just made a small change in the code that you provided and I reproduced the same problem regarding the forward function.

The change is that:-
I added a variable self.Threshold = 0.95 in the init function of the CustomEnv and now, the terminated condition in the step function is terminated = 0.98 >= self.Threshold.
As the value of the score in the step function is 0.98, so this boolean variable terminated becomes True, and the function resets. Due to this, the obs variable in the forward function takes score and context value from the reset function.

import gymnasium as gym
import numpy as np
import json
import torch
import torch as th
from gymnasium import spaces
from gymnasium.wrappers import FlattenObservation
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.policies import ActorCriticPolicy


class CustomEnv(gym.Env):
    def __init__(self):
        super().__init__()
        
        self.Threshold = 0.95
        self.observation_space = spaces.Dict(
            {
                "context": spaces.Box(
                    low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32
                ),
                "score": spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
            }
        )
        self.action_space = gym.spaces.Discrete(5)

    def reset(self, seed=None, options=None):
        return {
            "context": np.array([1.0, 2.0], dtype=np.float32),
            "score": np.array([0.0], dtype=np.float32),
        }, {}

    def step(self, action):
        reward = 0.0
        terminated = 0.98 >= self.Threshold
        truncated = False
        info = {}
        return {
                "context": np.array([1.0, 2.0], dtype=np.float32),
                "score": np.array([0.98], dtype=np.float32),
            }, reward, terminated, truncated, info

class CustomPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, **kwargs):
        super().__init__(observation_space, action_space, lr_schedule, **kwargs)

        # Policy head (action logits) and value head
        self.policy_head = torch.nn.Linear(10, self.action_space.n)
        self.value_head = torch.nn.Linear(10, 1)

env = CustomEnv()
env = FlattenObservation(env)
check_env(env)

model = PPO(
    CustomPolicy,
    env,
    verbose=1,
    n_steps=5
).learn(10)

In terms of execution:-
When the terminated condition was False earlier, then the reset function is called only once and after that, only step function was called like this: reset --> forward --> step --> forward --> step --> forward --> and so on.

But now, this is the order: reset --> forward --> step --> reset --> forward --> step --> reset --> forward --> and so on.

Actually, in the code, I am using training data that I pass to the CustomEnv. Every time the score in the step function goes above the threshold, I want to move to the next training instance, where the score and the context are re-initialized in the reset function. In the step function, the score and the context are then calculated based on the logic applied to the current training instance.

@araffin
Copy link
Member

araffin commented Nov 21, 2024

Due to this, the obs variable in the forward function takes score and context value from the reset function.

this is correct, the terminal observation is only used to predict the value, not the action:

# Handle timeout by bootstrapping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value

(see some link in our doc + issue to explain the distinction between termination/truncation)

Predicting an action here would not make sense since the episode is over (so the action will not be executed)

Note: the env variable in SB3 is a VecEnv, so with automated reset (see doc)

@SachinVashisth
Copy link
Author

Thanks for the reply and the code.

I looked at the code and the documentation.
If I understood it correctly, then it says that:

  1. For the terminated = True condition, the environment resets automatically, and hence, I see the observations returned by the reset function in the forward function of class ActorCriticPolicy as the episode has ended, and hence, terminal observations (the one just before the episode ends) cannot be used to compute the next action. Now, the agent needs a new starting point which will be provided by the reset function.
  2. This also means that if the score exceeds the threshold in the first step (as a result set terminated = True), then the agent still updates the policy (after n_steps are covered) using observations obtained from this single step, but the learning will be affected in the sense that agent will struggle to generalize (especially when episodes terminate frequently in single steps).
  3. Although, for the case when the score exceeds the threshold in the first step, I also thought of deferring the terminated = True condition for one extra step so that the observations from the step function get included in the rollout buffer, giving the agent more data for policy learning.
    Like this:
def step(self, action):
        reward = 0.0
        terminated = bool(0.98 >= self.Threshold)
        truncated = False
        info = {}
        obs = { "context": np.array([1.0, 2.0], dtype=np.float32),
                "score": np.array([0.98], dtype=np.float32)
              }
        if terminated:
            info = {"terminal_observation": obs}
            return obs, reward, False, truncated, info
        
        return obs, reward, terminated, truncated, info

I guess now, it will be cover the case when episode terminates in the first step.

@araffin
Copy link
Member

araffin commented Nov 27, 2024

terminal observations (the one just before the episode ends) cannot be used to compute the next action

yes

using observations obtained from this single step, but the learning will be affected in the sense that agent will struggle to generalize

I'm not sure about that one. But it sounds about right. If your episodes are single steps, then it will use single steps to train (which sounds reasonable).

giving the agent more data for policy learning.

it seems that you are changing the problem you want to solve. Finishing after one step doesn't seem bad if it reaches the objective you are asking for.
Also, if the episode terminates after one step only, it look more like a bandit problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom gym env Issue related to Custom Gym Env
Projects
None yet
Development

No branches or pull requests

2 participants