Skip to content

Commit

Permalink
Fixed code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Семенов Андрей Максимович committed Sep 19, 2024
1 parent 1d0a890 commit adb0a3d
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions rectools/models/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(

def forward(self, items: torch.Tensor) -> torch.Tensor:
"""TODO"""
feature_dense = self.get_item_features_dense(items)
feature_dense = self.get_dense_item_features(items)

feature_embeddings = torch.concatenate(
[self.category_embeddings[feature_name].get_all_embeddings() for feature_name in self.category_embeddings],
Expand All @@ -138,11 +138,7 @@ def device(self) -> torch.device:
"""TODO"""
return self.category_embeddings[list(self.category_embeddings.keys())[0]].device

def get_item_features_by_certain_feature(self, feature_name: str) -> torch.Tensor:
"""TODO"""
return self.category_embeddings[feature_name].get_all_embeddings()

def get_item_features_dense(self, items: torch.Tensor) -> torch.Tensor:
def get_dense_item_features(self, items: torch.Tensor) -> torch.Tensor:
"""TODO"""
feature_dense = self.item_features.take(items.detach().cpu().numpy()).get_dense()
return torch.from_numpy(feature_dense).to(self.device)
Expand Down Expand Up @@ -227,8 +223,8 @@ class ConstructedItemNetBlock(ItemNetBase):
def __init__(
self,
n_items: int,
ids_embeddings: tp.Optional[IdEmbeddingsItemNet] = None,
cat_features_embeddings: tp.Optional[CatFeaturesEmebbedingsItemBlock] = None,
ids_embeddings: tp.Optional[IdEmbeddingsItemNet],
cat_features_embeddings: tp.Optional[CatFeaturesEmebbedingsItemBlock],
) -> None:
"""TODO"""
super().__init__()
Expand All @@ -243,7 +239,7 @@ def __init__(
item_nets["cat_features_embeddings"] = cat_features_embeddings

if ids_embeddings is None and cat_features_embeddings is None:
explanation = "Either ids_embeddings or category_features_embeddings must be provided"
explanation = "Either `ids_embeddings` or `cat_features_embeddings`, or both at once must be provided."
raise ValueError(explanation)

self.item_nets = nn.ModuleDict(item_nets)
Expand Down

0 comments on commit adb0a3d

Please sign in to comment.