Skip to content

Commit

Permalink
fix assignment error
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 7, 2024
1 parent 146eb36 commit 13139a2
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 37 deletions.
2 changes: 1 addition & 1 deletion convlstm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
weights_initializer: Optional[str] = WeightsInitializer.Zeros,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion convlstm/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
num_layers: int,
input_seq_length: int,
out_channels: Optional[int] = None,
weights_initializer: Optional[str] = WeightsInitializer.Zeros.value,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
return_sequences: bool = False,
) -> None:
"""
Expand Down
44 changes: 26 additions & 18 deletions core/convlstm_cell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch
from torch import nn
Expand All @@ -17,7 +17,7 @@ def __init__(
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
weights_initializer: Optional[str] = WeightsInitializer.Zeros,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
) -> None:
"""
Expand All @@ -31,17 +31,7 @@ def __init__(
"""
super().__init__()

if activation == "tanh":
self.activation = torch.tanh
elif activation == "relu":
self.activation = torch.relu
elif activation == "leakyRelu":
self.activation = torch.nn.LeakyReLU()
elif activation == "sigmoid":
self.activation = torch.sigmoid
else:
raise ValueError(f"Unknown activation: {activation}")

self.activation = self.__activation(activation)
self.conv = nn.Conv2d(
in_channels=in_channels + out_channels,
out_channels=4 * out_channels,
Expand All @@ -59,19 +49,37 @@ def __init__(
self.W_cf = nn.parameter.Parameter(
torch.zeros(out_channels, *frame_size, dtype=torch.float)
).to(DEVICE)
self.__initialize_weights(weights_initializer)

def __activation(self, activation: str) -> nn.Module:
if activation == "tanh":
return nn.Tanh()
elif activation == "relu":
return nn.ReLU()
elif activation == "leakyRelu":
return nn.LeakyReLU()
elif activation == "sigmoid":
return nn.Sigmoid()
else:
raise ValueError(f"Unknown activation: {activation}")

if weights_initializer == WeightsInitializer.Zeros:
pass
elif weights_initializer == WeightsInitializer.He:
def __initialize_weights(self, initializer: WeightsInitializer):
if initializer == WeightsInitializer.Zeros:
return

elif initializer == WeightsInitializer.He:
nn.init.kaiming_normal_(self.W_ci, mode="fan_in", nonlinearity="leaky_relu")
nn.init.kaiming_normal_(self.W_co, mode="fan_in", nonlinearity="leaky_relu")
nn.init.kaiming_normal_(self.W_cf, mode="fan_in", nonlinearity="leaky_relu")
elif weights_initializer == WeightsInitializer.Xavier:
return

elif initializer == WeightsInitializer.Xavier:
nn.init.xavier_normal_(self.W_ci, gain=1.0)
nn.init.xavier_normal_(self.W_co, gain=1.0)
nn.init.xavier_normal_(self.W_cf, gain=1.0)
return
else:
raise ValueError(f"Invalid weights Initializer: {weights_initializer}")
raise ValueError(f"Invalid weights Initializer: {initializer}")

def forward(
self, X: torch.Tensor, prev_h: torch.Tensor, prev_cell: torch.Tensor
Expand Down
4 changes: 2 additions & 2 deletions pipelines/utils/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Optional

import numpy as np
from torch import nn
Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.best_score: Optional[float] = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ split-on-trailing-comma = true
namespace_packages = true
ignore_missing_imports = true
python_version = "3.11"
disable_error_code = ["assignment", "override", "index", "var-annotated"]
disable_error_code = ["override", "index", "var-annotated"]
4 changes: 2 additions & 2 deletions self_attention_convlstm/cell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch

Expand All @@ -21,7 +21,7 @@ def __init__(
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
weights_initializer: Optional[str] = WeightsInitializer.Zeros,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
) -> None:
super().__init__(
in_channels,
Expand Down
12 changes: 8 additions & 4 deletions self_attention_convlstm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
weights_initializer: Optional[str] = WeightsInitializer.Zeros.value,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
) -> None:
super().__init__()

Expand All @@ -40,9 +40,13 @@ def __init__(
weights_initializer,
)

self.attention_scores = None
self.in_channels = in_channels
self.out_channels = out_channels
self._attention_scores: Optional[torch.Tensor] = None

@property
def attention_scores(self) -> Optional[torch.Tensor]:
return self._attention_scores

def forward(
self,
Expand All @@ -54,7 +58,7 @@ def forward(

# NOTE: Cannot store all attention scores because of memory. So only store attention map of the center.
# And the same attention score are applied to each channels.
self.attention_scores = torch.zeros(
self._attention_scores = torch.zeros(
(batch_size, seq_len, height * width), device=DEVICE
)

Expand All @@ -76,7 +80,7 @@ def forward(
h, cell, attention = self.sa_convlstm_cell(X[:, :, time_step], h, cell)

output[:, :, time_step] = h # type: ignore
self.attention_scores[:, time_step] = attention[
self._attention_scores[:, time_step] = attention[
:, attention.size(0) // 2
] # attention shape is (batch_size, height*width, height*width)

Expand Down
2 changes: 1 addition & 1 deletion self_attention_convlstm/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
num_layers: int,
input_seq_length: int,
out_channels: Optional[int] = None,
weights_initializer: Optional[str] = WeightsInitializer.Zeros.value,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
return_sequences: bool = False,
) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions self_attention_memory_convlstm/cell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Tuple, Union

import torch

Expand All @@ -19,7 +19,7 @@ def __init__(
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
weights_initializer: Optional[str] = WeightsInitializer.Zeros.value,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
) -> None:
super().__init__(
in_channels,
Expand Down
11 changes: 7 additions & 4 deletions self_attention_memory_convlstm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
padding: Union[int, Tuple, str],
activation: str,
frame_size: Tuple,
weights_initializer: Optional[str] = WeightsInitializer.Zeros.value,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
):
super().__init__()
self.sam_convlstm_cell = SAMConvLSTMCell(
Expand All @@ -41,8 +41,11 @@ def __init__(

self.in_channels = in_channels
self.out_channels = out_channels
self._attention_scores: Optional[torch.Tensor] = None

self.attention_scores = None
@property
def attention_scores(self) -> Optional[torch.Tensor]:
return self._attention_scores

def forward(
self,
Expand All @@ -55,7 +58,7 @@ def forward(

# NOTE: Cannot store all attention scores because of memory. So only store attention map of the center.
# And the same attention score are applied to each channels.
self.attention_scores = torch.zeros(
self._attention_scores = torch.zeros(
(batch_size, seq_len, height * width), device=DEVICE
)

Expand Down Expand Up @@ -87,7 +90,7 @@ def forward(
# Save attention maps of the center point because storing
# the full `attention_h` is difficult because of the lot of memory usage.
# `attention_h` shape is (batch_size, height*width, height*width)
self.attention_scores[:, time_step] = attention_h[
self._attention_scores[:, time_step] = attention_h[
:, attention_h.size(0) // 2
]

Expand Down
2 changes: 1 addition & 1 deletion self_attention_memory_convlstm/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
num_layers: int,
input_seq_length: int,
out_channels: Optional[int] = None,
weights_initializer: Optional[str] = WeightsInitializer.Zeros,
weights_initializer: WeightsInitializer = WeightsInitializer.Zeros,
return_sequences: bool = False,
):
super().__init__()
Expand Down

0 comments on commit 13139a2

Please sign in to comment.