Skip to content

Commit

Permalink
Pass mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuoran committed Jun 19, 2022
1 parent d32d8e0 commit aa35681
Show file tree
Hide file tree
Showing 41 changed files with 624 additions and 485 deletions.
2 changes: 1 addition & 1 deletion examples/cim/rl/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyG
def get_ac(state_dim: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.0,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
Expand Down
6 changes: 3 additions & 3 deletions examples/cim/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPoli
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
reward_discount=0.0,
replay_memory_capacity=10000,
batch_size=32,
params=DQNParams(
reward_discount=0.0,
update_target_every=5,
num_epochs=10,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
),
)
2 changes: 1 addition & 1 deletion examples/cim/rl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePol
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
return DiscreteMADDPGTrainer(
name=name,
reward_discount=0.0,
params=DiscreteMADDPGParams(
reward_discount=0.0,
num_epoch=10,
get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims),
shared_critic=False,
Expand Down
2 changes: 1 addition & 1 deletion examples/cim/rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicy
def get_ppo(state_dim: int, name: str) -> PPOTrainer:
return PPOTrainer(
name=name,
reward_discount=0.0,
params=PPOParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim),
reward_discount=0.0,
grad_iters=10,
critic_loss_cls=torch.nn.SmoothL1Loss,
min_logp=None,
Expand Down
2 changes: 1 addition & 1 deletion examples/vm_scheduling/rl/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str)
def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer:
return ActorCriticTrainer(
name=name,
reward_discount=0.9,
params=ActorCriticParams(
get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features),
reward_discount=0.9,
grad_iters=100,
critic_loss_cls=torch.nn.MSELoss,
min_logp=-20,
Expand Down
8 changes: 4 additions & 4 deletions examples/vm_scheduling/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
reward_discount=0.9,
replay_memory_capacity=10000,
batch_size=32,
data_parallelism=2,
params=DQNParams(
reward_discount=0.9,
update_target_every=5,
num_epochs=100,
soft_update_coef=0.1,
double=False,
replay_memory_capacity=10000,
random_overwrite=False,
batch_size=32,
data_parallelism=2,
),
)
4 changes: 4 additions & 0 deletions maro/rl/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

from .abs_proxy import AbsProxy
from .abs_worker import AbsWorker
from .port_config import DEFAULT_ROLLOUT_PRODUCER_PORT, DEFAULT_TRAINING_BACKEND_PORT, DEFAULT_TRAINING_FRONTEND_PORT

__all__ = [
"AbsProxy",
"AbsWorker",
"DEFAULT_ROLLOUT_PRODUCER_PORT",
"DEFAULT_TRAINING_FRONTEND_PORT",
"DEFAULT_TRAINING_BACKEND_PORT",
]
3 changes: 2 additions & 1 deletion maro/rl/distributed/abs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

from abc import abstractmethod
from typing import Union

import zmq
from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(
super(AbsWorker, self).__init__()

self._id = f"worker.{idx}"
self._logger = logger if logger else DummyLogger()
self._logger: Union[LoggerV2, DummyLogger] = logger if logger else DummyLogger()

# ZMQ sockets and streams
self._context = Context.instance()
Expand Down
6 changes: 6 additions & 0 deletions maro/rl/distributed/port_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

DEFAULT_ROLLOUT_PRODUCER_PORT = 20000
DEFAULT_TRAINING_FRONTEND_PORT = 10000
DEFAULT_TRAINING_BACKEND_PORT = 10001
5 changes: 3 additions & 2 deletions maro/rl/exploration/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def __init__(
start_ep: int = 1,
initial_value: float = None,
) -> None:
super().__init__(exploration_params, param_name, initial_value=initial_value)

# validate splits
splits = [(start_ep, initial_value)] + splits + [(last_ep, final_value)]
splits = [(start_ep, self._exploration_params[self.param_name])] + splits + [(last_ep, final_value)]
splits.sort()
for (ep, _), (ep2, _) in zip(splits, splits[1:]):
if ep == ep2:
raise ValueError("The zeroth element of split points must be unique")

super().__init__(exploration_params, param_name, initial_value=initial_value)
self.final_value = final_value
self._splits = splits
self._ep = start_ep
Expand Down
20 changes: 12 additions & 8 deletions maro/rl/model/abs_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from abc import ABCMeta
from typing import Any, Dict, Optional
from typing import Any, Dict

import torch.nn
from torch.optim import Optimizer
Expand All @@ -18,17 +18,21 @@ class AbsNet(torch.nn.Module, metaclass=ABCMeta):
def __init__(self) -> None:
super(AbsNet, self).__init__()

self._optim: Optional[Optimizer] = None
@property
def optim(self) -> Optimizer:
optim = getattr(self, "_optim", None)
assert isinstance(optim, Optimizer), "Each AbsNet must have an optimizer"
return optim

def step(self, loss: torch.Tensor) -> None:
"""Run a training step to update the net's parameters according to the given loss.
Args:
loss (torch.tensor): Loss used to update the model.
"""
self._optim.zero_grad()
self.optim.zero_grad()
loss.backward()
self._optim.step()
self.optim.step()

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Get the gradients with respect to all parameters according to the given loss.
Expand All @@ -39,7 +43,7 @@ def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
Returns:
Gradients (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters.
"""
self._optim.zero_grad()
self.optim.zero_grad()
loss.backward()
return {name: param.grad for name, param in self.named_parameters()}

Expand All @@ -51,7 +55,7 @@ def apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None:
"""
for name, param in self.named_parameters():
param.grad = grad[name]
self._optim.step()
self.optim.step()

def _forward_unimplemented(self, *input: Any) -> None:
pass
Expand All @@ -64,7 +68,7 @@ def get_state(self) -> dict:
"""
return {
"network": self.state_dict(),
"optim": self._optim.state_dict(),
"optim": self.optim.state_dict(),
}

def set_state(self, net_state: dict) -> None:
Expand All @@ -74,7 +78,7 @@ def set_state(self, net_state: dict) -> None:
net_state (dict): A dict that contains the net's state.
"""
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])
self.optim.load_state_dict(net_state["optim"])

