Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zjow): add wandb logger features; fix relative bugs for wandb online logger #579

Merged
merged 63 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
e571f50
td3 fix
zjowowen Nov 4, 2022
a614e3f
Merge branch 'opendilab:main' into benchmark-2
zjowowen Dec 19, 2022
9060c53
Add benchmark config file.
zjowowen Dec 19, 2022
731a2ad
Merge branch 'opendilab:main' into benchmark-2
zjowowen Jan 11, 2023
82a4944
add main
zjowowen Jan 15, 2023
ad616ff
fix
zjowowen Jan 15, 2023
f1aba9c
fix
zjowowen Jan 15, 2023
448daa1
add feature to wandb;fix bugs
zjowowen Feb 10, 2023
1e18f25
merge main
zjowowen Feb 10, 2023
8de9b9e
format code
zjowowen Feb 10, 2023
f36bec8
remove files.
zjowowen Feb 10, 2023
e5ec188
polish code
zjowowen Feb 10, 2023
46f64e6
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Feb 22, 2023
e520359
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Feb 24, 2023
6a9a565
fix td3 policy
zjowowen Feb 24, 2023
0222c04
Add td3
zjowowen Feb 28, 2023
929776b
Add td3 env
zjowowen Feb 28, 2023
4fba3b9
Add td3 env
zjowowen Feb 28, 2023
0257ae9
polish code
zjowowen Feb 28, 2023
cccd585
polish code
zjowowen Feb 28, 2023
d7f272e
polish code
zjowowen Feb 28, 2023
902f9b0
polish code
zjowowen Feb 28, 2023
17ba3a6
polish code
zjowowen Feb 28, 2023
21dcc8b
polish code
zjowowen Feb 28, 2023
bb0df37
polish code
zjowowen Feb 28, 2023
d01558d
polish code
zjowowen Feb 28, 2023
60f47b6
polish code
zjowowen Feb 28, 2023
511d71e
polish code
zjowowen Feb 28, 2023
6a9fd45
polish code
zjowowen Feb 28, 2023
d5573e9
polish code
zjowowen Feb 28, 2023
b7c2011
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 1, 2023
0a167f1
Merge branch 'opendilab:main' into benchmark-3
zjowowen Mar 2, 2023
3906543
fix data type error for mujoco
zjowowen Mar 2, 2023
e665493
polish code
zjowowen Mar 2, 2023
88f5181
polish code
zjowowen Mar 2, 2023
693a4cb
Add features
zjowowen Mar 2, 2023
e6bd0c5
fix base env manager readyimage
zjowowen Mar 3, 2023
cdb9928
polish code
zjowowen Mar 3, 2023
3015a92
remove NoReturn
zjowowen Mar 3, 2023
6e7041b
remove NoReturn
zjowowen Mar 3, 2023
c97a8d4
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 6, 2023
fe415b2
format code
zjowowen Mar 7, 2023
8f808b2
merge from main
zjowowen Mar 7, 2023
3432754
format code
zjowowen Mar 7, 2023
3f6ef3d
polish code
zjowowen Mar 7, 2023
535fd77
polish code
zjowowen Mar 7, 2023
4271610
fix logger
zjowowen Mar 7, 2023
ba0979b
format code
zjowowen Mar 7, 2023
3c19c2c
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 7, 2023
82826e2
format code
zjowowen Mar 7, 2023
da0dd12
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 7, 2023
bb35f90
Merge branch 'main' of https://github.com/zjowowen/DI-engine into ben…
zjowowen Mar 10, 2023
5340658
change api for ckpt; polish code
zjowowen Mar 10, 2023
2d3f6c8
polish code
zjowowen Mar 13, 2023
2e8292c
merge from main
zjowowen Mar 13, 2023
2f883d7
format code
zjowowen Mar 13, 2023
3c15c84
polish code
zjowowen Mar 13, 2023
6ce1421
fix load bug
zjowowen Mar 13, 2023
eac9434
fix bug
zjowowen Mar 13, 2023
6fda31b
fix dtype error
zjowowen Mar 14, 2023
6b9def4
polish code
zjowowen Mar 15, 2023
6f49d0a
polish code
zjowowen Mar 15, 2023
4c69cb0
Polish code
zjowowen Mar 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ding/framework/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ class OnlineRLContext(Context):
# eval
eval_value: float = -np.inf
last_eval_iter: int = -1
last_eval_value: int = -np.inf
eval_output: List = dataclasses.field(default_factory=dict)

