diff --git a/CHANGELOG.md b/CHANGELOG.md index cebd5458..7114d157 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - `display()` method in `MetricsApp` ([#169](https://github.com/MobileTeleSystems/RecTools/pull/169)) +### Fixed +- Allow warp-kos loss for LightFMWrapperModel ([#175](https://github.com/MobileTeleSystems/RecTools/pull/175)) + ## [0.7.0] - 29.07.2024 ### Added diff --git a/rectools/models/lightfm.py b/rectools/models/lightfm.py index 693d7aed..32dad456 100644 --- a/rectools/models/lightfm.py +++ b/rectools/models/lightfm.py @@ -77,12 +77,13 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore ui_coo = dataset.get_user_item_matrix(include_weights=True).tocoo(copy=False) user_features = self._prepare_features(dataset.get_hot_user_features(), dataset.n_hot_users) item_features = self._prepare_features(dataset.get_hot_item_features(), dataset.n_hot_items) + sample_weight = None if self._model.loss == "warp-kos" else ui_coo self.model.fit( ui_coo, user_features=user_features, item_features=item_features, - sample_weight=ui_coo, + sample_weight=sample_weight, epochs=self.n_epochs, num_threads=self.n_threads, verbose=self.verbose > 0, diff --git a/tests/models/test_lightfm.py b/tests/models/test_lightfm.py index 96b68ffd..c0d0eeb5 100644 --- a/tests/models/test_lightfm.py +++ b/tests/models/test_lightfm.py @@ -222,6 +222,16 @@ def test_with_weights(self, interactions_df: pd.DataFrame) -> None: actual, ) + def test_with_warp_kos(self, dataset: Dataset) -> None: + base_model = DeterministicLightFM(no_components=2, loss="warp-kos") + try: + LightFMWrapperModel(model=base_model, epochs=10).fit(dataset) + except NotImplementedError: + pytest.fail("Should not raise NotImplementedError") + except ValueError: + # LightFM raises ValueError with the dataset + pass + def test_get_vectors(self, dataset_with_features: Dataset) -> None: base_model = LightFM(no_components=2, loss="logistic") model = LightFMWrapperModel(model=base_model).fit(dataset_with_features)