def soft_update(self, other_model: AbsNet, tau: float) -> None:
"""Soft update the net's parameters according to another net, i.e.,
Expand Down
8 changes: 4 additions & 4 deletions maro/rl/model/fc_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

from collections import OrderedDict
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Tuple, Type

import torch
import torch.nn as nn
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
skip_connection: bool = False,
dropout_p: float = None,
gradient_threshold: float = None,
name: str = None,
name: str = "NONAME",
) -> None:
super(FullyConnected, self).__init__()
self._input_dim = input_dim
Expand Down Expand Up @@ -101,12 +101,12 @@ def input_dim(self) -> int:
def output_dim(self) -> int:
return self._output_dim

def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> torch.nn.Module:
def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> nn.Module:
"""Build a basic layer.
BN -> Linear -> Activation -> Dropout
"""
components = []
components: List[Tuple[str, nn.Module]] = []
if self._batch_norm:
components.append(("batch_norm", nn.BatchNorm1d(input_dim)))
components.append(("linear", nn.Linear(input_dim, output_dim)))
Expand Down
16 changes: 8 additions & 8 deletions maro/rl/policy/abs_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -27,14 +27,14 @@ def __init__(self, name: str, trainable: bool) -> None:
self._trainable = trainable

@abstractmethod
def get_actions(self, states: object) -> object:
def get_actions(self, states: Union[list, np.ndarray]) -> Any:
"""Get actions according to states.
Args:
states (object): States.
states (Union[list, np.ndarray]): States.
Returns:
actions (object): Actions.
actions (Any): Actions.
"""
raise NotImplementedError

Expand Down Expand Up @@ -79,7 +79,7 @@ class DummyPolicy(AbsPolicy):
def __init__(self) -> None:
super(DummyPolicy, self).__init__(name="DUMMY_POLICY", trainable=False)

def get_actions(self, states: object) -> None:
def get_actions(self, states: Union[list, np.ndarray]) -> None:
return None

def explore(self) -> None:
Expand All @@ -101,11 +101,11 @@ class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta):
def __init__(self, name: str) -> None:
super(RuleBasedPolicy, self).__init__(name=name, trainable=False)

def get_actions(self, states: List[object]) -> List[object]:
def get_actions(self, states: list) -> list:
return self._rule(states)

@abstractmethod
def _rule(self, states: List[object]) -> List[object]:
def _rule(self, states: list) -> list:
raise NotImplementedError

def explore(self) -> None:
Expand Down Expand Up @@ -304,7 +304,7 @@ def unfreeze(self) -> None:
raise NotImplementedError

@abstractmethod
def get_state(self) -> object:
def get_state(self) -> dict:
"""Get the state of the policy."""
raise NotImplementedError

Expand Down
6 changes: 2 additions & 4 deletions maro/rl/policy/continuous_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ def __init__(
)

self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range)
assert self._lbounds is not None and self._ubounds is not None

self._policy_net = policy_net

@property
def action_bounds(self) -> Tuple[List[float], List[float]]:
def action_bounds(self) -> Tuple[Optional[List[float]], Optional[List[float]]]:
return self._lbounds, self._ubounds

@property
Expand Down Expand Up @@ -118,7 +116,7 @@ def eval(self) -> None:
def train(self) -> None:
self._policy_net.train()

def get_state(self) -> object:
def get_state(self) -> dict:
return self._policy_net.get_state()

def set_state(self, policy_state: dict) -> None:
Expand Down
10 changes: 6 additions & 4 deletions maro/rl/policy/discrete_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def __init__(

self._exploration_func = exploration_strategy[0]
self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing
self._exploration_schedulers = [
opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options
]
self._exploration_schedulers = (
[opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options]
if exploration_scheduling_options is not None
else []
)

self._call_cnt = 0
self._warmup = warmup
Expand Down Expand Up @@ -219,7 +221,7 @@ def eval(self) -> None:
def train(self) -> None:
self._q_net.train()

def get_state(self) -> object:
def get_state(self) -> dict:
return self._q_net.get_state()

def set_state(self, policy_state: dict) -> None:
Expand Down
Loading

0 comments on commit aa35681

Please sign in to comment.