def __post_init__(self):
# This method is called just after __init__ method. Here, concretely speaking,
# this method is called just after the object initialize its fields.
# We use this method here to keep the fields needed for each iteration.
self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter')
self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter', 'last_eval_value')


@dataclasses.dataclass
Expand Down
3 changes: 2 additions & 1 deletion ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Callable, List, Tuple, Any
from easydict import EasyDict
from functools import reduce
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

import treetensor.torch as ttorch
from ding.envs import BaseEnvManager
from ding.policy import Policy
Expand Down Expand Up @@ -77,7 +78,7 @@ def _inference(ctx: "OnlineRLContext"):

obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
inference_output = policy.forward(obs, **ctx.collect_kwargs)
ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD
ctx.action = np.array([to_ndarray(v['action']) for v in inference_output.values()]) # TBD
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
ctx.inference_output = inference_output

return _inference
Expand Down
4 changes: 3 additions & 1 deletion ding/framework/middleware/functional/ctx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ def final_ctx_saver(name: str) -> Callable:

def _save(ctx: "Context"):
if task.finish:
# make sure the items to be recorded are all kept in the context
with open(os.path.join(name, 'result.pkl'), 'wb') as f:
final_data = {
'total_step': ctx.total_step,
'train_iter': ctx.train_iter,
'eval_value': ctx.eval_value,
'last_eval_iter': ctx.last_eval_iter,
'eval_value': ctx.last_eval_value,
}
if ctx.has_attr('env_step'):
final_data['env_step'] = ctx.env_step
Expand Down
6 changes: 4 additions & 2 deletions ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
eval_monitor.update_video(env.ready_imgs)
eval_monitor.update_output(inference_output)
output = [v for v in inference_output.values()]
action = [to_ndarray(v['action']) for v in output] # TBD
action = np.array([to_ndarray(v['action']) for v in output]) # TBD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

timesteps = env.step(action)
for timestep in timesteps:
env_id = timestep.env_id.item()
Expand All @@ -282,7 +282,8 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
raise TypeError("not supported ctx type: {}".format(type(ctx)))
ctx.last_eval_iter = ctx.train_iter
ctx.eval_value = episode_return
ctx.eval_output = {'reward': episode_return}
ctx.last_eval_value = ctx.eval_value
ctx.eval_output = {'episode_return': episode_return}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
ctx.eval_output['episode_info'] = episode_info
Expand Down Expand Up @@ -374,6 +375,7 @@ def _evaluate(ctx: "OnlineRLContext"):
)
ctx.last_eval_iter = ctx.train_iter
ctx.eval_value = episode_return_mean
ctx.last_eval_value = ctx.eval_value
ctx.eval_output = {'episode_return': episode_return}
episode_info = eval_monitor.get_episode_info()
if episode_info is not None:
Expand Down
224 changes: 149 additions & 75 deletions ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def wandb_online_logger(
return task.void()
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
if metric_list is None:
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy"]
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
if anonymous:
Expand Down Expand Up @@ -176,18 +176,32 @@ def wandb_online_logger(
)

def _plot(ctx: "OnlineRLContext"):
info_for_logging = {}

if not cfg.plot_logger:
one_time_warning(
"If you want to use wandb to visualize the result, please set plot_logger = True in the config."
)
return
for metric in metric_list:
if metric in ctx.train_output[0]:
metric_value = np.mean([item[metric] for item in ctx.train_output])
wandb.log({metric: metric_value})
metric_value_list = []
for item in ctx.train_output:
if isinstance(item[metric], torch.Tensor):
metric_value_list.append(item[metric].cpu().detach().numpy())
else:
metric_value_list.append(item[metric])
metric_value = np.mean(metric_value_list)
info_for_logging.update({metric: metric_value})

if ctx.eval_value != -np.inf:
wandb.log({"reward": ctx.eval_value, "train iter": ctx.train_iter, "env step": ctx.env_step})
info_for_logging.update(
{
"episode return mean": ctx.eval_value,
"train iter": ctx.train_iter,
"env step": ctx.env_step
}
)

eval_output = ctx.eval_output['output']
episode_return = ctx.eval_output['episode_return']
Expand All @@ -202,26 +216,37 @@ def _plot(ctx: "OnlineRLContext"):
file_list.append(p)
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn)))
video_path = os.path.join(record_path, file_list[-2])
wandb.log({"video": wandb.Video(video_path, format="mp4")})
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")})

