Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Policy-based RL Strategy #3650

Merged
merged 7 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dependencies/recommended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ pytorch-lightning >= 1.1.1, < 1.2
onnx
peewee
graphviz
gym
tianshou >= 0.4.1
2 changes: 2 additions & 0 deletions dependencies/recommended_legacy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ keras == 2.1.6
onnx
peewee
graphviz
gym
tianshou >= 0.4.1
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ prettytable
psutil
ruamel.yaml
ipython
gym
tianshou
https://download.pytorch.org/whl/cpu/torch-1.7.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
https://download.pytorch.org/whl/cpu/torchvision-0.8.2%2Bcpu-cp37-cp37m-linux_x86_64.whl
1 change: 1 addition & 0 deletions nni/retiarii/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy
from .local_debug_strategy import _LocalDebugStrategy
from .rl import PolicyBasedRL
121 changes: 121 additions & 0 deletions nni/retiarii/strategy/_rl_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# This file might cause import error for those who didn't install RL-related dependencies

import logging

import gym
import numpy as np
import torch
import torch.nn as nn

from gym import spaces
from tianshou.data import to_torch

from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models


_logger = logging.getLogger(__name__)


class ModelEvaluationEnv(gym.Env):
def __init__(self, base_model, mutators, search_space):
self.base_model = base_model
self.mutators = mutators
self.search_space = search_space
self.ss_keys = list(self.search_space.keys())
self.action_dim = max(map(lambda v: len(v), self.search_space.values()))
self.num_steps = len(self.search_space)

@property
def observation_space(self):
return spaces.Dict({
'action_history': spaces.MultiDiscrete([self.action_dim] * self.num_steps),
'cur_step': spaces.Discrete(self.num_steps + 1),
'action_dim': spaces.Discrete(self.action_dim + 1)
})

@property
def action_space(self):
return spaces.Discrete(self.action_dim)

def reset(self):
self.action_history = np.zeros(self.num_steps, dtype=np.int32)
self.cur_step = 0
self.sample = {}
return {
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]])
}

def step(self, action):
cur_key = self.ss_keys[self.cur_step]
assert action < len(self.search_space[cur_key]), \
f'Current action {action} out of range {self.search_space[cur_key]}.'
self.action_history[self.cur_step] = action
self.sample[cur_key] = self.search_space[cur_key][action]
self.cur_step += 1
obs = {
'action_history': self.action_history,
'cur_step': self.cur_step,
'action_dim': len(self.search_space[self.ss_keys[self.cur_step]]) \
if self.cur_step < self.num_steps else self.action_dim
}
if self.cur_step == self.num_steps:
model = get_targeted_model(self.base_model, self.mutators, self.sample)
_logger.info(f'New model created: {self.sample}')
submit_models(model)
wait_models(model)
if model.status == ModelStatus.Failed:
return self.reset(), 0., False, {}
rew = model.metric
_logger.info(f'Model metric received as reward: {rew}')
return obs, rew, True, {}
else:

return obs, 0., False, {}


class Preprocessor(nn.Module):
def __init__(self, obs_space, hidden_dim=64, num_layers=1):
super().__init__()
self.action_dim = obs_space['action_history'].nvec[0]
self.hidden_dim = hidden_dim
# first token is [SOS]
self.embedding = nn.Embedding(self.action_dim + 1, hidden_dim)
self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)

def forward(self, obs):
seq = nn.functional.pad(obs['action_history'] + 1, (1, 1)) # pad the start token and end token
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
seq = self.embedding(seq.long())
feature, _ = self.rnn(seq)
return feature[torch.arange(len(feature), device=feature.device), obs['cur_step'].long() + 1]


class Actor(nn.Module):
def __init__(self, action_space, preprocess):
super().__init__()
self.preprocess = preprocess
self.action_dim = action_space.n
self.linear = nn.Linear(self.preprocess.hidden_dim, self.action_dim)

def forward(self, obs, **kwargs):
obs = to_torch(obs, device=self.linear.weight.device)
out = self.linear(self.preprocess(obs))
# to take care of choices with different number of options
mask = torch.arange(self.action_dim).expand(len(out), self.action_dim) >= obs['action_dim'].unsqueeze(1)
out[mask.to(out.device)] = float('-inf')
return nn.functional.softmax(out), kwargs.get('state', None)


class Critic(nn.Module):
def __init__(self, preprocess):
super().__init__()
self.preprocess = preprocess
self.linear = nn.Linear(self.preprocess.hidden_dim, 1)

def forward(self, obs, **kwargs):
obs = to_torch(obs, device=self.linear.weight.device)
return self.linear(self.preprocess(obs)).squeeze(-1)
92 changes: 92 additions & 0 deletions nni/retiarii/strategy/rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
from typing import Optional, Callable

from .base import BaseStrategy
from .utils import dry_run_for_search_space
from ..execution import query_available_resources

