Skip to content

Commit

Permalink
Merge pull request #196 from WenjieDu/wdu_dev
Browse files Browse the repository at this point in the history
Add learning-rate schedulers for optimizers, and
  • Loading branch information
WenjieDu authored Sep 28, 2023
2 parents 09b494d + c426cb2 commit fe4e41e
Show file tree
Hide file tree
Showing 37 changed files with 1,459 additions and 302 deletions.
25 changes: 2 additions & 23 deletions docs/pypots.forecasting.rst
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
pypots.forecasting package
==========================

Subpackages
-----------

.. toctree::
:maxdepth: 4

pypots.forecasting.bttf
pypots.forecasting.template

Submodules
----------

pypots.forecasting.base module
pypots.forecasting.bttf module
------------------------------

.. automodule:: pypots.forecasting.base
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

Module contents
---------------

.. automodule:: pypots.forecasting
.. automodule:: pypots.forecasting.bttf
:members:
:undoc-members:
:show-inheritance:
Expand Down
9 changes: 9 additions & 0 deletions docs/pypots.optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ pypots.optim.base module
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.optim.lr_scheduler module
------------------------------

.. automodule:: pypots.optim.lr_scheduler
:members:
:undoc-members:
:show-inheritance:
:inherited-members:
11 changes: 5 additions & 6 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,12 @@ class BRITS(BaseNNClassifier):
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
Attributes
References
----------
model : :class:`torch.nn.Module`
The underlying BRITS model.
optimizer : :class:`pypots.optim.Optimizer`
The optimizer for model training.
.. [1] `Cao, Wei, Dong Wang, Jian Li, Hao Zhou, Lei Li, and Yitan Li.
"Brits: Bidirectional recurrent imputation for time series."
Advances in neural information processing systems 31 (2018).
<https://arxiv.org/pdf/1805.10572>`_
"""

Expand Down
10 changes: 5 additions & 5 deletions pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ class GRUD(BaseNNClassifier):
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
Attributes
References
----------
model : :class:`torch.nn.Module`
The underlying GRU-D model.
.. [1] `Che, Zhengping, Sanjay Purushotham, Kyunghyun Cho, David Sontag, and Yan Liu.
"Recurrent neural networks for multivariate time series with missing values."
Scientific reports 8, no. 1 (2018): 6085.
<https://www.nature.com/articles/s41598-018-24271-9.pdf>`_
optimizer : :class:`pypots.optim.Optimizer`
The optimizer for model training.
"""

def __init__(
Expand Down
12 changes: 5 additions & 7 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,12 @@ class Raindrop(BaseNNClassifier):
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
Attributes
References
----------
model : :class:`torch.nn.Module`
The underlying Raindrop model.
optimizer : :class:`pypots.optim.Optimizer`
The optimizer for model training.
.. [1] `Zhang, Xiang, Marko Zeman, Theodoros Tsiligkaridis, and Marinka Zitnik.
"Graph-guided network for irregularly sampled multivariate time series."
International Conference on Learning Representations (ICLR). 2022.
<https://openreview.net/forum?id=Kwm8I7dU-l5>`_
"""

def __init__(
Expand Down
12 changes: 6 additions & 6 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ class CRLI(BaseNNClusterer):
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
Attributes
References
----------
model : :class:`torch.nn.Module`
The underlying CRLI model.
optimizer : :class:`pypots.optim.Optimizer`
The optimizer for model training.
.. [1] `Ma, Qianli, Chuxin Chen, Sen Li, and Garrison W. Cottrell. 2021.
"Learning Representations for Incomplete Time Series Clustering".
Proceedings of the AAAI Conference on Artificial Intelligence 35 (10):8837-46.
https://doi.org/10.1609/aaai.v35i10.17070.
<https://ojs.aaai.org/index.php/AAAI/article/view/17070>`_
"""

Expand Down
34 changes: 23 additions & 11 deletions pypots/clustering/crli/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
)
output_collector = torch.empty((bz, n_steps, self.d_input), device=self.device)
if self.cell_type == "LSTM":
# TODO: cell states should have different shapes
cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device)
cell_states = [
torch.zeros((bz, self.d_hidden), device=self.device)
for i in range(self.n_layer)
]