action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif"))
return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif"))
if cfg.action_logger in ['q_value', 'action probability']:
if isinstance(eval_output, tnp.ndarray):
action_prob = softmax(eval_output.logit)
else:
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_prob[1])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
)
ani.save(action_path, writer='pillow')
wandb.log({cfg.action_logger: wandb.Video(action_path, format="gif")})
plt.clf()
if cfg.action_logger:
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"):
if isinstance(eval_output, tnp.ndarray):
action_prob = softmax(eval_output.logit)
else:
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_prob[1])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
)
ani.save(action_path, writer='pillow')
info_for_logging.update({"action": wandb.Video(action_path, format="gif")})

elif all(['action' in v for v in eval_output[0]]):
for i, action_trajectory in enumerate(eval_output):
fig, ax = plt.subplots()
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)])
steps = fig_data[:, 0]
actions = fig_data[:, 1:]
plt.ylim([-1, 1])
for j in range(actions.shape[1]):
ax.scatter(steps, actions[:, j])
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig})

if cfg.return_logger:
fig, ax = plt.subplots()
Expand All @@ -232,45 +257,63 @@ def _plot(ctx: "OnlineRLContext"):
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
wandb.log({"return distribution": wandb.Video(return_path, format="gif")})
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")})

wandb.log(data=info_for_logging, step=ctx.env_step)
plt.clf()

return _plot


def wandb_offline_logger(
cfg: EasyDict,
env: BaseEnvManagerV2,
model: torch.nn.Module,
record_path: str,
datasetpath: str,
cfg: Union[str, EasyDict] = 'default',
metric_list: Optional[List[str]] = None,
env: Optional[BaseEnvManagerV2] = None,
model: Optional[torch.nn.Module] = None,
anonymous: bool = False
) -> Callable:
'''
Overview:
Wandb visualizer to track the experiment.
Arguments:
- cfg (:obj:`EasyDict`): Config, a dict of following settings:
- record_path: string. The path to save the replay of simulation.
- record_path (:obj:`str`): The path to save the replay of simulation.
- cfg (:obj:`Union[str, EasyDict]`): Config, a dict of following settings:
- gradient_logger: boolean. Whether to track the gradient.
- plot_logger: boolean. Whether to track the metrics like reward and loss.
- action_logger: `q_value` or `action probability`.
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Model.
- datasetpath (:obj:`str`): The path of offline dataset.
- model (:obj:`nn.Module`): Policy neural network model.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not.
The anonymous mode allows visualization of data without wandb count.
'''

if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
if metric_list is None:
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
if anonymous:
wandb.init(anonymous="must")
else:
wandb.init()
if cfg == 'default':
cfg = EasyDict(
dict(
gradient_logger=False,
plot_logger=True,
video_logger=False,
action_logger=False,
return_logger=False,
)
)
# The visualizer is called to save the replay of the simulation
# which will be uploaded to wandb later
env.enable_save_replay(replay_path=cfg.record_path)
if env is not None:
env.enable_save_replay(replay_path=record_path)
if cfg.gradient_logger:
wandb.watch(model)
else:
Expand Down Expand Up @@ -326,60 +369,91 @@ def _vis_dataset(datasetpath: str):
if cfg.vis_dataset is True:
_vis_dataset(datasetpath)

