From 48edf3211275181716405509534d7697a7a9ed3c Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Tue, 12 Dec 2023 14:06:26 +0800 Subject: [PATCH 1/2] update readme --- Project.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.md b/Project.md index d7c455f5..38a8d1c7 100644 --- a/Project.md +++ b/Project.md @@ -18,7 +18,7 @@ However, in many practical applications, it is important to develop reasonable a In this paper, we propose an on-policy framework for discovering multiple strategies for the same task. Experimental results show that our method efficiently finds diverse strategies in a wide variety of reinforcement learning tasks. -- Paper: [DGPO: Discovering Multiple Strategies with Diversity-Guided Policy Optimization](https://arxiv.org/abs/2207.05631)(AAMAS Extended Abstract 2023) -- Authors: Wenze Chen, Shiyu Huang, Yuan Chiang, Ting Chen, Jun Zhu +- Paper: [DGPO: Discovering Multiple Strategies with Diversity-Guided Policy Optimization](https://arxiv.org/abs/2207.05631)(AAAAI 2024) +- Authors: Wenze Chen, Shiyu Huang, Yuan Chiang, Tim Pearce, Wei-Wei Tu, Ting Chen, Jun Zhu From e29fd9031476875e4b10c7ee47461795f57f7745 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Tue, 12 Dec 2023 14:52:29 +0800 Subject: [PATCH 2/2] improve test --- openrl/buffers/offpolicy_replay_data.py | 92 +++--- openrl/buffers/replay_data.py | 312 +++++++++--------- openrl/configs/config.py | 7 +- tests/test_buffer/test_generator.py | 88 +++++ tests/test_buffer/test_offpolicy_generator.py | 68 ++++ 5 files changed, 362 insertions(+), 205 deletions(-) create mode 100644 tests/test_buffer/test_generator.py create mode 100644 tests/test_buffer/test_offpolicy_generator.py diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py index 4d62d53f..31e52e85 100644 --- a/openrl/buffers/offpolicy_replay_data.py +++ b/openrl/buffers/offpolicy_replay_data.py @@ -97,52 +97,52 @@ def __init__( ) self.first_insert_flag = True - def dict_insert(self, data): - if self._mixed_obs: - for key in self.critic_obs.keys(): - self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() - for key in self.policy_obs.keys(): - self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() - for key in self.next_policy_obs.keys(): - self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][ - key - ].copy() - for key in self.next_critic_obs.keys(): - self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][ - key - ].copy() - else: - self.critic_obs[self.step + 1] = data["critic_obs"].copy() - self.policy_obs[self.step + 1] = data["policy_obs"].copy() - self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy() - self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy() - - if "rnn_states" in data: - self.rnn_states[self.step + 1] = data["rnn_states"].copy() - if "rnn_states_critic" in data: - self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() - if "actions" in data: - self.actions[self.step + 1] = data["actions"].copy() - if "action_log_probs" in data: - self.action_log_probs[self.step] = data["action_log_probs"].copy() - - if "value_preds" in data: - self.value_preds[self.step] = data["value_preds"].copy() - if "rewards" in data: - self.rewards[self.step + 1] = data["rewards"].copy() - if "masks" in data: - self.masks[self.step + 1] = data["masks"].copy() - - if "bad_masks" in data: - self.bad_masks[self.step + 1] = data["bad_masks"].copy() - if "active_masks" in data: - self.active_masks[self.step + 1] = data["active_masks"].copy() - if "action_masks" in data: - self.action_masks[self.step + 1] = data["action_masks"].copy() - - if (self.step + 1) % self.episode_length != 0: - self.first_insert_flag = False - self.step = (self.step + 1) % self.episode_length + # def dict_insert(self, data): + # if self._mixed_obs: + # for key in self.critic_obs.keys(): + # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() + # for key in self.policy_obs.keys(): + # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() + # for key in self.next_policy_obs.keys(): + # self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][ + # key + # ].copy() + # for key in self.next_critic_obs.keys(): + # self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][ + # key + # ].copy() + # else: + # self.critic_obs[self.step + 1] = data["critic_obs"].copy() + # self.policy_obs[self.step + 1] = data["policy_obs"].copy() + # self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy() + # self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy() + # + # if "rnn_states" in data: + # self.rnn_states[self.step + 1] = data["rnn_states"].copy() + # if "rnn_states_critic" in data: + # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() + # if "actions" in data: + # self.actions[self.step + 1] = data["actions"].copy() + # if "action_log_probs" in data: + # self.action_log_probs[self.step] = data["action_log_probs"].copy() + # + # if "value_preds" in data: + # self.value_preds[self.step] = data["value_preds"].copy() + # if "rewards" in data: + # self.rewards[self.step + 1] = data["rewards"].copy() + # if "masks" in data: + # self.masks[self.step + 1] = data["masks"].copy() + # + # if "bad_masks" in data: + # self.bad_masks[self.step + 1] = data["bad_masks"].copy() + # if "active_masks" in data: + # self.active_masks[self.step + 1] = data["active_masks"].copy() + # if "action_masks" in data: + # self.action_masks[self.step + 1] = data["action_masks"].copy() + # + # if (self.step + 1) % self.episode_length != 0: + # self.first_insert_flag = False + # self.step = (self.step + 1) % self.episode_length def init_buffer(self, raw_obs, action_masks=None): critic_obs = get_critic_obs(raw_obs) diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py index 40a4b383..8d092d7d 100644 --- a/openrl/buffers/replay_data.py +++ b/openrl/buffers/replay_data.py @@ -198,49 +198,49 @@ def get_batch_data( else: return np.concatenate(data[step]) - def all_batch_data(self, data_name: str, min=None, max=None): - assert hasattr(self, data_name) - data = getattr(self, data_name) - - if isinstance(data, ObsData): - return data.all_batch(min, max) - else: - return data[min:max].reshape((-1, *data.shape[3:])) - - def dict_insert(self, data): - if self._mixed_obs: - for key in self.critic_obs.keys(): - self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() - for key in self.policy_obs.keys(): - self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() - else: - self.critic_obs[self.step + 1] = data["critic_obs"].copy() - self.policy_obs[self.step + 1] = data["policy_obs"].copy() - - if "rnn_states" in data: - self.rnn_states[self.step + 1] = data["rnn_states"].copy() - if "rnn_states_critic" in data: - self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() - if "actions" in data: - self.actions[self.step] = data["actions"].copy() - if "action_log_probs" in data: - self.action_log_probs[self.step] = data["action_log_probs"].copy() - - if "value_preds" in data: - self.value_preds[self.step] = data["value_preds"].copy() - if "rewards" in data: - self.rewards[self.step] = data["rewards"].copy() - if "masks" in data: - self.masks[self.step + 1] = data["masks"].copy() - - if "bad_masks" in data: - self.bad_masks[self.step + 1] = data["bad_masks"].copy() - if "active_masks" in data: - self.active_masks[self.step + 1] = data["active_masks"].copy() - if "action_masks" in data: - self.action_masks[self.step + 1] = data["action_masks"].copy() - - self.step = (self.step + 1) % self.episode_length + # def all_batch_data(self, data_name: str, min=None, max=None): + # assert hasattr(self, data_name) + # data = getattr(self, data_name) + # + # if isinstance(data, ObsData): + # return data.all_batch(min, max) + # else: + # return data[min:max].reshape((-1, *data.shape[3:])) + + # def dict_insert(self, data): + # if self._mixed_obs: + # for key in self.critic_obs.keys(): + # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() + # for key in self.policy_obs.keys(): + # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() + # else: + # self.critic_obs[self.step + 1] = data["critic_obs"].copy() + # self.policy_obs[self.step + 1] = data["policy_obs"].copy() + # + # if "rnn_states" in data: + # self.rnn_states[self.step + 1] = data["rnn_states"].copy() + # if "rnn_states_critic" in data: + # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() + # if "actions" in data: + # self.actions[self.step] = data["actions"].copy() + # if "action_log_probs" in data: + # self.action_log_probs[self.step] = data["action_log_probs"].copy() + # + # if "value_preds" in data: + # self.value_preds[self.step] = data["value_preds"].copy() + # if "rewards" in data: + # self.rewards[self.step] = data["rewards"].copy() + # if "masks" in data: + # self.masks[self.step + 1] = data["masks"].copy() + # + # if "bad_masks" in data: + # self.bad_masks[self.step + 1] = data["bad_masks"].copy() + # if "active_masks" in data: + # self.active_masks[self.step + 1] = data["active_masks"].copy() + # if "action_masks" in data: + # self.action_masks[self.step + 1] = data["action_masks"].copy() + # + # self.step = (self.step + 1) % self.episode_length def insert( self, @@ -947,119 +947,119 @@ def naive_recurrent_generator(self, advantages, num_mini_batch): yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch - def recurrent_generator_v2( - self, advantages, num_mini_batch=None, mini_batch_size=None - ): - """ - Yield training data for MLP policies. - :param advantages: (np.ndarray) advantage estimates. - :param num_mini_batch: (int) number of minibatches to split the batch into. - :param mini_batch_size: (int) number of samples in each minibatch. - """ - episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] - batch_size = n_rollout_threads * episode_length - - if mini_batch_size is None: - assert ( - batch_size >= num_mini_batch - ), ( - "PPO requires the number of processes ({}) " - "* number of steps ({}) = {} " - "to be greater than or equal to the number of PPO mini batches ({})." - "".format( - n_rollout_threads, - episode_length, - n_rollout_threads * episode_length, - num_mini_batch, - ) - ) - mini_batch_size = batch_size // num_mini_batch - - rand = torch.randperm(batch_size).numpy() - sampler = [ - rand[i * mini_batch_size : (i + 1) * mini_batch_size] - for i in range(num_mini_batch) - ] - - # keep (num_agent, dim) - critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:]) - - policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:]) - - rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) - - rnn_states_critic = self.rnn_states_critic[:-1].reshape( - -1, *self.rnn_states_critic.shape[2:] - ) - - actions = self.actions.reshape(-1, *self.actions.shape[2:]) - - if self.action_masks is not None: - action_masks = self.action_masks[:-1].reshape( - -1, *self.action_masks.shape[2:] - ) - - value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:]) - - returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:]) - - masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:]) - - active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:]) - - action_log_probs = self.action_log_probs.reshape( - -1, *self.action_log_probs.shape[2:] - ) - - advantages = advantages.reshape(-1, *advantages.shape[2:]) - - shuffle = False - if shuffle: - rows, cols = _shuffle_agent_grid(batch_size, num_agents) - - if self.action_masks is not None: - action_masks = action_masks[rows, cols] - critic_obs = critic_obs[rows, cols] - policy_obs = policy_obs[rows, cols] - rnn_states = rnn_states[rows, cols] - rnn_states_critic = rnn_states_critic[rows, cols] - actions = actions[rows, cols] - value_preds = value_preds[rows, cols] - returns = returns[rows, cols] - masks = masks[rows, cols] - active_masks = active_masks[rows, cols] - action_log_probs = action_log_probs[rows, cols] - advantages = advantages[rows, cols] - - for indices in sampler: - # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim] - critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:]) - policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:]) - rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:]) - rnn_states_critic_batch = rnn_states_critic[indices].reshape( - -1, *rnn_states_critic.shape[2:] - ) - actions_batch = actions[indices].reshape(-1, *actions.shape[2:]) - if self.action_masks is not None: - action_masks_batch = action_masks[indices].reshape( - -1, *action_masks.shape[2:] - ) - else: - action_masks_batch = None - value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:]) - return_batch = returns[indices].reshape(-1, *returns.shape[2:]) - masks_batch = masks[indices].reshape(-1, *masks.shape[2:]) - active_masks_batch = active_masks[indices].reshape( - -1, *active_masks.shape[2:] - ) - old_action_log_probs_batch = action_log_probs[indices].reshape( - -1, *action_log_probs.shape[2:] - ) - if advantages is None: - adv_targ = None - else: - adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:]) - yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch + # def recurrent_generator_v2( + # self, advantages, num_mini_batch=None, mini_batch_size=None + # ): + # """ + # Yield training data for MLP policies. + # :param advantages: (np.ndarray) advantage estimates. + # :param num_mini_batch: (int) number of minibatches to split the batch into. + # :param mini_batch_size: (int) number of samples in each minibatch. + # """ + # episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] + # batch_size = n_rollout_threads * episode_length + # + # if mini_batch_size is None: + # assert ( + # batch_size >= num_mini_batch + # ), ( + # "PPO requires the number of processes ({}) " + # "* number of steps ({}) = {} " + # "to be greater than or equal to the number of PPO mini batches ({})." + # "".format( + # n_rollout_threads, + # episode_length, + # n_rollout_threads * episode_length, + # num_mini_batch, + # ) + # ) + # mini_batch_size = batch_size // num_mini_batch + # + # rand = torch.randperm(batch_size).numpy() + # sampler = [ + # rand[i * mini_batch_size : (i + 1) * mini_batch_size] + # for i in range(num_mini_batch) + # ] + # + # # keep (num_agent, dim) + # critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:]) + # + # policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:]) + # + # rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) + # + # rnn_states_critic = self.rnn_states_critic[:-1].reshape( + # -1, *self.rnn_states_critic.shape[2:] + # ) + # + # actions = self.actions.reshape(-1, *self.actions.shape[2:]) + # + # if self.action_masks is not None: + # action_masks = self.action_masks[:-1].reshape( + # -1, *self.action_masks.shape[2:] + # ) + # + # value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:]) + # + # returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:]) + # + # masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:]) + # + # active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:]) + # + # action_log_probs = self.action_log_probs.reshape( + # -1, *self.action_log_probs.shape[2:] + # ) + # + # advantages = advantages.reshape(-1, *advantages.shape[2:]) + # + # shuffle = False + # if shuffle: + # rows, cols = _shuffle_agent_grid(batch_size, num_agents) + # + # if self.action_masks is not None: + # action_masks = action_masks[rows, cols] + # critic_obs = critic_obs[rows, cols] + # policy_obs = policy_obs[rows, cols] + # rnn_states = rnn_states[rows, cols] + # rnn_states_critic = rnn_states_critic[rows, cols] + # actions = actions[rows, cols] + # value_preds = value_preds[rows, cols] + # returns = returns[rows, cols] + # masks = masks[rows, cols] + # active_masks = active_masks[rows, cols] + # action_log_probs = action_log_probs[rows, cols] + # advantages = advantages[rows, cols] + # + # for indices in sampler: + # # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim] + # critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:]) + # policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:]) + # rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:]) + # rnn_states_critic_batch = rnn_states_critic[indices].reshape( + # -1, *rnn_states_critic.shape[2:] + # ) + # actions_batch = actions[indices].reshape(-1, *actions.shape[2:]) + # if self.action_masks is not None: + # action_masks_batch = action_masks[indices].reshape( + # -1, *action_masks.shape[2:] + # ) + # else: + # action_masks_batch = None + # value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:]) + # return_batch = returns[indices].reshape(-1, *returns.shape[2:]) + # masks_batch = masks[indices].reshape(-1, *masks.shape[2:]) + # active_masks_batch = active_masks[indices].reshape( + # -1, *active_masks.shape[2:] + # ) + # old_action_log_probs_batch = action_log_probs[indices].reshape( + # -1, *action_log_probs.shape[2:] + # ) + # if advantages is None: + # adv_targ = None + # else: + # adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:]) + # yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] diff --git a/openrl/configs/config.py b/openrl/configs/config.py index 1137d6e3..49bdae79 100644 --- a/openrl/configs/config.py +++ b/openrl/configs/config.py @@ -498,13 +498,14 @@ def create_config_parser(): ) parser.add_argument( "--use_popart", - action="store_true", default=False, + type=bool, help="by default False, use PopArt to normalize rewards.", ) parser.add_argument( "--dual_clip_ppo", default=False, + type=bool, help="by default False, use dual-clip ppo.", ) parser.add_argument( @@ -730,8 +731,8 @@ def create_config_parser(): ) parser.add_argument( "--use_gae", - action="store_false", default=True, + type=bool, help="use generalized advantage estimation", ) parser.add_argument( @@ -748,8 +749,8 @@ def create_config_parser(): ) parser.add_argument( "--use_proper_time_limits", - action="store_true", default=False, + type=bool, help="compute returns taking into account time limits", ) parser.add_argument( diff --git a/tests/test_buffer/test_generator.py b/tests/test_buffer/test_generator.py new file mode 100644 index 00000000..4de33c02 --- /dev/null +++ b/tests/test_buffer/test_generator.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest + +from openrl.envs.common import make +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent + + +@pytest.fixture( + scope="module", + params=[ + "--use_recurrent_policy true --use_joint_action_loss true", + "--use_recurrent_policy true --use_joint_action_loss false", + "--use_recurrent_policy false --use_naive_recurrent true", + "--use_recurrent_policy false --use_naive_recurrent false", + ], +) +def generator_type(request): + return request.param + + +@pytest.fixture(scope="module", params=["--use_gae true", "--use_gae false"]) +def use_gae(request): + return request.param + + +@pytest.fixture( + scope="module", + params=["--use_proper_time_limits true", "--use_proper_time_limits false"], +) +def use_proper_time_limits(request): + return request.param + + +@pytest.fixture( + scope="module", + params=[ + "--use_popart true --use_valuenorm false", + "--use_popart false --use_valuenorm true", + "--use_popart false --use_valuenorm false", + ], +) +def use_popart(request): + return request.param + + +@pytest.fixture(scope="module") +def config(use_proper_time_limits, use_popart, use_gae, generator_type): + config_str = ( + use_proper_time_limits + " " + use_popart + " " + use_gae + " " + generator_type + ) + + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(config_str.split()) + return cfg + + +@pytest.mark.unittest +def test_buffer_generator(config): + env = make("CartPole-v1", env_num=2) + agent = Agent(Net(env, cfg=config)) + agent.train(total_time_steps=200) + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_buffer/test_offpolicy_generator.py b/tests/test_buffer/test_offpolicy_generator.py new file mode 100644 index 00000000..5e5da276 --- /dev/null +++ b/tests/test_buffer/test_offpolicy_generator.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest + +from openrl.envs.common import make +from openrl.modules.common import DQNNet as Net +from openrl.runners.common import DQNAgent as Agent + + +@pytest.fixture( + scope="module", + params=[ + "--use_recurrent_policy false --use_joint_action_loss false", + ], +) +def generator_type(request): + return request.param + + +@pytest.fixture(scope="module", params=["--use_proper_time_limits false"]) +def use_proper_time_limits(request): + return request.param + + +@pytest.fixture(scope="module", params=["--use_popart false --use_valuenorm false"]) +def use_popart(request): + return request.param + + +@pytest.fixture(scope="module") +def config(use_proper_time_limits, use_popart, generator_type): + config_str = use_proper_time_limits + " " + use_popart + " " + generator_type + + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(config_str.split()) + return cfg + + +@pytest.mark.unittest +def test_buffer_generator(config): + env = make("CartPole-v1", env_num=2) + agent = Agent(Net(env, cfg=config)) + agent.train(total_time_steps=200) + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))