Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(luyd): fix dt new pipeline of mujoco #754

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:
def producer(queue, dataset, batch_size, device):
torch.set_num_threads(4)
nonlocal stream
idx_iter = iter(range(len(dataset)))
idx_iter = iter(range(len(dataset) - batch_size))

if len(dataset) < batch_size:
logging.warning('batch_size is too large!!!!')
with torch.cuda.stream(stream):
while True:
if queue.full():
Expand All @@ -201,7 +204,7 @@ def producer(queue, dataset, batch_size, device):
start_idx = next(idx_iter)
except StopIteration:
del idx_iter
idx_iter = iter(range(len(dataset)))
idx_iter = iter(range(len(dataset) - batch_size))
start_idx = next(idx_iter)
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
data = [[i[j] for i in data] for j in range(len(data[0]))]
Expand All @@ -211,7 +214,7 @@ def producer(queue, dataset, batch_size, device):
queue = Queue(maxsize=50)
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
producer_thread = Thread(
target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer'
target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer'
)

def _fetch(ctx: "OfflineRLContext"):
Expand Down
6 changes: 4 additions & 2 deletions ding/policy/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
if self._basic_discrete_env:
actions = actions.to(torch.long)
actions = actions.squeeze(-1)
action_target = torch.clone(actions).detach().to(self._device)
action_target = torch.clone(actions).detach().to(self._device)

if self._atari_env:
state_preds, action_preds, return_preds = self._learn_model.forward(
Expand Down Expand Up @@ -291,7 +291,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
else:
self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device)
self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device)
self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]

if self.t[i] <= self.context_len:
Expand Down Expand Up @@ -328,6 +328,8 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
act[i] = torch.multinomial(probs[i], num_samples=1)
else:
act = torch.argmax(logits, axis=1).unsqueeze(1)
else:
act = logits
for i in data_id:
self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
self.t[i] += 1
Expand Down
48 changes: 1 addition & 47 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,10 @@ def __init__(self, cfg: dict) -> None:

self.trajectories = paths

# calculate min len of traj, state mean and variance
# and returns_to_go for all traj
min_len = 10 ** 6
# calculate state mean and variance and returns_to_go for all traj
states = []
for traj in self.trajectories:
traj_len = traj['observations'].shape[0]
min_len = min(min_len, traj_len)
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
Expand All @@ -408,46 +405,6 @@ def __init__(self, cfg: dict) -> None:
for traj in self.trajectories:
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std

# self.trajectories = {}
# exp_key = ['rewards', 'terminals', 'timeouts']
# for k in dataset.keys():
# logging.info(f'Load {k} data.')
# if k in exp_key:
# self.trajectories[k] = np.expand_dims(dataset[k][:], axis=1)
# else:
# self.trajectories[k] = dataset[k][:]

# # used for input normalization
# states = np.concatenate(self.trajectories['observations'], axis=0)
# self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

# # normalize states
# self.trajectories['observations'] = (self.trajectories['observations'] - self.state_mean) / self.state_std
# self.trajectories['returns_to_go'] = discount_cumsum(self.trajectories['rewards'], 1.0) / rtg_scale

# datalen = self.trajectories['rewards'].shape[0]

# use_timeouts = False
# if 'timeouts' in dataset:
# use_timeouts = True

# data_ = collections.defaultdict(list)
# episode_step = 0
# trajectories_tmp = []
# for i in range(datalen):
# done_bool = bool(self.trajectories['terminals'][i])
# final_timestep = (episode_step == 1000-1)
# for k in ['observations', 'actions', 'returns_to_go']:
# data_[k].append(self.trajectories[k][i])
# if done_bool or final_timestep:
# episode_step = 0
# episode_data = {}
# for k in data_:
# episode_data[k] = np.array(data_[k])
# trajectories_tmp.append(episode_data)
# data_ = collections.defaultdict(list)
# episode_step += 1
# self.trajectories = trajectories_tmp
elif 'pkl' in dataset_path:
if 'dqn' in dataset_path:
# load dataset
Expand Down Expand Up @@ -493,11 +450,8 @@ def __init__(self, cfg: dict) -> None:
with open(dataset_path, 'rb') as f:
self.trajectories = pickle.load(f)

min_len = 10 ** 6
states = []
for traj in self.trajectories:
traj_len = traj['observations'].shape[0]
min_len = min(min_len, traj_len)
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
Expand Down
2 changes: 1 addition & 1 deletion dizoo/d4rl/config/hopper_expert_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
context_len=20,
data_dir_prefix='d4rl/hopper_expert-v2.pkl',
),
policy=dict(
Expand Down
6 changes: 3 additions & 3 deletions dizoo/d4rl/config/hopper_medium_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
data_dir_prefix='d4rl/hopper_medium-v2.pkl',
context_len=20,
data_dir_prefix='d4rl/hopper_medium_expert-v2.pkl',
),
policy=dict(
cuda=True,
Expand Down Expand Up @@ -47,7 +47,7 @@
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

Expand Down
6 changes: 3 additions & 3 deletions dizoo/d4rl/config/hopper_medium_expert_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy

hopper_dt_config = dict(
exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt_seed0',
exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt',
env=dict(
env_id='Hopper-v3',
collector_env_num=1,
Expand All @@ -14,8 +14,8 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
data_dir_prefix='d4rl/hopper_medium_expert-v2.pkl',
context_len=20,
data_dir_prefix='d4rl/hopper_medium_expert.pkl',
),
policy=dict(
cuda=True,
Expand Down
14 changes: 7 additions & 7 deletions dizoo/d4rl/config/walker2d_medium_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from copy import deepcopy

walk2d_dt_config = dict(
exp_name='dt_log/d4rl/walk2d/walk2d_medium_dt_seed0',
exp_name='dt_log/d4rl/walk2d/walk2d_medium_dt',
env=dict(
env_id='Walk2d-v3',
env_id='Walker2d-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
Expand All @@ -14,16 +14,16 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
data_dir_prefix='d4rl/walk2d_medium-v2.pkl',
context_len=20,
data_dir_prefix='d4rl/walker2d_medium-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=5000,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='Walk2d-v3',
env_name='Walker2d-v3',
rtg_target=5000, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
Expand All @@ -32,8 +32,8 @@
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=11,
act_dim=3,
state_dim=17,
act_dim=6,
n_blocks=3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify hyper-parameters if they are the same as default configs

h_dim=128,
context_len=20,
Expand Down
12 changes: 5 additions & 7 deletions dizoo/d4rl/entry/d4rl_dt_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem, offline_logger, termination_checker
from ding.utils import set_pkg_seed
from dizoo.d4rl.envs import D4RLEnv
from dizoo.d4rl.config.hopper_medium_dt_config import main_config, create_config
Expand All @@ -32,16 +32,14 @@ def main():

dataset = create_dataset(cfg)
# env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name)
env_data_stats = dataset.get_state_stats()
cfg.policy.state_mean, cfg.policy.state_std = np.array(env_data_stats['state_mean']
), np.array(env_data_stats['state_std'])
cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats()
model = DecisionTransformer(**cfg.policy.model)
policy = DTPolicy(cfg.policy, model=model)
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(offline_data_fetcher_from_mem(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=1e5))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(termination_checker(max_train_iter=5e4))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(offline_logger())
task.run()

Expand Down
Loading