Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A3C distributed modes #340

Merged
merged 39 commits into from
May 18, 2019
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
bf66248
remove check_compatibility
kengz May 17, 2019
0a922a7
add and register GlobalAdam
kengz May 17, 2019
3f49674
move make_global_nets to net_util
kengz May 17, 2019
dae5989
Merge remote-tracking branch 'origin/v4-dev' into globalopt
kengz May 17, 2019
074bd0c
enforce net naming convention
kengz May 17, 2019
2b29495
move global_nets init to net_util
kengz May 17, 2019
5800dcf
simplify net and global_net init
kengz May 17, 2019
ec92f45
use global adam for a3cpong
kengz May 17, 2019
07311bc
lower lr
kengz May 17, 2019
657196b
override with global sync
kengz May 17, 2019
ab8db09
eval less
kengz May 17, 2019
3736c84
increase lr
kengz May 17, 2019
7a249be
sync net after training
kengz May 17, 2019
700780a
set and sync hogwild
kengz May 17, 2019
0e2883a
lower lr
kengz May 17, 2019
4636653
fix ppo misnaming critic
kengz May 17, 2019
1d44716
remove curren sync_global_nets
kengz May 18, 2019
2ca515d
do grad push and param pull inside training_step
kengz May 18, 2019
67c6b3f
guard set_global_nets, pass global_net into training_step
kengz May 18, 2019
099d435
fix typo
kengz May 18, 2019
aa44829
move local grad to cpu first
kengz May 18, 2019
f725bcb
global_net to CPU
kengz May 18, 2019
00f3384
revert
kengz May 18, 2019
82918fa
rename to train_step
kengz May 18, 2019
2bed5f4
allow for synced and shared distributed modes
kengz May 18, 2019
cb9b1ad
add basic compat check
kengz May 18, 2019
64782e8
name a3c
kengz May 18, 2019
b248e90
add a2c pong spec
kengz May 18, 2019
fdae0e1
cleanup is_venv setting
kengz May 18, 2019
6bf48f2
remove useless NUM_EVAL_EPI
kengz May 18, 2019
3453a91
divide max_tick by max session if distributed
kengz May 18, 2019
7e6957a
rename resources to search_resources for clarity
kengz May 18, 2019
2e617de
add a3c atari spec
kengz May 18, 2019
be5bc12
add GlobalRMSProp
kengz May 18, 2019
b2cbd1b
update deprecated warning and add_trace methods
kengz May 18, 2019
981f740
improve run log
kengz May 18, 2019
e739b6f
fix a3c shared hogwild cuda id assignment to offset 0
kengz May 18, 2019
bdaac5e
disable a3c gpu with synced in spec
kengz May 18, 2019
fe78439
add flaky to vizdoom test
kengz May 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add and register GlobalAdam
  • Loading branch information
kengz committed May 17, 2019
commit 0a922a7c89f068167b8b39115f8741f682aa90c2
6 changes: 4 additions & 2 deletions slm_lab/agent/net/net_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from functools import partial, wraps
from slm_lab import ROOT_DIR
from slm_lab.lib import logger, util
from slm_lab.lib import logger, optimizer, util
import os
import pydash as ps
import torch
import torch.nn as nn

logger = logger.get_logger(__name__)

# register custom torch.optim
setattr(torch.optim, 'GlobalAdam', optimizer.GlobalAdam)


class NoOpLRScheduler:
'''Symbolic LRScheduler class for API consistency'''
56 changes: 56 additions & 0 deletions slm_lab/lib/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import math
import torch


class GlobalAdam(torch.optim.Adam):
'''
Global Adam algorithm with shared states for Hogwild.
Adapted from https://github.com/ikostrikov/pytorch-a3c/blob/master/my_optim.py (MIT)
'''

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
super().__init__(params, lr, betas, eps, weight_decay)

for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = torch.zeros(1)
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()

def share_memory(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'].share_memory_()
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step'].item()
bias_correction2 = 1 - beta2 ** state['step'].item()
step_size = group['lr'] * math.sqrt(
bias_correction2) / bias_correction1

p.data.addcdiv_(-step_size, exp_avg, denom)
return loss