Skip to content

Commit

Permalink
fix(nyz): fix logger assertion and unittest bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Mar 21, 2023
1 parent 503d273 commit 405191d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def wandb_online_logger(
else:
if not isinstance(cfg, EasyDict):
cfg = EasyDict(cfg)
assert tuple(cfg.keys()) == ("gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger")
assert set(cfg.keys()
) == set(["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger"])
assert all(value in [True, False] for value in cfg.values())

# The visualizer is called to save the replay of the simulation
Expand Down
4 changes: 2 additions & 2 deletions ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_wandb_online_logger():
dict(
gradient_logger=True,
plot_logger=True,
action_logger='action probability',
action_logger=True,
return_logger=True,
video_logger=True,
)
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_wandb_online_logger_gradient():
@pytest.mark.tmp
def test_wandb_offline_logger(mocker):
record_path = './video_pendulum_cql'
cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger='action probability', vis_dataset=True))
cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger=True, vis_dataset=True))
env = TheEnvClass()
ctx = OnlineRLContext()
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
Expand Down

0 comments on commit 405191d

Please sign in to comment.