Skip to content

Commit

Permalink
Remove numpy data type (#571)
Browse files Browse the repository at this point in the history
* Change numpy data type; change test requirements.

* Lint
  • Loading branch information
lihuoran authored Jan 11, 2023
1 parent a4e3168 commit eb6324c
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 8 deletions.
5 changes: 2 additions & 3 deletions maro/cli/data_pipeline/citi_bike.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from enum import Enum

import geopy.distance
import numpy as np
import pandas as pd
from yaml import safe_load

Expand Down Expand Up @@ -320,7 +319,7 @@ def _process_distance(self, station_info: pd.DataFrame):
0,
index=station_info["station_index"],
columns=station_info["station_index"],
dtype=np.float,
dtype=float,
)
look_up_df = station_info[["latitude", "longitude"]]
return distance_adj.apply(
Expand Down Expand Up @@ -617,7 +616,7 @@ def _gen_distance(self, station_init: pd.DataFrame):
0,
index=station_init["station_index"],
columns=station_init["station_index"],
dtype=np.float,
dtype=float,
)
look_up_df = station_init[["latitude", "longitude"]]
distance_df = distance_adj.apply(
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/training/replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self._states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
self._actions = np.zeros((self._capacity, self._action_dim), dtype=np.float32)
self._rewards = np.zeros(self._capacity, dtype=np.float32)
self._terminals = np.zeros(self._capacity, dtype=np.bool)
self._terminals = np.zeros(self._capacity, dtype=bool)
self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
self._returns = np.zeros(self._capacity, dtype=np.float32)
self._advantages = np.zeros(self._capacity, dtype=np.float32)
Expand Down Expand Up @@ -373,7 +373,7 @@ def __init__(
self._actions = [np.zeros((self._capacity, action_dim), dtype=np.float32) for action_dim in self._action_dims]
self._rewards = [np.zeros(self._capacity, dtype=np.float32) for _ in range(self.agent_num)]
self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32)
self._terminals = np.zeros(self._capacity, dtype=np.bool)
self._terminals = np.zeros(self._capacity, dtype=bool)

assert len(agent_states_dims) == self.agent_num
self._agent_states_dims = agent_states_dims
Expand Down
2 changes: 1 addition & 1 deletion maro/streamit/client/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def is_float_type(v_type: type):
Returns:
bool: True if an float type.
"""
return v_type is float or v_type is np.float or v_type is np.float32 or v_type is np.float64
return v_type is float or v_type is np.float16 or v_type is np.float32 or v_type is np.float64


def parse_value(value: object):
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements.test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ deepdiff>=5.7.0
geopy>=2.0.0
holidays>=0.10.3
kubernetes>=21.7.0
numpy>=1.19.5,<1.24.0
numpy>=1.19.5
pandas>=0.25.3
paramiko>=2.9.2
pytest>=7.1.2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def test_append_nodes(self):
self.assertListEqual([0.0, 0.0, 0.0, 0.0, 9.0], list(states)[0:5])

# 2 padding (NAN) in the end
self.assertTrue((states[-2:].astype(np.int) == 0).all())
self.assertTrue((states[-2:].astype(int) == 0).all())

states = static_snapshot[1::"a3"]

Expand Down

0 comments on commit eb6324c

Please sign in to comment.