From f5f219b8d8d7ae0569dac40e2b75c693e7c0a91d Mon Sep 17 00:00:00 2001 From: karroyan Date: Mon, 31 Oct 2022 11:44:37 +0800 Subject: [PATCH] polish(lxy): add wandb anonymous mode parameter (#528) * add wandb anonymous mode parameter * polish style * polish style * polish unittest for wandb * polish unittest for wandb * polish unittest for wandb * change anonymous mode to must in logger --- .../framework/middleware/functional/logger.py | 26 ++++++++++++++++--- .../framework/middleware/tests/test_logger.py | 16 ++++++------ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index b0c9f02866..d13b4a9c2f 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -104,7 +104,9 @@ def _logger(ctx: "OfflineRLContext"): return _logger -def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module) -> Callable: +def wandb_online_logger( + cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module, anonymous: bool = False +) -> Callable: ''' Overview: Wandb visualizer to track the experiment. @@ -116,13 +118,18 @@ def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Mo - action_logger: `q_value` or `action probability`. - env (:obj:`BaseEnvManagerV2`): Evaluator environment. - model (:obj:`nn.Module`): Model. + - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. + The anonymous mode allows visualization of data without wandb count. ''' color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] metric_list = ["q_value", "target q_value", "loss", "lr", "entropy"] # Initialize wandb with default settings # Settings can be covered by calling wandb.init() at the top of the script - wandb.init() + if anonymous: + wandb.init(anonymous="must") + else: + wandb.init() # The visualizer is called to save the replay of the simulation # which will be uploaded to wandb later env.enable_save_replay(replay_path=cfg.record_path) @@ -192,7 +199,13 @@ def _plot(ctx: "OnlineRLContext"): return _plot -def wandb_offline_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module, datasetpath: str) -> Callable: +def wandb_offline_logger( + cfg: EasyDict, + env: BaseEnvManagerV2, + model: torch.nn.Module, + datasetpath: str, + anonymous: bool = False +) -> Callable: ''' Overview: Wandb visualizer to track the experiment. @@ -205,13 +218,18 @@ def wandb_offline_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.M - env (:obj:`BaseEnvManagerV2`): Evaluator environment. - model (:obj:`nn.Module`): Model. - datasetpath (:obj:`str`): The path of offline dataset. + - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. + The anonymous mode allows visualization of data without wandb count. ''' color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] # Initialize wandb with default settings # Settings can be covered by calling wandb.init() at the top of the script - wandb.init() + if anonymous: + wandb.init(anonymous="must") + else: + wandb.init() # The visualizer is called to save the replay of the simulation # which will be uploaded to wandb later env.enable_save_replay(replay_path=cfg.record_path) diff --git a/ding/framework/middleware/tests/test_logger.py b/ding/framework/middleware/tests/test_logger.py index bc70f43e8b..cbeab92889 100644 --- a/ding/framework/middleware/tests/test_logger.py +++ b/ding/framework/middleware/tests/test_logger.py @@ -185,7 +185,7 @@ def __getitem__(self, index): return [[1]] * 50 -@pytest.mark.other # due to no api key in github now +@pytest.mark.unittest def test_wandb_online_logger(): cfg = EasyDict( @@ -197,7 +197,7 @@ def test_wandb_online_logger(): ctx = OnlineRLContext() ctx.train_output = [{'reward': 1, 'q_value': [1.0]}] model = TheModelClass() - wandb.init(config=cfg) + wandb.init(config=cfg, anonymous="must") def mock_metric_logger(metric_dict): metric_list = [ @@ -211,17 +211,17 @@ def mock_gradient_logger(input_model): def test_wandb_online_logger_metric(): with patch.object(wandb, 'log', new=mock_metric_logger): - wandb_online_logger(cfg, env, model)(ctx) + wandb_online_logger(cfg, env, model, anonymous=True)(ctx) def test_wandb_online_logger_gradient(): with patch.object(wandb, 'watch', new=mock_gradient_logger): - wandb_online_logger(cfg, env, model)(ctx) + wandb_online_logger(cfg, env, model, anonymous=True)(ctx) test_wandb_online_logger_metric() test_wandb_online_logger_gradient() -@pytest.mark.other # due to no api key in github now +@pytest.mark.unittest def test_wandb_offline_logger(mocker): cfg = EasyDict( @@ -237,7 +237,7 @@ def test_wandb_offline_logger(mocker): ctx = OnlineRLContext() ctx.train_output = [{'reward': 1, 'q_value': [1.0]}] model = TheModelClass() - wandb.init(config=cfg) + wandb.init(config=cfg, anonymous="must") def mock_metric_logger(metric_dict): metric_list = [ @@ -255,7 +255,7 @@ def mock_image_logger(imagepath): def test_wandb_offline_logger_gradient(): cfg.vis_dataset = False with patch.object(wandb, 'watch', new=mock_gradient_logger): - wandb_offline_logger(cfg, env, model, 'dataset.h5')(ctx) + wandb_offline_logger(cfg, env, model, 'dataset.h5', anonymous=True)(ctx) def test_wandb_offline_logger_dataset(): cfg.vis_dataset = True @@ -264,7 +264,7 @@ def test_wandb_offline_logger_dataset(): with patch.object(wandb, 'log', new=mock_metric_logger): with patch.object(wandb, 'Image', new=mock_image_logger): mocker.patch('h5py.File', return_value=m) - wandb_offline_logger(cfg, env, model, 'dataset.h5')(ctx) + wandb_offline_logger(cfg, env, model, 'dataset.h5', anonymous=True)(ctx) test_wandb_offline_logger_gradient() test_wandb_offline_logger_dataset()