Skip to content

Commit

Permalink
Fix cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Jun 24, 2022
1 parent bfd1e2e commit fe99b4c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions etna/models/nn/deepstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing_extensions import TypedDict

import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from typing_extensions import TypedDict

from etna.core import BaseMixin

Expand Down Expand Up @@ -169,7 +169,7 @@ def forward(self, inference_batch: InferenceBatch):
output, (h_n, c_n) = self.RNN(encoder_real) # (batch_size, seq_length, latent_dim)
lds = LDS(
emission_coeff=self.ssm.emission_coeff(datetime_index_train), # (batch_size, seq_length, latent_dim)
transition_coeff=self.ssm.transition_coeff(), # (latent_dim, latent_dim)
transition_coeff=self.ssm.transition_coeff().type_as(targets), # (latent_dim, latent_dim)
innovation_coeff=self.ssm.innovation_coeff(datetime_index_train)
* self.projectors["innovation"](output), # (batch_size, seq_length, latent_dim)
noise_std=self.projectors["noise_std"](output), # (batch_size, seq_length, 1)
Expand All @@ -185,7 +185,7 @@ def forward(self, inference_batch: InferenceBatch):
output, (_, _) = self.RNN(decoder_real, (h_n, c_n)) # (batch_size, seq_length, latent_dim)
lds = LDS(
emission_coeff=self.ssm.emission_coeff(datetime_index_test), # (batch_size, seq_length, latent_dim)
transition_coeff=self.ssm.transition_coeff(), # (latent_dim, latent_dim)
transition_coeff=self.ssm.transition_coeff().type_as(targets), # (latent_dim, latent_dim)
innovation_coeff=self.ssm.innovation_coeff(datetime_index_test)
* self.projectors["innovation"](output), # (batch_size, seq_length, latent_dim)
noise_std=self.projectors["noise_std"](output), # (batch_size, seq_length, latent_dim)
Expand All @@ -211,7 +211,7 @@ def training_step(self, train_batch: TrainBatch, batch_idx):

lds = LDS(
emission_coeff=self.ssm.emission_coeff(datetime_index), # (batch_size, seq_length, latent_dim)
transition_coeff=self.ssm.transition_coeff(), # (latent_dim, latent_dim)
transition_coeff=self.ssm.transition_coeff().type_as(targets), # (latent_dim, latent_dim)
innovation_coeff=self.ssm.innovation_coeff(datetime_index)
* self.projectors["innovation"](output), # (batch_size, seq_length, latent_dim)
noise_std=self.projectors["noise_std"](output), # (batch_size, seq_length, 1)
Expand Down Expand Up @@ -316,7 +316,7 @@ def kalman_filter_step(
# print(filtered_mean.shape)
# P = (I - KH)P_t (batch_size, latent_dim, latent_dim)
filtered_cov = (
torch.eye(self.latent_dim) - kalman_gain.unsqueeze(-1) @ emission_coeff.permute(0, 2, 1)
torch.eye(self.latent_dim).type_as(target) - kalman_gain.unsqueeze(-1) @ emission_coeff.permute(0, 2, 1)
) @ prior_cov
# print(filtered_cov.shape)
# log-likelihood (batch_size, 1)
Expand Down

0 comments on commit fe99b4c

Please sign in to comment.