Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(zjow): fix complex obs demo for ppo pipeline #786

Merged
merged 7 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
if python --version | grep -q "Python 3.7"; then
python -m pip install wandb==0.16.4
else
echo "Python version is not 3.7, skipping wandb installation"
fi
./ding/scripts/install-k8s-tools.sh
make unittest
- name: Upload coverage to Codecov
Expand Down Expand Up @@ -55,5 +60,10 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install transformers
if python --version | grep -q "Python 3.7"; then
python -m pip install wandb==0.16.4
else
echo "Python version is not 3.7, skipping wandb installation"
fi
./ding/scripts/install-k8s-tools.sh
make benchmark
7 changes: 6 additions & 1 deletion ding/example/ppo_with_complex_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
cuda=True,
action_space='discrete',
model=dict(
obs_shape=None,
obs_shape=dict(
key_0=dict(k1=(), k2=()),
key_1=(5, 10),
key_2=(10, 10, 3),
key_3=(2, ),
),
action_shape=2,
action_space='discrete',
critic_head_hidden_size=138,
Expand Down
5 changes: 2 additions & 3 deletions ding/framework/middleware/functional/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ding.rl_utils import gae, gae_data, get_train_sample
from ding.framework import task
from ding.utils.data import ttorch_collate
from ding.utils.dict_helper import convert_easy_dict_to_dict
from ding.torch_utils import to_device

if TYPE_CHECKING:
Expand All @@ -33,10 +34,8 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non
# Unify the shape of obs and action
obs_shape = cfg['policy']['model']['obs_shape']
obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
else ttorch.size.Size(convert_easy_dict_to_dict(obs_shape)) if isinstance(obs_shape, dict) \
else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
action_shape = cfg['policy']['model']['action_shape']
action_shape = torch.Size(torch.tensor(action_shape)) if isinstance(action_shape, list) \
else torch.Size(torch.tensor(action_shape).unsqueeze(0))

def _gae(ctx: "OnlineRLContext"):
"""
Expand Down
1 change: 1 addition & 0 deletions ding/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LimitedSpaceContainer, deep_merge_dicts, set_pkg_seed, flatten_dict, one_time_warning, split_data_generator, \
RunningMeanStd, make_key_as_identifier, remove_illegal_item
from .design_helper import SingletonMetaclass
from .dict_helper import convert_easy_dict_to_dict
from .file_helper import read_file, save_file, remove_file
from .import_helper import try_import_ceph, try_import_mc, try_import_link, import_module, try_import_redis, \
try_import_rediscluster
Expand Down
13 changes: 13 additions & 0 deletions ding/utils/dict_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from easydict import EasyDict


def convert_easy_dict_to_dict(easy_dict: EasyDict) -> dict:
"""
Overview:
Convert an EasyDict object to a dict object recursively.
Arguments:
- easy_dict (:obj:`EasyDict`): The EasyDict object to be converted.
Returns:
- dict: The converted dict object.
"""
return {k: convert_easy_dict_to_dict(v) if isinstance(v, EasyDict) else v for k, v in easy_dict.items()}
Loading