This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] Policy-based RL Strategy #3650
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
8b47630
Initiate RL strategy
ultmaster 725e8a0
Simple RL strategy
ultmaster 8c5ad65
Update RL implementation and tests
ultmaster eea6888
Add docstring and dependency
ultmaster c1ad4b3
Update requirements and lint
ultmaster a4c482b
Skip tests on windows and MacOS
ultmaster f5ac5d0
Update comments for async
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,5 @@ pytorch-lightning >= 1.1.1, < 1.2 | |
onnx | ||
peewee | ||
graphviz | ||
gym | ||
tianshou >= 0.4.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,5 @@ keras == 2.1.6 | |
onnx | ||
peewee | ||
graphviz | ||
gym | ||
tianshou >= 0.4.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# This file might cause import error for those who didn't install RL-related dependencies | ||
|
||
import logging | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from gym import spaces | ||
from tianshou.data import to_torch | ||
|
||
from .utils import get_targeted_model | ||
from ..graph import ModelStatus | ||
from ..execution import submit_models, wait_models | ||
|
||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class ModelEvaluationEnv(gym.Env): | ||
def __init__(self, base_model, mutators, search_space): | ||
self.base_model = base_model | ||
self.mutators = mutators | ||
self.search_space = search_space | ||
self.ss_keys = list(self.search_space.keys()) | ||
self.action_dim = max(map(lambda v: len(v), self.search_space.values())) | ||
self.num_steps = len(self.search_space) | ||
|
||
@property | ||
def observation_space(self): | ||
return spaces.Dict({ | ||
'action_history': spaces.MultiDiscrete([self.action_dim] * self.num_steps), | ||
'cur_step': spaces.Discrete(self.num_steps + 1), | ||
'action_dim': spaces.Discrete(self.action_dim + 1) | ||
}) | ||
|
||
@property | ||
def action_space(self): | ||
return spaces.Discrete(self.action_dim) | ||
|
||
def reset(self): | ||
self.action_history = np.zeros(self.num_steps, dtype=np.int32) | ||
self.cur_step = 0 | ||
self.sample = {} | ||
return { | ||
'action_history': self.action_history, | ||
'cur_step': self.cur_step, | ||
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) | ||
} | ||
|
||
def step(self, action): | ||
cur_key = self.ss_keys[self.cur_step] | ||
assert action < len(self.search_space[cur_key]), \ | ||
f'Current action {action} out of range {self.search_space[cur_key]}.' | ||
self.action_history[self.cur_step] = action | ||
self.sample[cur_key] = self.search_space[cur_key][action] | ||
self.cur_step += 1 | ||
obs = { | ||
'action_history': self.action_history, | ||
'cur_step': self.cur_step, | ||
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \ | ||
if self.cur_step < self.num_steps else self.action_dim | ||
} | ||
if self.cur_step == self.num_steps: | ||
model = get_targeted_model(self.base_model, self.mutators, self.sample) | ||
_logger.info(f'New model created: {self.sample}') | ||
submit_models(model) | ||
wait_models(model) | ||
if model.status == ModelStatus.Failed: | ||
return self.reset(), 0., False, {} | ||
rew = model.metric | ||
_logger.info(f'Model metric received as reward: {rew}') | ||
return obs, rew, True, {} | ||
else: | ||
|
||
return obs, 0., False, {} | ||
|
||
|
||
class Preprocessor(nn.Module): | ||
def __init__(self, obs_space, hidden_dim=64, num_layers=1): | ||
super().__init__() | ||
self.action_dim = obs_space['action_history'].nvec[0] | ||
self.hidden_dim = hidden_dim | ||
# first token is [SOS] | ||
self.embedding = nn.Embedding(self.action_dim + 1, hidden_dim) | ||
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True) | ||
|
||
def forward(self, obs): | ||
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token | ||
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP. | ||
seq = self.embedding(seq.long()) | ||
feature, _ = self.rnn(seq) | ||
return feature[torch.arange(len(feature), device=feature.device), obs['cur_step'].long() + 1] | ||
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, action_space, preprocess): | ||
super().__init__() | ||
self.preprocess = preprocess | ||
self.action_dim = action_space.n | ||
self.linear = nn.Linear(self.preprocess.hidden_dim, self.action_dim) | ||
|
||
def forward(self, obs, **kwargs): | ||
obs = to_torch(obs, device=self.linear.weight.device) | ||
out = self.linear(self.preprocess(obs)) | ||
# to take care of choices with different number of options | ||
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1) | ||
out[mask.to(out.device)] = float('-inf') | ||
return nn.functional.softmax(out), kwargs.get('state', None) | ||
|
||
|
||
class Critic(nn.Module): | ||
def __init__(self, preprocess): | ||
super().__init__() | ||
self.preprocess = preprocess | ||
self.linear = nn.Linear(self.preprocess.hidden_dim, 1) | ||
|
||
def forward(self, obs, **kwargs): | ||
obs = to_torch(obs, device=self.linear.weight.device) | ||
return self.linear(self.preprocess(obs)).squeeze(-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import logging | ||
from typing import Optional, Callable | ||
|
||
from .base import BaseStrategy | ||
from .utils import dry_run_for_search_space | ||
from ..execution import query_available_resources | ||
|
||
try: | ||
has_tianshou = True | ||
import torch | ||
from tianshou.data import AsyncCollector, Collector, VectorReplayBuffer | ||
from tianshou.env import SubprocVectorEnv | ||
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import | ||
from ._rl_impl import ModelEvaluationEnv, Preprocessor, Actor, Critic | ||
except ImportError: | ||
has_tianshou = False | ||
|
||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class PolicyBasedRL(BaseStrategy): | ||
""" | ||
Algorithm for policy-based reinforcement learning. | ||
This is a wrapper of algorithms provided in tianshou (PPO by default), | ||
and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_). | ||
|
||
Note that RL algorithms are known to have issues on Windows and MacOS. They will be supported in future. | ||
|
||
Parameters | ||
---------- | ||
max_collect : int | ||
How many times collector runs to collect trials for RL. Default 100. | ||
trial_per_collect : int | ||
How many trials (trajectories) each time collector collects. | ||
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20. | ||
policy_fn : function | ||
Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example. | ||
asynchronous : bool | ||
If true, in each step, collector won't wait for all the envs to complete. | ||
This should generally not affect the result, but might affect the efficiency. Note that a slightly more trials | ||
than expected might be collected if this is enabled. | ||
If asynchronous is false, collector will wait for all parallel environments to complete in each step. | ||
See ``tianshou.data.AsyncCollector`` for more details. | ||
|
||
References | ||
---------- | ||
|
||
.. [1] Barret Zoph and Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". | ||
https://arxiv.org/abs/1611.01578 | ||
""" | ||
|
||
def __init__(self, max_collect: int = 100, trial_per_collect = 20, | ||
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None, asynchronous: bool = True): | ||
if not has_tianshou: | ||
raise ImportError('`tianshou` is required to run RL-based strategy. ' | ||
'Please use "pip install tianshou" to install it beforehand.') | ||
|
||
self.policy_fn = policy_fn or self._default_policy_fn | ||
self.max_collect = max_collect | ||
self.trial_per_collect = trial_per_collect | ||
self.asynchronous = asynchronous | ||
|
||
@staticmethod | ||
def _default_policy_fn(env): | ||
net = Preprocessor(env.observation_space) | ||
actor = Actor(env.action_space, net) | ||
critic = Critic(net) | ||
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4) | ||
return PPOPolicy(actor, critic, optim, torch.distributions.Categorical, | ||
discount_factor=1., action_space=env.action_space) | ||
|
||
def run(self, base_model, applied_mutators): | ||
search_space = dry_run_for_search_space(base_model, applied_mutators) | ||
concurrency = query_available_resources() | ||
|
||
env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space) | ||
policy = self.policy_fn(env_fn()) | ||
|
||
if self.asynchronous: | ||
# wait for half of the env complete in each step | ||
env = SubprocVectorEnv([env_fn for _ in range(concurrency)], wait_num=int(concurrency * 0.5)) | ||
collector = AsyncCollector(policy, env, VectorReplayBuffer(20000, len(env))) | ||
else: | ||
env = SubprocVectorEnv([env_fn for _ in range(concurrency)]) | ||
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env))) | ||
|
||
for cur_collect in range(1, self.max_collect + 1): | ||
_logger.info('Collect [%d] Running...', cur_collect) | ||
result = collector.collect(n_episode=self.trial_per_collect) | ||
_logger.info('Collect [%d] Result: %s', cur_collect, str(result)) | ||
policy.update(0, collector.buffer, batch_size=64, repeat=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't understand, why asynchronous does not affect the result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synchronous doesn't mean single-process sampling. Both synchronous and asynchronous has parallelism. "Asynchronous" induces a mechanism to give up on some environment when it's not finished.
Refer to https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling if you feel interested. It's a bit complicated and I don't think I can make it clear here in a few words.