Skip to content

Commit

Permalink
feature(lxy): add offline wandb (#523)
Browse files Browse the repository at this point in the history
* fix import path error in lunarlander

* add offline wandb

* delete the installation of pickle, which is involved in the standard python

* polish offline wandb and add unitest

* polish style and package
  • Loading branch information
karroyan authored Oct 28, 2022
1 parent efc0cd8 commit ae74916
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 29 deletions.
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .collector import inferencer, rolloutor, TransitionList
from .evaluator import interaction_evaluator
from .termination_checker import termination_checker
from .logger import online_logger, offline_logger, wandb_online_logger
from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger
from .ctx_helper import final_ctx_saver

# algorithm
Expand Down
197 changes: 171 additions & 26 deletions ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
from easydict import EasyDict
from matplotlib import pyplot as plt
from matplotlib import animation
from matplotlib import ticker as mtick
from torch.nn import functional as F
from sklearn.manifold import TSNE
import numpy as np
import torch
import wandb
import h5py
import pickle
from ding.envs import BaseEnvManagerV2
from ding.utils import DistributedWriter
from ding.torch_utils import to_ndarray
Expand All @@ -15,6 +20,28 @@
from ding.framework import OnlineRLContext, OfflineRLContext


def action_prob(num, action_prob, ln):
ax = plt.gca()
ax.set_ylim([0, 1])
for rect, x in zip(ln, action_prob[num]):
rect.set_height(x)
return ln


def return_prob(num, return_prob, ln):
return ln


def return_distribution(reward):
num = len(reward)
max_return = max(reward)
min_return = min(reward)
hist, bins = np.histogram(reward, bins=np.linspace(min_return - 50, max_return + 50, 6))
gap = (max_return - min_return + 100) / 5
x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)]
return hist / num, x_dim


def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable:
writer = DistributedWriter.get_instance()
last_train_show_iter = -1
Expand Down Expand Up @@ -77,7 +104,7 @@ def _logger(ctx: "OfflineRLContext"):
return _logger


