From b7b503a76e5c066bf96b5db54fb08dd8b803c1d4 Mon Sep 17 00:00:00 2001 From: luyudong Date: Fri, 15 Sep 2023 18:03:05 +0800 Subject: [PATCH 1/6] Fix test files --- ding/entry/tests/test_serial_entry.py | 2 +- ding/example/dt.py | 2 +- ding/model/template/__init__.py | 2 +- .../{dt.py => decision_transformer.py} | 2 +- ding/model/template/tests/test_acer.py | 42 ++++++++++ ding/model/template/tests/test_bcq.py | 76 +++++++++++++++++++ .../tests/test_decision_transformer.py | 54 +++++++++---- ding/model/template/tests/test_edac.py | 58 ++++++++++++++ ding/model/template/tests/test_ngu.py | 49 ++++++++++++ dizoo/atari/entry/atari_dt_main.py | 2 +- 10 files changed, 270 insertions(+), 19 deletions(-) rename ding/model/template/{dt.py => decision_transformer.py} (99%) create mode 100644 ding/model/template/tests/test_acer.py create mode 100644 ding/model/template/tests/test_bcq.py create mode 100644 ding/model/template/tests/test_edac.py create mode 100644 ding/model/template/tests/test_ngu.py diff --git a/ding/entry/tests/test_serial_entry.py b/ding/entry/tests/test_serial_entry.py index 1a44c7b548..1ee45f9535 100644 --- a/ding/entry/tests/test_serial_entry.py +++ b/ding/entry/tests/test_serial_entry.py @@ -650,7 +650,7 @@ def test_discrete_dt(): from ding.utils import set_pkg_seed from ding.data import create_dataset from ding.config import compile_config - from ding.model.template.dt import DecisionTransformer + from ding.model.template.decision_transformer import DecisionTransformer from ding.policy import DTPolicy from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ OfflineMemoryDataFetcher, offline_logger, termination_checker diff --git a/ding/example/dt.py b/ding/example/dt.py index 84d2b9d522..74ea1525de 100644 --- a/ding/example/dt.py +++ b/ding/example/dt.py @@ -1,6 +1,6 @@ import gym from ditk import logging -from ding.model.template.dt import DecisionTransformer +from ding.model.template.decision_transformer import DecisionTransformer from ding.policy import DTPolicy from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2 from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index df5f337888..b2dd815287 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -22,7 +22,7 @@ from .maqac import DiscreteMAQAC, ContinuousMAQAC from .madqn import MADQN from .vae import VanillaVAE -from .dt import DecisionTransformer +from .decision_transformer import DecisionTransformer from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS from .bcq import BCQ from .edac import EDAC diff --git a/ding/model/template/dt.py b/ding/model/template/decision_transformer.py similarity index 99% rename from ding/model/template/dt.py rename to ding/model/template/decision_transformer.py index da1e72f7d6..73330da25f 100644 --- a/ding/model/template/dt.py +++ b/ding/model/template/decision_transformer.py @@ -183,7 +183,7 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None): action_preds = self.predict_action(h[:, 1]) # predict action given r, s else: state_embeddings = self.state_encoder( - states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous() + states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() ) # (batch * block_size, h_dim) state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim) returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) diff --git a/ding/model/template/tests/test_acer.py b/ding/model/template/tests/test_acer.py new file mode 100644 index 0000000000..ded6b50bb6 --- /dev/null +++ b/ding/model/template/tests/test_acer.py @@ -0,0 +1,42 @@ +import torch +import pytest +from itertools import product + +from ding.model.template import ACER +from ding.torch_utils import is_differentiable + + +B = 4 +obs_shape = [4, (8, ), (4, 64, 64)] +act_shape = [3, (6, )] +args = list(product(*[obs_shape, act_shape])) + +@pytest.mark.unittest +class TestACER: + + @pytest.mark.parametrize('obs_shape, act_shape', args) + def test_ACER(self, obs_shape, act_shape): + if isinstance(obs_shape, int): + inputs = torch.randn(B, obs_shape) + else: + inputs = torch.randn(B, *obs_shape) + model = ACER(obs_shape, act_shape) + + outputs_c = model(inputs, mode='compute_critic') + assert isinstance(outputs_c, dict) + if isinstance(act_shape, int): + assert outputs_c['q_value'].shape == (B, act_shape) + elif len(act_shape) == 1: + assert outputs_c['q_value'].shape == (B, *act_shape) + + outputs_a = model(inputs, mode='compute_actor') + assert isinstance(outputs_a, dict) + if isinstance(act_shape, int): + assert outputs_a['logit'].shape == (B, act_shape) + elif len(act_shape) == 1: + assert outputs_a['logit'].shape == (B, *act_shape) + + outputs = {**outputs_a, **outputs_c} + loss = sum([v.sum() for v in outputs.values()]) + is_differentiable(loss, model) + diff --git a/ding/model/template/tests/test_bcq.py b/ding/model/template/tests/test_bcq.py new file mode 100644 index 0000000000..894a92c0b1 --- /dev/null +++ b/ding/model/template/tests/test_bcq.py @@ -0,0 +1,76 @@ +import pytest +from itertools import product +import torch +from ding.model.template import BCQ +from ding.torch_utils import is_differentiable + + +B = 4 +obs_shape = [4, (8, )] +act_shape = [3, (6, )] +args = list(product(*[obs_shape, act_shape])) + + +@pytest.mark.unittest +class TestBCQ: + + def output_check(self, model, outputs): + if isinstance(outputs, torch.Tensor): + loss = outputs.sum() + elif isinstance(outputs, dict): + loss = sum([v.sum() for v in outputs.values()]) + is_differentiable(loss, model) + + @pytest.mark.parametrize('obs_shape, act_shape', args) + def test_BCQ(self, obs_shape, act_shape): + if isinstance(obs_shape, int): + inputs_obs = torch.randn(B, obs_shape) + else: + inputs_obs = torch.randn(B, *obs_shape) + if isinstance(act_shape, int): + inputs_act = torch.randn(B, act_shape) + else: + inputs_act = torch.randn(B, *act_shape) + inputs = {'obs': inputs_obs, 'action': inputs_act} + model = BCQ(obs_shape, act_shape) + + outputs_c = model(inputs, mode='compute_critic') + assert isinstance(outputs_c, dict) + if isinstance(act_shape, int): + assert torch.stack(outputs_c['q_value']).shape == (2, B) + else: + assert torch.stack(outputs_c['q_value']).shape == (2, B) + self.output_check(model.critic, torch.stack(outputs_c['q_value'])) + + outputs_a = model(inputs, mode='compute_actor') + assert isinstance(outputs_a, dict) + if isinstance(act_shape, int): + assert outputs_a['action'].shape == (B, act_shape) + elif len(act_shape) == 1: + assert outputs_a['action'].shape == (B, *act_shape) + self.output_check(model.actor, outputs_a) + + outputs_vae = model(inputs, mode='compute_vae') + assert isinstance(outputs_vae, dict) + if isinstance(act_shape, int): + assert outputs_vae['recons_action'].shape == (B, act_shape) + assert outputs_vae['mu'].shape == (B, act_shape * 2) + assert outputs_vae['log_var'].shape == (B, act_shape * 2) + assert outputs_vae['z'].shape == (B, act_shape * 2) + elif len(act_shape) == 1: + assert outputs_vae['recons_action'].shape == (B, *act_shape) + assert outputs_vae['mu'].shape == (B, act_shape[0] * 2) + assert outputs_vae['log_var'].shape == (B, act_shape[0] * 2) + assert outputs_vae['z'].shape == (B, act_shape[0] * 2) + if isinstance(obs_shape, int): + assert outputs_vae['prediction_residual'].shape == (B, obs_shape) + else: + assert outputs_vae['prediction_residual'].shape == (B, *obs_shape) + + outputs_eval = model(inputs, mode='compute_eval') + assert isinstance(outputs_eval, dict) + assert isinstance(outputs_eval, dict) + if isinstance(act_shape, int): + assert outputs_eval['action'].shape == (B, act_shape) + elif len(act_shape) == 1: + assert outputs_eval['action'].shape == (B, *act_shape) diff --git a/ding/model/template/tests/test_decision_transformer.py b/ding/model/template/tests/test_decision_transformer.py index 0ee054d176..17096dab25 100644 --- a/ding/model/template/tests/test_decision_transformer.py +++ b/ding/model/template/tests/test_decision_transformer.py @@ -1,23 +1,30 @@ import pytest from itertools import product import torch +import torch.nn as nn import torch.nn.functional as F from ding.model.template import DecisionTransformer from ding.torch_utils import is_differentiable -args = ['continuous', 'discrete'] - +action_space = ['continuous', 'discrete'] +state_encoder = [None, nn.Sequential(nn.Flatten(), nn.Linear(8, 8), nn.Tanh())] +args = list(product(*[action_space, state_encoder])) +args.pop(1) @pytest.mark.unittest -@pytest.mark.parametrize('action_space', args) -def test_decision_transformer(action_space): +@pytest.mark.parametrize('action_space, state_encoder', args) +def test_decision_transformer(action_space, state_encoder): B, T = 4, 6 - state_dim = 3 + if state_encoder: + state_dim = (2, 2, 2) + else: + state_dim = 3 act_dim = 2 DT_model = DecisionTransformer( state_dim=state_dim, act_dim=act_dim, + state_encoder=state_encoder, n_blocks=3, h_dim=8, context_len=T, @@ -27,8 +34,14 @@ def test_decision_transformer(action_space): ) is_continuous = True if action_space == 'continuous' else False - timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T - states = torch.randn([B, T, state_dim]) # B x T x state_dim + if state_encoder: + timesteps = torch.randint(0, 100, [B, 3*T-1, 1], dtype=torch.long) # B x T + else: + timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T + if isinstance(state_dim, int): + states = torch.randn([B, T, state_dim]) # B x T x state_dim + else: + states = torch.randn([B, T, *state_dim]) # B x T x state_dim if action_space == 'continuous': actions = torch.randn([B, T, act_dim]) # B x T x act_dim action_target = torch.randn([B, T, act_dim]) @@ -51,12 +64,19 @@ def test_decision_transformer(action_space): state_preds, action_preds, return_preds = DT_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go ) - assert state_preds.shape == (B, T, state_dim) + if state_encoder: + assert state_preds == None + assert return_preds == None + else: + assert state_preds.shape == (B, T, state_dim) + assert return_preds.shape == (B, T, 1) assert action_preds.shape == (B, T, act_dim) - assert return_preds.shape == (B, T, 1) # only consider non padded elements - action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0] + if state_encoder: + action_preds = action_preds.reshape(-1, act_dim) + else: + action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0] if is_continuous: action_target = action_target.view(-1, act_dim)[traj_mask.view(-1, ) > 0] @@ -68,11 +88,17 @@ def test_decision_transformer(action_space): else: action_loss = F.cross_entropy(action_preds, action_target) - # print(action_loss) - # is_differentiable(action_loss, DT_model) - is_differentiable( + if state_encoder: + is_differentiable( + action_loss, [ + DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, + DT_model.state_encoder + ] + ) + else: + is_differentiable( action_loss, [ DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg, DT_model.embed_state ] - ) # pass + ) diff --git a/ding/model/template/tests/test_edac.py b/ding/model/template/tests/test_edac.py new file mode 100644 index 0000000000..3569dab488 --- /dev/null +++ b/ding/model/template/tests/test_edac.py @@ -0,0 +1,58 @@ +import torch +import pytest +from itertools import product + +from ding.model.template import EDAC +from ding.torch_utils import is_differentiable + + +B = 4 +obs_shape = [4, (8, )] +act_shape = [3, (6, )] +args = list(product(*[obs_shape, act_shape])) + +@pytest.mark.unittest +class TestEDAC: + + def output_check(self, model, outputs): + if isinstance(outputs, torch.Tensor): + loss = outputs.sum() + elif isinstance(outputs, list): + loss = sum([t.sum() for t in outputs]) + elif isinstance(outputs, dict): + loss = sum([v.sum() for v in outputs.values()]) + is_differentiable(loss, model) + + @pytest.mark.parametrize('obs_shape, act_shape', args) + def test_EDAC(self, obs_shape, act_shape): + if isinstance(obs_shape, int): + inputs_obs = torch.randn(B, obs_shape) + else: + inputs_obs = torch.randn(B, *obs_shape) + if isinstance(act_shape, int): + inputs_act = torch.randn(B, act_shape) + else: + inputs_act = torch.randn(B, *act_shape) + inputs = {'obs': inputs_obs, 'action': inputs_act} + model = EDAC(obs_shape, act_shape, ensemble_num=2) + + outputs_c = model(inputs, mode='compute_critic') + assert isinstance(outputs_c, dict) + assert outputs_c['q_value'].shape == (2, B) + self.output_check(model.critic, outputs_c) + + if isinstance(obs_shape, int): + inputs = torch.randn(B, obs_shape) + else: + inputs = torch.randn(B, *obs_shape) + outputs_a = model(inputs, mode='compute_actor') + assert isinstance(outputs_a, dict) + if isinstance(act_shape, int): + assert outputs_a['logit'][0].shape == (B, act_shape) + assert outputs_a['logit'][1].shape == (B, act_shape) + elif len(act_shape) == 1: + assert outputs_a['logit'][0].shape == (B, *act_shape) + assert outputs_a['logit'][1].shape == (B, *act_shape) + outputs = {'mu': outputs_a['logit'][0], 'sigma': outputs_a['logit'][1]} + self.output_check(model.actor, outputs) + diff --git a/ding/model/template/tests/test_ngu.py b/ding/model/template/tests/test_ngu.py new file mode 100644 index 0000000000..da0d5dfc2c --- /dev/null +++ b/ding/model/template/tests/test_ngu.py @@ -0,0 +1,49 @@ +import pytest +from itertools import product +import torch +from ding.model.template import NGU +from ding.torch_utils import is_differentiable + +B = 4 +H = 4 +obs_shape = [4, (8, ), (4, 64, 64)] +act_shape = [4, (4, )] +args = list(product(*[obs_shape, act_shape])) + + +@pytest.mark.unittest +class TestNGU: + + def output_check(self, model, outputs): + if isinstance(outputs, torch.Tensor): + loss = outputs.sum() + elif isinstance(outputs, list): + loss = sum([t.sum() for t in outputs]) + elif isinstance(outputs, dict): + loss = sum([v.sum() for v in outputs.values()]) + is_differentiable(loss, model) + + @pytest.mark.parametrize('obs_shape, act_shape', args) + def test_ngu(self, obs_shape, act_shape): + if isinstance(obs_shape, int): + inputs_obs = torch.randn(B, H, obs_shape) + else: + inputs_obs = torch.randn(B, H, *obs_shape) + if isinstance(act_shape, int): + inputs_prev_action = torch.ones(B, act_shape).long() + else: + inputs_prev_action = torch.ones(B, *act_shape).long() + inputs_prev_reward_extrinsic = torch.randn(B, H, 1) + inputs_beta = 2*torch.ones([4,4], dtype=torch.long) + inputs = {'obs': inputs_obs, 'prev_state': None, + 'prev_action': inputs_prev_action, 'prev_reward_extrinsic':inputs_prev_reward_extrinsic, + 'beta': inputs_beta} + + model = NGU(obs_shape, act_shape, collector_env_num=3) + outputs = model(inputs) + assert isinstance(outputs, dict) + if isinstance(act_shape, int): + assert outputs['logit'].shape == (B, act_shape, act_shape) + elif len(act_shape) == 1: + assert outputs['logit'].shape == (B, *act_shape, *act_shape) + self.output_check(model, outputs['logit']) \ No newline at end of file diff --git a/dizoo/atari/entry/atari_dt_main.py b/dizoo/atari/entry/atari_dt_main.py index d6d52eee3b..fbaa0edadb 100644 --- a/dizoo/atari/entry/atari_dt_main.py +++ b/dizoo/atari/entry/atari_dt_main.py @@ -1,6 +1,6 @@ import torch.nn as nn from ditk import logging -from ding.model.template.dt import DecisionTransformer +from ding.model.template.decision_transformer import DecisionTransformer from ding.policy import DTPolicy from ding.envs import SubprocessEnvManagerV2 from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper From c09057492e195b619e88d406d47c1dd6597e400e Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 18 Sep 2023 16:52:57 +0800 Subject: [PATCH 2/6] Add vac test and fix dt test --- .../tests/test_decision_transformer.py | 1 + ding/model/template/tests/test_vac.py | 21 +++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/ding/model/template/tests/test_decision_transformer.py b/ding/model/template/tests/test_decision_transformer.py index 17096dab25..ce0d93403b 100644 --- a/ding/model/template/tests/test_decision_transformer.py +++ b/ding/model/template/tests/test_decision_transformer.py @@ -32,6 +32,7 @@ def test_decision_transformer(action_space, state_encoder): drop_p=0.1, continuous=(action_space == 'continuous') ) + DT_model.configure_optimizers(1.0, 0.0003) is_continuous = True if action_space == 'continuous' else False if state_encoder: diff --git a/ding/model/template/tests/test_vac.py b/ding/model/template/tests/test_vac.py index 4f31e942b6..c44e568e06 100644 --- a/ding/model/template/tests/test_vac.py +++ b/ding/model/template/tests/test_vac.py @@ -3,14 +3,16 @@ import torch from itertools import product -from ding.model import VAC +from ding.model import VAC, DREAMERVAC from ding.torch_utils import is_differentiable from ding.model import ConvEncoder +from easydict import EasyDict +ezD = EasyDict({'action_args_shape': (3, ), 'action_type_shape': 4}) B, C, H, W = 4, 3, 128, 128 obs_shape = [4, (8, ), (4, 64, 64)] -act_args = [[6, 'discrete'], [(3, ), 'continuous'], [[2, 3, 6], 'discrete']] +act_args = [[6, 'discrete'], [(3, ), 'continuous'], [[2, 3, 6], 'discrete'], [ezD, 'hybrid']] # act_args = [[(3, ), True]] args = list(product(*[obs_shape, act_args, [False, True]])) @@ -20,6 +22,8 @@ def output_check(model, outputs, action_shape): loss = sum([t.sum() for t in outputs]) elif np.isscalar(action_shape): loss = outputs.sum() + elif isinstance(action_shape, dict): + loss = outputs.sum() is_differentiable(loss, model) @@ -28,6 +32,8 @@ def model_check(model, inputs): value, logit = outputs['value'], outputs['logit'] if model.action_space == 'continuous': outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum() + elif model.action_space == 'hybrid': + outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum() else: if model.multi_head: outputs = value.sum() + sum([t.sum() for t in logit]) @@ -40,6 +46,8 @@ def model_check(model, inputs): logit = model(inputs, mode='compute_actor')['logit'] if model.action_space == 'continuous': logit = logit['mu'].sum() + logit['sigma'].sum() + elif model.action_space == 'hybrid': + logit = logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum() output_check(model.actor, logit, model.action_shape) for p in model.parameters(): @@ -49,6 +57,15 @@ def model_check(model, inputs): output_check(model.critic, value, 1) +@pytest.mark.unittest +class TestDREAMERVAC: + + def test_DREAMERVAC(self): + obs_shape = 8 + act_shape = 6 + model = DREAMERVAC(obs_shape, act_shape) + + @pytest.mark.unittest @pytest.mark.parametrize('obs_shape, act_args, share_encoder', args) class TestVACGeneral: From a6288b9148a533f35431814cd71815c307493660 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 18 Sep 2023 16:54:15 +0800 Subject: [PATCH 3/6] Add qtrain test and GTrXLDQN test --- ding/model/template/q_learning.py | 2 +- ding/model/template/tests/test_q_learning.py | 10 +++++++++- ding/model/template/tests/test_qtran.py | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 ding/model/template/tests/test_qtran.py diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index e1ddbd6e5f..43c831c174 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -1149,7 +1149,7 @@ def forward(self, x: torch.Tensor) -> Dict: >>> # Init input's Keys: >>> obs_dim, seq_len, bs, action_dim = 128, 64, 32, 4 >>> obs = torch.rand(seq_len, bs, obs_dim) - >>> model = GTrXLDiscreteHead(obs_dim, action_dim) + >>> model = GTrXLDQN(obs_dim, action_dim) >>> outputs = model(obs) >>> assert isinstance(outputs, dict) """ diff --git a/ding/model/template/tests/test_q_learning.py b/ding/model/template/tests/test_q_learning.py index 303481cb1c..ce6a150dd0 100644 --- a/ding/model/template/tests/test_q_learning.py +++ b/ding/model/template/tests/test_q_learning.py @@ -1,7 +1,7 @@ import pytest from itertools import product import torch -from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ +from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ, GTrXLDQN from ding.torch_utils import is_differentiable T, B = 3, 4 @@ -283,3 +283,11 @@ def test_drqn_inference_res_link(self, obs_shape, act_shape): assert all([len(t) == 2 for t in outputs['next_state']]) assert all([t['h'].shape == (1, 1, 64) for t in outputs['next_state']]) self.output_check(model, outputs['logit']) + + @pytest.mark.tmp + def test_GTrXLDQN(self): + obs_dim, seq_len, bs, action_dim = [4,64,64], 64, 32, 4 + obs = torch.rand(seq_len, bs, *obs_dim) + model = GTrXLDQN(obs_dim, action_dim,encoder_hidden_size_list=[16,16,16]) + outputs = model(obs) + assert isinstance(outputs, dict) \ No newline at end of file diff --git a/ding/model/template/tests/test_qtran.py b/ding/model/template/tests/test_qtran.py new file mode 100644 index 0000000000..0d9f201444 --- /dev/null +++ b/ding/model/template/tests/test_qtran.py @@ -0,0 +1,19 @@ +import pytest +from itertools import product +import torch +from ding.model.template import QTran +from ding.torch_utils import is_differentiable + +@pytest.mark.unittest +def test_qtran(): + B = 1 + obs_shape = (1,64,64) + act_shape = 2 + # inputs = { + # 'obs': {'agent_state': torch.randn(B, *obs_shape), + # 'global_state': torch.randn(B, *obs_shape)}, + # 'prev_state': [[torch.randn(1, 1, *obs_shape) for __ in range(1)] for _ in range(1)], + # 'action': torch.randn(B, act_shape) + # } + model = QTran(1, obs_shape, 4*64*64, act_shape, [8, 8, 8], 5) + # model.forward(inputs) \ No newline at end of file From b7ecc380e849f36751c7a006c6d98bb3275cc02d Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 18 Sep 2023 16:54:30 +0800 Subject: [PATCH 4/6] Fix ngu test --- ding/model/template/tests/test_ngu.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ding/model/template/tests/test_ngu.py b/ding/model/template/tests/test_ngu.py index da0d5dfc2c..ec67bc74ad 100644 --- a/ding/model/template/tests/test_ngu.py +++ b/ding/model/template/tests/test_ngu.py @@ -46,4 +46,18 @@ def test_ngu(self, obs_shape, act_shape): assert outputs['logit'].shape == (B, act_shape, act_shape) elif len(act_shape) == 1: assert outputs['logit'].shape == (B, *act_shape, *act_shape) - self.output_check(model, outputs['logit']) \ No newline at end of file + self.output_check(model, outputs['logit']) + + inputs = {'obs': inputs_obs, 'prev_state': None, + 'action': inputs_prev_action, + 'reward': inputs_prev_reward_extrinsic, + 'prev_reward_extrinsic':inputs_prev_reward_extrinsic, + 'beta': inputs_beta} + model = NGU(obs_shape, act_shape, collector_env_num=3) + outputs = model(inputs) + assert isinstance(outputs, dict) + if isinstance(act_shape, int): + assert outputs['logit'].shape == (B, act_shape, act_shape) + elif len(act_shape) == 1: + assert outputs['logit'].shape == (B, *act_shape, *act_shape) + self.output_check(model, outputs['logit']) From 87294d53d3985413c08cb6d7f119b738d83d3c40 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 18 Sep 2023 16:54:46 +0800 Subject: [PATCH 5/6] Add transformer_segment_wrapper test --- ding/model/wrapper/test_model_wrappers.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ding/model/wrapper/test_model_wrappers.py b/ding/model/wrapper/test_model_wrappers.py index 890d1eb1fc..f301dc79d0 100644 --- a/ding/model/wrapper/test_model_wrappers.py +++ b/ding/model/wrapper/test_model_wrappers.py @@ -514,12 +514,23 @@ def test_transformer_input_wrapper(self): model.reset() assert model.obs_memory is None + def test_transformer_segment_wrapper(self): + seq_len, bs, obs_shape = 12, 8, 32 + layer_num, memory_len, emb_dim = 3, 4, 4 + model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) + model = model_wrap(model, wrapper_name='transformer_segment', seq_len=seq_len) + inputs1 = torch.randn((seq_len, bs, obs_shape)) + out = model.forward(inputs1) + info = model.info('info') + info = model.info('x') + def test_transformer_memory_wrapper(self): seq_len, bs, obs_shape = 12, 8, 32 layer_num, memory_len, emb_dim = 3, 4, 4 model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num) model1 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) model2 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs) + model1.show_memory_occupancy() inputs1 = torch.randn((seq_len, bs, obs_shape)) out = model1.forward(inputs1) new_memory1 = model1.memory From 4eebb5319e99f2ff7bc1108b59969ff5d00a30b8 Mon Sep 17 00:00:00 2001 From: luyudong Date: Mon, 18 Sep 2023 16:58:47 +0800 Subject: [PATCH 6/6] Reformat --- ding/model/template/tests/test_acer.py | 3 +- ding/model/template/tests/test_bcq.py | 1 - .../tests/test_decision_transformer.py | 20 ++-- ding/model/template/tests/test_edac.py | 3 +- ding/model/template/tests/test_ngu.py | 25 +++-- ding/model/template/tests/test_q_learning.py | 6 +- ding/model/template/tests/test_qtran.py | 7 +- ding/model/template/tests/test_vac.py | 3 +- ding/model/wrapper/test_model_wrappers.py | 2 +- .../pendulum/config/pendulum_pg_config.py | 2 +- dizoo/dmc2gym/entry/dmc2gym_onppo_main.py | 10 +- dizoo/procgen/entry/coinrun_onppo_main.py | 10 +- dizoo/tabmwp/envs/tabmwp_env.py | 104 +++++++++++------- dizoo/tabmwp/envs/utils.py | 57 ++++++---- setup.py | 2 +- 15 files changed, 152 insertions(+), 103 deletions(-) diff --git a/ding/model/template/tests/test_acer.py b/ding/model/template/tests/test_acer.py index ded6b50bb6..1c3877335a 100644 --- a/ding/model/template/tests/test_acer.py +++ b/ding/model/template/tests/test_acer.py @@ -5,12 +5,12 @@ from ding.model.template import ACER from ding.torch_utils import is_differentiable - B = 4 obs_shape = [4, (8, ), (4, 64, 64)] act_shape = [3, (6, )] args = list(product(*[obs_shape, act_shape])) + @pytest.mark.unittest class TestACER: @@ -39,4 +39,3 @@ def test_ACER(self, obs_shape, act_shape): outputs = {**outputs_a, **outputs_c} loss = sum([v.sum() for v in outputs.values()]) is_differentiable(loss, model) - diff --git a/ding/model/template/tests/test_bcq.py b/ding/model/template/tests/test_bcq.py index 894a92c0b1..101cfd9b9c 100644 --- a/ding/model/template/tests/test_bcq.py +++ b/ding/model/template/tests/test_bcq.py @@ -4,7 +4,6 @@ from ding.model.template import BCQ from ding.torch_utils import is_differentiable - B = 4 obs_shape = [4, (8, )] act_shape = [3, (6, )] diff --git a/ding/model/template/tests/test_decision_transformer.py b/ding/model/template/tests/test_decision_transformer.py index ce0d93403b..4c70877049 100644 --- a/ding/model/template/tests/test_decision_transformer.py +++ b/ding/model/template/tests/test_decision_transformer.py @@ -12,6 +12,7 @@ args = list(product(*[action_space, state_encoder])) args.pop(1) + @pytest.mark.unittest @pytest.mark.parametrize('action_space, state_encoder', args) def test_decision_transformer(action_space, state_encoder): @@ -36,7 +37,7 @@ def test_decision_transformer(action_space, state_encoder): is_continuous = True if action_space == 'continuous' else False if state_encoder: - timesteps = torch.randint(0, 100, [B, 3*T-1, 1], dtype=torch.long) # B x T + timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T else: timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T if isinstance(state_dim, int): @@ -91,15 +92,12 @@ def test_decision_transformer(action_space, state_encoder): if state_encoder: is_differentiable( - action_loss, [ - DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, - DT_model.state_encoder - ] - ) + action_loss, [DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, DT_model.state_encoder] + ) else: is_differentiable( - action_loss, [ - DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg, - DT_model.embed_state - ] - ) + action_loss, [ + DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg, + DT_model.embed_state + ] + ) diff --git a/ding/model/template/tests/test_edac.py b/ding/model/template/tests/test_edac.py index 3569dab488..76f0cca60a 100644 --- a/ding/model/template/tests/test_edac.py +++ b/ding/model/template/tests/test_edac.py @@ -5,12 +5,12 @@ from ding.model.template import EDAC from ding.torch_utils import is_differentiable - B = 4 obs_shape = [4, (8, )] act_shape = [3, (6, )] args = list(product(*[obs_shape, act_shape])) + @pytest.mark.unittest class TestEDAC: @@ -55,4 +55,3 @@ def test_EDAC(self, obs_shape, act_shape): assert outputs_a['logit'][1].shape == (B, *act_shape) outputs = {'mu': outputs_a['logit'][0], 'sigma': outputs_a['logit'][1]} self.output_check(model.actor, outputs) - diff --git a/ding/model/template/tests/test_ngu.py b/ding/model/template/tests/test_ngu.py index ec67bc74ad..ed0e86f194 100644 --- a/ding/model/template/tests/test_ngu.py +++ b/ding/model/template/tests/test_ngu.py @@ -34,10 +34,14 @@ def test_ngu(self, obs_shape, act_shape): else: inputs_prev_action = torch.ones(B, *act_shape).long() inputs_prev_reward_extrinsic = torch.randn(B, H, 1) - inputs_beta = 2*torch.ones([4,4], dtype=torch.long) - inputs = {'obs': inputs_obs, 'prev_state': None, - 'prev_action': inputs_prev_action, 'prev_reward_extrinsic':inputs_prev_reward_extrinsic, - 'beta': inputs_beta} + inputs_beta = 2 * torch.ones([4, 4], dtype=torch.long) + inputs = { + 'obs': inputs_obs, + 'prev_state': None, + 'prev_action': inputs_prev_action, + 'prev_reward_extrinsic': inputs_prev_reward_extrinsic, + 'beta': inputs_beta + } model = NGU(obs_shape, act_shape, collector_env_num=3) outputs = model(inputs) @@ -48,11 +52,14 @@ def test_ngu(self, obs_shape, act_shape): assert outputs['logit'].shape == (B, *act_shape, *act_shape) self.output_check(model, outputs['logit']) - inputs = {'obs': inputs_obs, 'prev_state': None, - 'action': inputs_prev_action, - 'reward': inputs_prev_reward_extrinsic, - 'prev_reward_extrinsic':inputs_prev_reward_extrinsic, - 'beta': inputs_beta} + inputs = { + 'obs': inputs_obs, + 'prev_state': None, + 'action': inputs_prev_action, + 'reward': inputs_prev_reward_extrinsic, + 'prev_reward_extrinsic': inputs_prev_reward_extrinsic, + 'beta': inputs_beta + } model = NGU(obs_shape, act_shape, collector_env_num=3) outputs = model(inputs) assert isinstance(outputs, dict) diff --git a/ding/model/template/tests/test_q_learning.py b/ding/model/template/tests/test_q_learning.py index ce6a150dd0..2307a372d1 100644 --- a/ding/model/template/tests/test_q_learning.py +++ b/ding/model/template/tests/test_q_learning.py @@ -286,8 +286,8 @@ def test_drqn_inference_res_link(self, obs_shape, act_shape): @pytest.mark.tmp def test_GTrXLDQN(self): - obs_dim, seq_len, bs, action_dim = [4,64,64], 64, 32, 4 + obs_dim, seq_len, bs, action_dim = [4, 64, 64], 64, 32, 4 obs = torch.rand(seq_len, bs, *obs_dim) - model = GTrXLDQN(obs_dim, action_dim,encoder_hidden_size_list=[16,16,16]) + model = GTrXLDQN(obs_dim, action_dim, encoder_hidden_size_list=[16, 16, 16]) outputs = model(obs) - assert isinstance(outputs, dict) \ No newline at end of file + assert isinstance(outputs, dict) diff --git a/ding/model/template/tests/test_qtran.py b/ding/model/template/tests/test_qtran.py index 0d9f201444..2e44fc9b2b 100644 --- a/ding/model/template/tests/test_qtran.py +++ b/ding/model/template/tests/test_qtran.py @@ -4,10 +4,11 @@ from ding.model.template import QTran from ding.torch_utils import is_differentiable + @pytest.mark.unittest def test_qtran(): B = 1 - obs_shape = (1,64,64) + obs_shape = (1, 64, 64) act_shape = 2 # inputs = { # 'obs': {'agent_state': torch.randn(B, *obs_shape), @@ -15,5 +16,5 @@ def test_qtran(): # 'prev_state': [[torch.randn(1, 1, *obs_shape) for __ in range(1)] for _ in range(1)], # 'action': torch.randn(B, act_shape) # } - model = QTran(1, obs_shape, 4*64*64, act_shape, [8, 8, 8], 5) - # model.forward(inputs) \ No newline at end of file + model = QTran(1, obs_shape, 4 * 64 * 64, act_shape, [8, 8, 8], 5) + # model.forward(inputs) diff --git a/ding/model/template/tests/test_vac.py b/ding/model/template/tests/test_vac.py index c44e568e06..85e44e8a4c 100644 --- a/ding/model/template/tests/test_vac.py +++ b/ding/model/template/tests/test_vac.py @@ -33,7 +33,8 @@ def model_check(model, inputs): if model.action_space == 'continuous': outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum() elif model.action_space == 'hybrid': - outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum() + outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum( + ) + logit['action_args']['sigma'].sum() else: if model.multi_head: outputs = value.sum() + sum([t.sum() for t in logit]) diff --git a/ding/model/wrapper/test_model_wrappers.py b/ding/model/wrapper/test_model_wrappers.py index f301dc79d0..1da744d36f 100644 --- a/ding/model/wrapper/test_model_wrappers.py +++ b/ding/model/wrapper/test_model_wrappers.py @@ -523,7 +523,7 @@ def test_transformer_segment_wrapper(self): out = model.forward(inputs1) info = model.info('info') info = model.info('x') - + def test_transformer_memory_wrapper(self): seq_len, bs, obs_shape = 12, 8, 32 layer_num, memory_len, emb_dim = 3, 4, 4 diff --git a/dizoo/classic_control/pendulum/config/pendulum_pg_config.py b/dizoo/classic_control/pendulum/config/pendulum_pg_config.py index d448dee002..b512548398 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_pg_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_pg_config.py @@ -18,7 +18,7 @@ action_shape=1, ), learn=dict( - batch_size=400 + batch_size=400, learning_rate=0.001, entropy_weight=0.001, ), diff --git a/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py b/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py index fe1ca9f761..412a46577a 100644 --- a/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py +++ b/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py @@ -71,10 +71,12 @@ def wrapped_dmc2gym_env(cfg): width=default_cfg["width"], frame_skip=default_cfg["frame_skip"] ), - cfg={'env_wrapper': [ - lambda env: Dmc2GymWrapper(env, default_cfg), - lambda env: EvalEpisodeReturnWrapper(env), - ]} + cfg={ + 'env_wrapper': [ + lambda env: Dmc2GymWrapper(env, default_cfg), + lambda env: EvalEpisodeReturnWrapper(env), + ] + } ) diff --git a/dizoo/procgen/entry/coinrun_onppo_main.py b/dizoo/procgen/entry/coinrun_onppo_main.py index c80eb37c92..ca132b1fa7 100644 --- a/dizoo/procgen/entry/coinrun_onppo_main.py +++ b/dizoo/procgen/entry/coinrun_onppo_main.py @@ -60,10 +60,12 @@ def wrapped_procgen_env(cfg): num_levels=default_cfg.num_levels ) if default_cfg.control_level else gym.make('procgen:procgen-' + default_cfg.env_id + '-v0', start_level=0, num_levels=1), - cfg={'env_wrapper': [ - lambda env: CoinrunWrapper(env, default_cfg), - lambda env: EvalEpisodeReturnWrapper(env), - ]} + cfg={ + 'env_wrapper': [ + lambda env: CoinrunWrapper(env, default_cfg), + lambda env: EvalEpisodeReturnWrapper(env), + ] + } ) diff --git a/dizoo/tabmwp/envs/tabmwp_env.py b/dizoo/tabmwp/envs/tabmwp_env.py index 4da1fdfe98..fe32e02b35 100644 --- a/dizoo/tabmwp/envs/tabmwp_env.py +++ b/dizoo/tabmwp/envs/tabmwp_env.py @@ -26,9 +26,7 @@ def __init__(self, cfg): openai.api_key = cfg.api_key self.observation_space = gym.spaces.Dict() self.action_space = gym.spaces.Discrete(self.cfg.cand_number * (self.cfg.cand_number - 1)) - self.reward_space = gym.spaces.Box( - low=-1, high=1, shape=(1,), dtype=np.float32 - ) + self.reward_space = gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32) self.correct_num = 0 # Initialize language model if needed. @@ -61,8 +59,12 @@ def get_output(self, inp: str) -> str: inputs = TabMWP.tokenizer(inp + " [MASK].", return_tensors="pt") inputs = TabMWP.tokenizer.build_inputs_for_generation(inputs, max_gen_length=512) inputs = {key: value.cuda() for key, value in inputs.items()} - outputs = TabMWP.model.generate(**inputs, max_length=512, eos_token_id=TabMWP.tokenizer.eop_token_id, - pad_token_id=TabMWP.tokenizer.eos_token_id) + outputs = TabMWP.model.generate( + **inputs, + max_length=512, + eos_token_id=TabMWP.tokenizer.eop_token_id, + pad_token_id=TabMWP.tokenizer.eos_token_id + ) outputs = TabMWP.tokenizer.decode(outputs[0].tolist()) t0 = outputs.find('<|startofpiece|>') + 16 @@ -78,29 +80,37 @@ def reset(self) -> dict: if TabMWP.model is not None: TabMWP.model = TabMWP.model.cuda() if self.enable_replay: - self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', - '19270', '23713', '17209', '33379', '34987', '11177'] + self.cand_pids = [ + '32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713', + '17209', '33379', '34987', '11177' + ] if self.cfg.seed == 0: # train - self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', - '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162', - '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993', - '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964', - '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', - '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', - '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', - '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', - '22918', '31680', '15024', '24607', '26930'] + self.train_pids = [ + '14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135', + '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', + '4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', + '18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', + '17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', + '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329', + '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024', + '24607', '26930' + ] model_io_path = 'dizoo/tabmwp/data/model_in_out_train.txt' if not os.path.exists(model_io_path): - os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O ' - + model_io_path + ' --no-check-certificate') + os.system( + f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O ' + + model_io_path + ' --no-check-certificate' + ) else: - self.train_pids = ['21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044', - '19492', '31882', '11991', '27594', '7637', '15394', '7666', '5177', '33761', - '13703', '29105'] + self.train_pids = [ + '21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044', '19492', '31882', + '11991', '27594', '7637', '15394', '7666', '5177', '33761', '13703', '29105' + ] model_io_path = 'dizoo/tabmwp/data/model_in_out_eval.txt' - os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O ' - + model_io_path + ' --no-check-certificate') + os.system( + f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O ' + model_io_path + + ' --no-check-certificate' + ) self.cfg.cand_number = len(self.cand_pids) self.cfg.train_number = len(self.train_pids) @@ -135,8 +145,19 @@ def search_answer(self, pid, pids): raise ValueError('item does not exists.') def parse_all_answers(self): - self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713', '17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492'] - self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930'] + self.cand_pids = [ + '32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713', + '17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492' + ] + self.train_pids = [ + '14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135', + '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970', + '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504', + '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245', + '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056', + '7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903', + '11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930' + ] self.problem_id = 0 self.cfg.train_number = len(self.train_pids) n = len(self.cand_pids) @@ -221,24 +242,25 @@ def __repr__(self) -> str: if __name__ == '__main__': from easydict import EasyDict - env_cfg = EasyDict(dict( - cand_number=16, - train_number=20, - engine='text-davinci-002', - temperature=0., - max_tokens=512, - top_p=1., - frequency_penalty=0., - presence_penalty=0., - option_inds=["A", "B", "C", "D", "E", "F"], - api_key='xxx', - prompt_format='TQ-A', - enable_replay=True, - seed=0, - )) + env_cfg = EasyDict( + dict( + cand_number=16, + train_number=20, + engine='text-davinci-002', + temperature=0., + max_tokens=512, + top_p=1., + frequency_penalty=0., + presence_penalty=0., + option_inds=["A", "B", "C", "D", "E", "F"], + api_key='xxx', + prompt_format='TQ-A', + enable_replay=True, + seed=0, + ) + ) env = TabMWP(env_cfg) env.seed(0) env.reset() env.parse_all_answers() env.search_answer('22976', ['32889', '8044']) - diff --git a/dizoo/tabmwp/envs/utils.py b/dizoo/tabmwp/envs/utils.py index f1f74a3f0c..c97c183935 100644 --- a/dizoo/tabmwp/envs/utils.py +++ b/dizoo/tabmwp/envs/utils.py @@ -31,7 +31,12 @@ def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0. return out -def calc_rwkv(model: transformers.RwkvForCausalLM, tokenizer: transformers.AutoTokenizer, prompt: str, max_len: int = 10) -> str: +def calc_rwkv( + model: transformers.RwkvForCausalLM, + tokenizer: transformers.AutoTokenizer, + prompt: str, + max_len: int = 10 +) -> str: # Use RWKV to generate sentence. orig_len = len(prompt) inputs = tokenizer(prompt, return_tensors="pt").to('cuda') @@ -53,8 +58,13 @@ def calc_internlm(model, tokenizer, prompt: str, args): inputs = tokenizer(prompt, return_tensors="pt") for k, v in inputs.items(): inputs[k] = v.cuda() - gen_kwargs = {"max_length": args.max_tokens, "top_p": args.top_p, "temperature": args.temperature, "do_sample": True, - "repetition_penalty": args.frequency_penalty} + gen_kwargs = { + "max_length": args.max_tokens, + "top_p": args.top_p, + "temperature": args.temperature, + "do_sample": True, + "repetition_penalty": args.frequency_penalty + } output = model.generate(**inputs, **gen_kwargs) output = tokenizer.decode(output) return output @@ -69,8 +79,10 @@ def load_data(args: dict) -> tuple: os.mkdir(data_root) if not os.path.exists(os.path.join(data_root, f'problems_train.json')): - os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' - + os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate') + os.system( + f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' + + os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate' + ) problems = json.load(open(os.path.join(data_root, f'problems_train.json'))) pids = list(problems.keys()) @@ -81,24 +93,30 @@ def load_data(args: dict) -> tuple: def get_gpt3_output(prompt: str, args: dict) -> str: - return call_gpt3(args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty, - args.presence_penalty) + return call_gpt3( + args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty, + args.presence_penalty + ) @lru_cache(maxsize=10000) -def call_gpt3(engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, - frequency_penalty: float, presence_penalty: float) -> str: +def call_gpt3( + engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, frequency_penalty: float, + presence_penalty: float +) -> str: patience = 100 while True: try: - response = openai.Completion.create(engine=engine, - prompt=prompt, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - stop=["\n"]) + response = openai.Completion.create( + engine=engine, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stop=["\n"] + ) output = response["choices"][0]["text"].strip() break except Exception: @@ -146,8 +164,9 @@ def get_solution_text(problem: dict) -> str: return solution -def create_one_example(format: str, table: str, question: str, answer: str, - solution: str, test_example:bool = True) -> str: +def create_one_example( + format: str, table: str, question: str, answer: str, solution: str, test_example: bool = True +) -> str: # Using template to generate one prompt example. input_format, output_format = format.split("-") # e.g., "TQ-A" diff --git a/setup.py b/setup.py index 823165d6e7..6b349a99aa 100644 --- a/setup.py +++ b/setup.py @@ -107,7 +107,7 @@ 'numpy-stl', 'numba>=0.53.0', ], - 'video':[ + 'video': [ 'moviepy', 'imageio[ffmpeg]', ],