for step in range(n_steps):
x = X[:, step, :]
estimation = self.output_layer(hidden_state)
Expand All @@ -76,13 +79,14 @@ def forward(self, inputs: dict) -> Tuple[torch.Tensor, torch.Tensor]:
)
for i in range(self.n_layer):
if i == 0:
hidden_state, cell_states = self.model[i](
imputed_x, (hidden_state, cell_states)
hidden_state, cell_state = self.model[i](
imputed_x, (hidden_state, cell_states[i])
)
else:
hidden_state, cell_states = self.model[i](
hidden_state, (hidden_state, cell_states)
hidden_state, cell_state = self.model[i](
hidden_state, (hidden_state, cell_states[i])
)

hidden_state_collector[:, step, :] = hidden_state

elif self.cell_type == "GRU":
Expand Down Expand Up @@ -168,19 +172,27 @@ def forward(self, inputs: dict) -> torch.Tensor:
]
hidden_state_collector = torch.empty((bz, n_steps, 32), device=self.device)
if self.cell_type == "LSTM":
cell_states = torch.zeros((self.d_input, self.d_hidden), device=self.device)
cell_states = [
torch.zeros((bz, 32), device=self.device),
torch.zeros((bz, 16), device=self.device),
torch.zeros((bz, 8), device=self.device),
torch.zeros((bz, 16), device=self.device),
torch.zeros((bz, 32), device=self.device),
]
for step in range(n_steps):
x = imputed_X[:, step, :]
for i, rnn_cell in enumerate(self.rnn_cell_module_list):
if i == 0:
hidden_state, cell_states = rnn_cell(
x, (hidden_states[i], cell_states)
hidden_state, cell_state = rnn_cell(
x, (hidden_states[i], cell_states[i])
)
else:
hidden_state, cell_states = rnn_cell(
hidden_states[i - 1], (hidden_states[i], cell_states)
hidden_state, cell_state = rnn_cell(
hidden_states[i - 1], (hidden_states[i], cell_states[i])
)
cell_states[i] = cell_state
hidden_states[i] = hidden_state

hidden_state_collector[:, step, :] = hidden_state

elif self.cell_type == "GRU":
Expand Down
12 changes: 7 additions & 5 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,15 @@ class VaDER(BaseNNClusterer):
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
Attributes
References
----------
model : :class:`torch.nn.Module`
The underlying VaDER model.
.. [1] `de Jong, Johann, Mohammad Asif Emon, Ping Wu, Reagon Karki, Meemansa Sood, Patrice Godard,
Ashar Ahmad, Henri Vrooman, Martin Hofmann-Apitius, and Holger Fröhlich.
"Deep learning for clustering of multivariate clinical patient trajectories with missing values."
GigaScience 8, no. 11 (2019): giz134.
<https://academic.oup.com/gigascience/article-pdf/8/11/giz134/30797160/giz134.pdf>`_
optimizer : :class:`pypots.optim.Optimizer`
The optimizer for model training.
"""

Expand Down
10 changes: 6 additions & 4 deletions pypots/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .base import BaseDataset
from .generating import (
gene_complete_random_walk,
gene_random_walk_for_classification,
gene_incomplete_random_walk_dataset,
gene_complete_random_walk_for_anomaly_detection,
gene_complete_random_walk_for_classification,
gene_random_walk,
gene_physionet2012,
)
from .load_specific_datasets import (
Expand All @@ -29,8 +30,9 @@
"BaseDataset",
# data generation
"gene_complete_random_walk",
"gene_random_walk_for_classification",
"gene_incomplete_random_walk_dataset",
"gene_complete_random_walk_for_anomaly_detection",
"gene_complete_random_walk_for_classification",
"gene_random_walk",
"gene_physionet2012",
# list and load datasets
"list_supported_datasets",
Expand Down
Loading

0 comments on commit fe4e41e

Please sign in to comment.