Skip to content

Commit

Permalink
polish(lxy): add wandb anonymous mode parameter (#528)
Browse files Browse the repository at this point in the history
* add wandb anonymous mode parameter

* polish style

* polish style

* polish unittest for wandb

* polish unittest for wandb

* polish unittest for wandb

* change anonymous mode to must in logger
  • Loading branch information
karroyan authored Oct 31, 2022
1 parent ae74916 commit f5f219b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
26 changes: 22 additions & 4 deletions ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def _logger(ctx: "OfflineRLContext"):
return _logger


def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module) -> Callable:
def wandb_online_logger(
cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module, anonymous: bool = False
) -> Callable:
'''
Overview:
Wandb visualizer to track the experiment.
Expand All @@ -116,13 +118,18 @@ def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Mo
- action_logger: `q_value` or `action probability`.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Model.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not.
The anonymous mode allows visualization of data without wandb count.
'''

color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
wandb.init()
if anonymous:
wandb.init(anonymous="must")
else:
wandb.init()
# The visualizer is called to save the replay of the simulation
# which will be uploaded to wandb later
env.enable_save_replay(replay_path=cfg.record_path)
Expand Down Expand Up @@ -192,7 +199,13 @@ def _plot(ctx: "OnlineRLContext"):
return _plot


def wandb_offline_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module, datasetpath: str) -> Callable:
def wandb_offline_logger(
cfg: EasyDict,
env: BaseEnvManagerV2,
model: torch.nn.Module,
datasetpath: str,
anonymous: bool = False
) -> Callable:
'''
Overview:
Wandb visualizer to track the experiment.
Expand All @@ -205,13 +218,18 @@ def wandb_offline_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.M
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Model.
- datasetpath (:obj:`str`): The path of offline dataset.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not.
The anonymous mode allows visualization of data without wandb count.
'''

color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
wandb.init()
if anonymous:
wandb.init(anonymous="must")
else:
wandb.init()
# The visualizer is called to save the replay of the simulation
# which will be uploaded to wandb later
env.enable_save_replay(replay_path=cfg.record_path)
Expand Down
16 changes: 8 additions & 8 deletions ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __getitem__(self, index):
return [[1]] * 50


@pytest.mark.other # due to no api key in github now
@pytest.mark.unittest
def test_wandb_online_logger():

cfg = EasyDict(
Expand All @@ -197,7 +197,7 @@ def test_wandb_online_logger():
ctx = OnlineRLContext()
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
model = TheModelClass()
wandb.init(config=cfg)
wandb.init(config=cfg, anonymous="must")

def mock_metric_logger(metric_dict):
metric_list = [
Expand All @@ -211,17 +211,17 @@ def mock_gradient_logger(input_model):

def test_wandb_online_logger_metric():
with patch.object(wandb, 'log', new=mock_metric_logger):
wandb_online_logger(cfg, env, model)(ctx)
wandb_online_logger(cfg, env, model, anonymous=True)(ctx)

def test_wandb_online_logger_gradient():
with patch.object(wandb, 'watch', new=mock_gradient_logger):
wandb_online_logger(cfg, env, model)(ctx)
wandb_online_logger(cfg, env, model, anonymous=True)(ctx)

test_wandb_online_logger_metric()
test_wandb_online_logger_gradient()


@pytest.mark.other # due to no api key in github now
@pytest.mark.unittest
def test_wandb_offline_logger(mocker):

cfg = EasyDict(
Expand All @@ -237,7 +237,7 @@ def test_wandb_offline_logger(mocker):
ctx = OnlineRLContext()
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
model = TheModelClass()
wandb.init(config=cfg)
wandb.init(config=cfg, anonymous="must")

def mock_metric_logger(metric_dict):
metric_list = [
Expand All @@ -255,7 +255,7 @@ def mock_image_logger(imagepath):
def test_wandb_offline_logger_gradient():
cfg.vis_dataset = False
with patch.object(wandb, 'watch', new=mock_gradient_logger):
wandb_offline_logger(cfg, env, model, 'dataset.h5')(ctx)
wandb_offline_logger(cfg, env, model, 'dataset.h5', anonymous=True)(ctx)

def test_wandb_offline_logger_dataset():
cfg.vis_dataset = True
Expand All @@ -264,7 +264,7 @@ def test_wandb_offline_logger_dataset():
with patch.object(wandb, 'log', new=mock_metric_logger):
with patch.object(wandb, 'Image', new=mock_image_logger):
mocker.patch('h5py.File', return_value=m)
wandb_offline_logger(cfg, env, model, 'dataset.h5')(ctx)
wandb_offline_logger(cfg, env, model, 'dataset.h5', anonymous=True)(ctx)

test_wandb_offline_logger_gradient()
test_wandb_offline_logger_dataset()

0 comments on commit f5f219b

Please sign in to comment.