def _plot(ctx: "OfflineRLContext"):
def _plot(ctx: "OnlineRLContext"):
info_for_logging = {}

if not cfg.plot_logger:
one_time_warning(
"If you want to use wandb to visualize the result, please set plot_logger = True in the config."
)
return
for metric in metric_list:
if metric in ctx.train_output:
metric_value = ctx.train_output[metric]
wandb.log({metric: metric_value})
if metric in ctx.train_output[0]:
metric_value_list = []
for item in ctx.train_output:
if isinstance(item[metric], torch.Tensor):
metric_value_list.append(item[metric].cpu().detach().numpy())
else:
metric_value_list.append(item[metric])
metric_value = np.mean(metric_value_list)
info_for_logging.update({metric: metric_value})

if ctx.eval_value != -np.inf:
wandb.log({"reward": ctx.eval_value, "train iter": ctx.train_iter})
info_for_logging.update(
{
"episode return mean": ctx.eval_value,
"train iter": ctx.train_iter,
"env step": ctx.env_step
}
)

eval_output = ctx.eval_output['output']
episode_return = ctx.eval_output['episode_return']
if 'logit' in eval_output[0]:
action_value = [to_ndarray(F.softmax(v['logit'], dim=-1)) for v in eval_output]

file_list = []
for p in os.listdir(cfg.record_path):
if os.path.splitext(p)[-1] == ".mp4":
file_list.append(p)
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(cfg.record_path, fn)))

video_path = os.path.join(cfg.record_path, file_list[-2])
action_path = os.path.join(cfg.record_path, (str(ctx.train_iter) + "_action.gif"))
return_path = os.path.join(cfg.record_path, (str(ctx.train_iter) + "_return.gif"))
if cfg.action_logger in ['q_value', 'action probability']:
episode_return = np.array(episode_return)
if len(episode_return.shape) == 2:
episode_return = episode_return.squeeze(1)

if cfg.video_logger:
file_list = []
for p in os.listdir(record_path):
if os.path.splitext(p)[-1] == ".mp4":
file_list.append(p)
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn)))
video_path = os.path.join(record_path, file_list[-2])
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")})

action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif"))
return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif"))
if cfg.action_logger:
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"):
if isinstance(eval_output, tnp.ndarray):
action_prob = softmax(eval_output.logit)
else:
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_prob[1])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
)
ani.save(action_path, writer='pillow')
info_for_logging.update({"action": wandb.Video(action_path, format="gif")})

elif all(['action' in v for v in eval_output[0]]):
for i, action_trajectory in enumerate(eval_output):
fig, ax = plt.subplots()
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)])
steps = fig_data[:, 0]
actions = fig_data[:, 1:]
plt.ylim([-1, 1])
for j in range(actions.shape[1]):
ax.scatter(steps, actions[:, j])
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig})

if cfg.return_logger:
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_value[0])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_value, ln), blit=True, save_count=len(action_value)
)
ani.save(action_path, writer='pillow')
wandb.log({cfg.action_logger: wandb.Video(action_path, format="gif")})
plt.clf()

fig, ax = plt.subplots()
ax = plt.gca()
ax.set_ylim([0, 1])
hist, x_dim = return_distribution(episode_return)
assert len(hist) == len(x_dim)
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
wandb.log(
{
"video": wandb.Video(video_path, format="mp4"),
"return distribution": wandb.Video(return_path, format="gif")
}
)
ax = plt.gca()
ax.set_ylim([0, 1])
hist, x_dim = return_distribution(episode_return)
assert len(hist) == len(x_dim)
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")})

wandb.log(data=info_for_logging, step=ctx.env_step)
plt.clf()

return _plot
Loading