try:
has_tianshou = True
import torch
from tianshou.data import AsyncCollector, Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import
from ._rl_impl import ModelEvaluationEnv, Preprocessor, Actor, Critic
except ImportError:
has_tianshou = False


_logger = logging.getLogger(__name__)


class PolicyBasedRL(BaseStrategy):
"""
Algorithm for policy-based reinforcement learning.
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy`` (e.g., REINFORCE [1]_).

Note that RL algorithms are known to have issues on Windows and MacOS. They will be supported in future.

Parameters
----------
max_collect : int
How many times collector runs to collect trials for RL. Default 100.
trial_per_collect : int
How many trials (trajectories) each time collector collects.
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes ``ModelEvaluationEnv`` as input and return a policy. See ``_default_policy_fn`` for an example.
asynchronous : bool
If true, in each step, collector won't wait for all the envs to complete.
This should generally not affect the result, but might affect the efficiency. Note that a slightly more trials
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't understand, why asynchronous does not affect the result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synchronous doesn't mean single-process sampling. Both synchronous and asynchronous has parallelism. "Asynchronous" induces a mechanism to give up on some environment when it's not finished.

Refer to https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling if you feel interested. It's a bit complicated and I don't think I can make it clear here in a few words.

than expected might be collected if this is enabled.
If asynchronous is false, collector will wait for all parallel environments to complete in each step.
See ``tianshou.data.AsyncCollector`` for more details.

References
----------

.. [1] Barret Zoph and Quoc V. Le, "Neural Architecture Search with Reinforcement Learning".
https://arxiv.org/abs/1611.01578
"""

def __init__(self, max_collect: int = 100, trial_per_collect = 20,
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None, asynchronous: bool = True):
if not has_tianshou:
raise ImportError('`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.')

self.policy_fn = policy_fn or self._default_policy_fn
self.max_collect = max_collect
self.trial_per_collect = trial_per_collect
self.asynchronous = asynchronous

@staticmethod
def _default_policy_fn(env):
net = Preprocessor(env.observation_space)
actor = Actor(env.action_space, net)
critic = Critic(net)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4)
return PPOPolicy(actor, critic, optim, torch.distributions.Categorical,
discount_factor=1., action_space=env.action_space)

def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
concurrency = query_available_resources()

env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
policy = self.policy_fn(env_fn())

if self.asynchronous:
# wait for half of the env complete in each step
env = SubprocVectorEnv([env_fn for _ in range(concurrency)], wait_num=int(concurrency * 0.5))
collector = AsyncCollector(policy, env, VectorReplayBuffer(20000, len(env)))
else:
env = SubprocVectorEnv([env_fn for _ in range(concurrency)])
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))

for cur_collect in range(1, self.max_collect + 1):
_logger.info('Collect [%d] Running...', cur_collect)
result = collector.collect(n_episode=self.trial_per_collect)
_logger.info('Collect [%d] Result: %s', cur_collect, str(result))
policy.update(0, collector.buffer, batch_size=64, repeat=5)
1 change: 1 addition & 0 deletions pipelines/full-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
python3 -m pip install keras==2.1.6
python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python3 -m pip install thop
python3 -m pip install tianshou>=0.4.1 gym
sudo apt-get install swig -y
displayName: Install extra dependencies

Expand Down
1 change: 1 addition & 0 deletions pipelines/full-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install 'pytorch-lightning>=1.1.1,<1.2'
python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python -m pip install tianshou>=0.4.1 gym
displayName: Install extra dependencies

# Need del later
Expand Down
28 changes: 24 additions & 4 deletions test/ut/retiarii/test_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import random
import sys
import time
import threading
from typing import *

import nni.retiarii.execution.api
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import pytest
import torch
import torch.nn.functional as F
from nni.retiarii import Model
Expand Down Expand Up @@ -58,7 +60,7 @@ def _reset_execution_engine(engine=None):


class Net(nn.Module):
def __init__(self, hidden_size=32):
def __init__(self, hidden_size=32, diff_size=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
Expand All @@ -69,7 +71,7 @@ def __init__(self, hidden_size=32):
self.fc2 = nn.LayerChoice([
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True)
], label='fc2')
] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]), label='fc2')

def forward(self, x):
x = F.relu(self.conv1(x))
Expand All @@ -82,8 +84,8 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


def _get_model_and_mutators():
base_model = Net()
def _get_model_and_mutators(**kwargs):
base_model = Net(**kwargs)
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = DebugEvaluator()
Expand Down Expand Up @@ -139,7 +141,25 @@ def test_evolution():
_reset_execution_engine()


@pytest.mark.skipif(sys.platform in ('win32', 'darwin'), reason='Does not run on Windows and MacOS')
def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators(diff_size=True))
wait_models(*engine.models)
_reset_execution_engine()

rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10, asynchronous=False)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()


if __name__ == '__main__':
test_grid_search()
test_random_search()
test_evolution()
test_rl()