From 88fd50f804763899f59e469ad5be8a0d98b8853a Mon Sep 17 00:00:00 2001 From: zjowowen Date: Thu, 28 Mar 2024 21:15:08 +0800 Subject: [PATCH 1/6] fix complex obs demo for ppo pipeline --- ding/example/ppo_with_complex_obs.py | 7 ++++++- .../framework/middleware/functional/advantage_estimator.py | 3 +++ ding/utils/__init__.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ding/example/ppo_with_complex_obs.py b/ding/example/ppo_with_complex_obs.py index a05875ba29..2563532a37 100644 --- a/ding/example/ppo_with_complex_obs.py +++ b/ding/example/ppo_with_complex_obs.py @@ -30,7 +30,12 @@ cuda=True, action_space='discrete', model=dict( - obs_shape=None, + obs_shape=dict( + key_0=dict(k1=(), k2=()), + key_1=(5, 10), + key_2=(10, 10, 3), + key_3=(2, ), + ), action_shape=2, action_space='discrete', critic_head_hidden_size=138, diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index cb80089fe2..24e7f6a338 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -8,6 +8,7 @@ from ding.rl_utils import gae, gae_data, get_train_sample from ding.framework import task from ding.utils.data import ttorch_collate +from ding.utils.dict_helper import convert_easy_dict_to_dict from ding.torch_utils import to_device if TYPE_CHECKING: @@ -33,9 +34,11 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non # Unify the shape of obs and action obs_shape = cfg['policy']['model']['obs_shape'] obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \ + else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \ else torch.Size(torch.tensor(obs_shape).unsqueeze(0)) action_shape = cfg['policy']['model']['action_shape'] action_shape = torch.Size(torch.tensor(action_shape)) if isinstance(action_shape, list) \ + else ttorch.size.Size(action_shape) if isinstance(action_shape, dict) \ else torch.Size(torch.tensor(action_shape).unsqueeze(0)) def _gae(ctx: "OnlineRLContext"): diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 6c2cb1c326..36d60b1c29 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -5,6 +5,7 @@ LimitedSpaceContainer, deep_merge_dicts, set_pkg_seed, flatten_dict, one_time_warning, split_data_generator, \ RunningMeanStd, make_key_as_identifier, remove_illegal_item from .design_helper import SingletonMetaclass +from .dict_helper import convert_easy_dict_to_dict from .file_helper import read_file, save_file, remove_file from .import_helper import try_import_ceph, try_import_mc, try_import_link, import_module, try_import_redis, \ try_import_rediscluster From 9f16bf175ef98d9f6dd6105f2f85c8c273e51fe7 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 2 Apr 2024 18:57:59 +0800 Subject: [PATCH 2/6] polish code --- .../middleware/functional/advantage_estimator.py | 4 ---- ding/utils/dict_helper.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) create mode 100644 ding/utils/dict_helper.py diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index 24e7f6a338..c365f4d7a3 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -36,10 +36,6 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \ else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \ else torch.Size(torch.tensor(obs_shape).unsqueeze(0)) - action_shape = cfg['policy']['model']['action_shape'] - action_shape = torch.Size(torch.tensor(action_shape)) if isinstance(action_shape, list) \ - else ttorch.size.Size(action_shape) if isinstance(action_shape, dict) \ - else torch.Size(torch.tensor(action_shape).unsqueeze(0)) def _gae(ctx: "OnlineRLContext"): """ diff --git a/ding/utils/dict_helper.py b/ding/utils/dict_helper.py new file mode 100644 index 0000000000..876dacffbb --- /dev/null +++ b/ding/utils/dict_helper.py @@ -0,0 +1,12 @@ +from easydict import EasyDict + +def convert_easy_dict_to_dict(easy_dict: EasyDict) -> dict: + """ + Overview: + Convert an EasyDict object to a dict object recursively. + Arguments: + - easy_dict (:obj:`EasyDict`): The EasyDict object to be converted. + Returns: + - dict: The converted dict object. + """ + return {k: convert_easy_dict_to_dict(v) if isinstance(v, EasyDict) else v for k, v in easy_dict.items()} From d7ccc9f734ac581121c98b3a64b0dc13f7286ddd Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 2 Apr 2024 19:30:36 +0800 Subject: [PATCH 3/6] polish code --- ding/utils/dict_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ding/utils/dict_helper.py b/ding/utils/dict_helper.py index 876dacffbb..e6a61799e6 100644 --- a/ding/utils/dict_helper.py +++ b/ding/utils/dict_helper.py @@ -1,5 +1,6 @@ from easydict import EasyDict + def convert_easy_dict_to_dict(easy_dict: EasyDict) -> dict: """ Overview: From cdd5ff4d66d9617d5bb6ce30c12a3b22ad4ab408 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 2 Apr 2024 20:27:44 +0800 Subject: [PATCH 4/6] prevent wandb version error --- .github/workflows/unit_test.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 771dc596e3..3f7c255bbd 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -26,6 +26,11 @@ jobs: python -m pip install . python -m pip install ".[test,k8s]" python -m pip install transformers + if python --version | grep -q "Python 3.7"; then + pip install wandb==0.16.4 + else + echo "Python version is not 3.7, skipping wandb installation" + fi ./ding/scripts/install-k8s-tools.sh make unittest - name: Upload coverage to Codecov From 78a27395dad445bc1061abac4be5719990d9296c Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 2 Apr 2024 20:35:25 +0800 Subject: [PATCH 5/6] polish code --- .github/workflows/unit_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 3f7c255bbd..63d07c3e9b 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -27,7 +27,7 @@ jobs: python -m pip install ".[test,k8s]" python -m pip install transformers if python --version | grep -q "Python 3.7"; then - pip install wandb==0.16.4 + python -m pip install wandb==0.16.4 else echo "Python version is not 3.7, skipping wandb installation" fi From 9cc2dc8873154751cd4e78a56260f62743743b97 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Tue, 2 Apr 2024 20:36:25 +0800 Subject: [PATCH 6/6] polish code --- .github/workflows/unit_test.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 63d07c3e9b..ef9038610c 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -60,5 +60,10 @@ jobs: python -m pip install . python -m pip install ".[test,k8s]" python -m pip install transformers + if python --version | grep -q "Python 3.7"; then + python -m pip install wandb==0.16.4 + else + echo "Python version is not 3.7, skipping wandb installation" + fi ./ding/scripts/install-k8s-tools.sh make benchmark