def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model) -> Callable:
def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module) -> Callable:
'''
Overview:
Wandb visualizer to track the experiment.
Expand Down Expand Up @@ -106,25 +133,6 @@ def wandb_online_logger(cfg: EasyDict, env: BaseEnvManagerV2, model) -> Callable
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config."
)

def _action_prob(num, action_prob, ln):
ax = plt.gca()
ax.set_ylim([0, 1])
for rect, x in zip(ln, action_prob[num]):
rect.set_height(x)
return ln

def _return_prob(num, return_prob, ln):
return ln

def _return_distribution(reward):
num = len(reward)
max_return = max(reward)
min_return = min(reward)
hist, bins = np.histogram(reward, bins=np.linspace(min_return - 50, max_return + 50, 6))
gap = (max_return - min_return + 100) / 5
x_dim = [str(min_return - 50 + gap * x) for x in range(5)]
return hist / num, x_dim

def _plot(ctx: "OnlineRLContext"):
if not cfg.plot_logger:
one_time_warning(
Expand All @@ -137,7 +145,7 @@ def _plot(ctx: "OnlineRLContext"):
wandb.log({metric: metric_value})

if ctx.eval_value != -np.inf:
wandb.log({"reward": ctx.eval_value})
wandb.log({"reward": ctx.eval_value, "train iter": ctx.train_iter})

eval_output = ctx.eval_output['output']
eval_reward = ctx.eval_output['reward']
Expand All @@ -153,26 +161,163 @@ def _plot(ctx: "OnlineRLContext"):
video_path = os.path.join(cfg.record_path, file_list[-2])
action_path = os.path.join(cfg.record_path, (str(ctx.env_step) + "_action.gif"))
return_path = os.path.join(cfg.record_path, (str(ctx.env_step) + "_return.gif"))
if cfg.action_logger == 'q_value' or 'action probability':
if cfg.action_logger in ['q_value', 'action probability']:
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_value[0])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_value, ln), blit=True, save_count=len(action_value)
)
ani.save(action_path, writer='pillow')
wandb.log({cfg.action_logger: wandb.Video(action_path, format="gif")})
plt.clf()

fig, ax = plt.subplots()
ax = plt.gca()
ax.set_ylim([0, 1])
hist, x_dim = return_distribution(eval_reward)
assert len(hist) == len(x_dim)
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
wandb.log(
{
"video": wandb.Video(video_path, format="mp4"),
"return distribution": wandb.Video(return_path, format="gif")
}
)

return _plot


def wandb_offline_logger(cfg: EasyDict, env: BaseEnvManagerV2, model: torch.nn.Module, datasetpath: str) -> Callable:
'''
Overview:
Wandb visualizer to track the experiment.
Arguments:
- cfg (:obj:`EasyDict`): Config, a dict of following settings:
- record_path: string. The path to save the replay of simulation.
- gradient_logger: boolean. Whether to track the gradient.
- plot_logger: boolean. Whether to track the metrics like reward and loss.
- action_logger: `q_value` or `action probability`.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Model.
- datasetpath (:obj:`str`): The path of offline dataset.
'''

color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
wandb.init()
# The visualizer is called to save the replay of the simulation
# which will be uploaded to wandb later
env.enable_save_replay(replay_path=cfg.record_path)
if cfg.gradient_logger:
wandb.watch(model)
else:
one_time_warning(
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config."
)

def _vis_dataset(datasetpath: str):
assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5']
if os.path.splitext(datasetpath)[-1] == '.pkl':
with open(datasetpath, 'rb') as f:
data = pickle.load(f)
obs = []
action = []
reward = []
for i in range(len(data)):
obs.extend(data[i]['observations'])
action.extend(data[i]['actions'])
reward.extend(data[i]['rewards'])
elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']:
with h5py.File(datasetpath, 'r') as f:
obs = f['obs'][()]
action = f['action'][()]
reward = f['reward'][()]

cmap = plt.cm.hsv
obs = np.array(obs)
reward = np.array(reward)
obs_action = np.hstack((obs, np.array(action)))
reward = reward / (max(reward) - min(reward))

embedded_obs = TSNE(n_components=2).fit_transform(obs)
embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action)
x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0)
embedded_obs = embedded_obs / (x_max - x_min)

x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0)
embedded_obs_action = embedded_obs_action / (x_max - x_min)

fig = plt.figure()
f, axes = plt.subplots(nrows=1, ncols=3)

axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward))
axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action))
axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward))
axes[0].set_title('state-reward')
axes[1].set_title('state-action')
axes[2].set_title('stateAction-reward')
plt.savefig('dataset.png')

wandb.log({"dataset": wandb.Image("dataset.png")})

if cfg.vis_dataset is True:
_vis_dataset(datasetpath)

def _plot(ctx: "OfflineRLContext"):
if not cfg.plot_logger:
one_time_warning(
"If you want to use wandb to visualize the result, please set plot_logger = True in the config."
)
return
for metric in metric_list:
if metric in ctx.train_output:
metric_value = ctx.train_output[metric]
wandb.log({metric: metric_value})

if ctx.eval_value != -np.inf:
wandb.log({"reward": ctx.eval_value, "train iter": ctx.train_iter})

eval_output = ctx.eval_output['output']
eval_reward = ctx.eval_output['reward']
if 'logit' in eval_output[0]:
action_value = [to_ndarray(F.softmax(v['logit'], dim=-1)) for v in eval_output]

file_list = []
for p in os.listdir(cfg.record_path):
if os.path.splitext(p)[-1] == ".mp4":
file_list.append(p)
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(cfg.record_path, fn)))

video_path = os.path.join(cfg.record_path, file_list[-2])
action_path = os.path.join(cfg.record_path, (str(ctx.train_iter) + "_action.gif"))
return_path = os.path.join(cfg.record_path, (str(ctx.train_iter) + "_return.gif"))
if cfg.action_logger in ['q_value', 'action probability']:
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_value[0])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, _action_prob, fargs=(action_value, ln), blit=True, save_count=len(action_value)
fig, action_prob, fargs=(action_value, ln), blit=True, save_count=len(action_value)
)
ani.save(action_path, writer='pillow')
wandb.log({cfg.action_logger: wandb.Video(action_path, format="gif")})
plt.close()
plt.clf()

fig, ax = plt.subplots()
ax = plt.gca()
ax.set_ylim([0, 1])
hist, x_dim = _return_distribution(eval_reward)
hist, x_dim = return_distribution(eval_reward)
assert len(hist) == len(x_dim)
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, _return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
wandb.log(
{
Expand Down
68 changes: 66 additions & 2 deletions ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import pytest
import shutil
import wandb
import h5py
import torch.nn as nn
from unittest.mock import Mock, patch
from ding.utils import DistributedWriter
from ding.framework.middleware.tests import MockPolicy, CONFIG
from ding.framework import OnlineRLContext, OfflineRLContext
from ding.framework.middleware.functional import online_logger, offline_logger, wandb_online_logger
from ding.framework.middleware.functional import online_logger, offline_logger, wandb_online_logger, \
wandb_offline_logger

test_folder = "test_exp"
test_path = path.join(os.getcwd(), test_folder)
Expand Down Expand Up @@ -171,6 +173,18 @@ def enable_save_replay(self, replay_path):
return


class TheObsDataClass(Mock):

def __getitem__(self, index):
return [[1, 1, 1]] * 50


class The1DDataClass(Mock):

def __getitem__(self, index):
return [[1]] * 50


@pytest.mark.other # due to no api key in github now
def test_wandb_online_logger():

Expand All @@ -187,7 +201,8 @@ def test_wandb_online_logger():

def mock_metric_logger(metric_dict):
metric_list = [
"q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution"
"q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution",
"train iter"
]
assert set(metric_dict.keys()) < set(metric_list)

Expand All @@ -204,3 +219,52 @@ def test_wandb_online_logger_gradient():

test_wandb_online_logger_metric()
test_wandb_online_logger_gradient()


@pytest.mark.other # due to no api key in github now
def test_wandb_offline_logger(mocker):

cfg = EasyDict(
dict(
record_path='./video_pendulum_cql',
gradient_logger=True,
plot_logger=True,
action_logger='action probability',
vis_dataset=True
)
)
env = TheEnvClass()
ctx = OnlineRLContext()
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
model = TheModelClass()
wandb.init(config=cfg)

def mock_metric_logger(metric_dict):
metric_list = [
"q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution",
"train iter", 'dataset'
]
assert set(metric_dict.keys()) < set(metric_list)

def mock_gradient_logger(input_model):
assert input_model == model

def mock_image_logger(imagepath):
assert os.path.splitext(imagepath)[-1] == '.png'

def test_wandb_offline_logger_gradient():
cfg.vis_dataset = False
with patch.object(wandb, 'watch', new=mock_gradient_logger):
wandb_offline_logger(cfg, env, model, 'dataset.h5')(ctx)

def test_wandb_offline_logger_dataset():
cfg.vis_dataset = True
m = mocker.MagicMock()
m.__enter__.return_value = {'obs': TheObsDataClass(), 'action': The1DDataClass(), 'reward': The1DDataClass()}
with patch.object(wandb, 'log', new=mock_metric_logger):
with patch.object(wandb, 'Image', new=mock_image_logger):
mocker.patch('h5py.File', return_value=m)
wandb_offline_logger(cfg, env, model, 'dataset.h5')(ctx)

test_wandb_offline_logger_gradient()
test_wandb_offline_logger_dataset()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
'wandb',
'matplotlib',
'MarkupSafe==2.0.1', # compatibility
'h5py',
'scikit-learn',
],
extras_require={
'test': [
Expand Down

0 comments on commit ae74916

Please sign in to comment.