Skip to content

Commit

Permalink
Encoding (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
yinsn authored Jul 18, 2022
1 parent 96dd91b commit 29d1ceb
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 3 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include src/rsdiv/embedding/*.pkl
include src/rsdiv/encoding/*.json
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ metrics.get_lorenz_curve(ratings['movieId'])
```
![Lorenz](pics/Lorenz.png)

### Evaluate the unblance from a sense of location
**rsdiv** provides the encoders including geography encoding function to improve the intuitive understanding for practitioners, to start with the random values:
```
>>> geo = rs.GeoEncoder()
>>> df = geo.read_source()
>>> rng = np.random.RandomState(42)
>>> df['random_values'] = rng.rand(len(df))
>>> geo.graw_geo_graph(df, 'random_values')
```
![GeoEncoder](pics/random_values.png)

### Train a recommender
**rsdiv** provides various implementations of core recommender algorithms. To start with, a wrapper for `LightFM` is also supported:
```
Expand Down Expand Up @@ -139,6 +150,7 @@ Not only for categorical labels, **rsdiv** also supports embedding for items, fo
- implement the Bounded Greedy Selection Strategy, BGS diversify algorithm
- implement the Determinantal Point Process, DPP diversify algorithm
- implement the Modified Gram-Schmidt, MGS diversify algorithm

### Hyperparameter optimization
**TODO**
- compatible with [Optuna](https://github.com/optuna/optuna).
Expand Down
Binary file added pics/random_values.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"lightfm>=1.16",
"scikit-learn>=1.1.1",
"matplotlib>=3.5.2",
"plotly>=5.6.0",
],
)

setup(
name="rsdiv",
version="0.1.9",
version="0.1.10",
author="Yin Cheng",
author_email="yin.sjtu@gmail.com",
long_description=LONG_DESCRIPTION,
Expand Down
1 change: 1 addition & 0 deletions src/rsdiv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dataset import *
from .embedding import *
from .encoding import *
from .evaluation import *
from .recommenders import *
1 change: 0 additions & 1 deletion src/rsdiv/embedding/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABCMeta
from pathlib import Path
from typing import Dict

import numpy as np
Expand Down
9 changes: 9 additions & 0 deletions src/rsdiv/encoding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import imp

from .base import BaseEncoder
from .geo_encoder import GeoEncoder

__all__ = [
"BaseEncoder",
"GeoEncoder",
]
10 changes: 10 additions & 0 deletions src/rsdiv/encoding/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABCMeta, abstractclassmethod, abstractmethod
from typing import Any, Dict, List, Union


class BaseEncoder(metaclass=ABCMeta):
encode_source: Dict[str, Any]

@abstractmethod
def encoding_single(cls, org: Union[List, str]) -> Union[int, str]:
raise NotImplementedError("embedding_single must be implemented.")
65 changes: 65 additions & 0 deletions src/rsdiv/encoding/geo_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import pkgutil
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
import plotly.express as px
from scipy import spatial

from .base import BaseEncoder


class GeoEncoder(BaseEncoder):
r"""Plotly Sample Datasets."""
ECD_PATH: Optional[bytes] = pkgutil.get_data(
"rsdiv.encoding", "geojson-counties-fips.json"
)
if ECD_PATH:
encode_source: Dict[str, Any] = json.loads(ECD_PATH)

def __init__(self) -> None:
super().__init__()
self.encoder: pd.DataFrame = self.read_source()
self.coord: List[np.ndarray] = self.encoder.coord.to_list()
self.index: pd.Index = pd.Index(self.encoder["index"])

def read_source(self) -> pd.DataFrame:
geo_county_dict: Dict[str, List] = {}
for item in self.encode_source["features"]:
coordinates = item["geometry"]["coordinates"]
parts = []
for part in coordinates:
parts.append(np.asarray(part).squeeze().mean(axis=0).squeeze())
coord = np.asarray(parts).mean(axis=0)[::-1] # reverse lat/lng
name = item["properties"]["NAME"]
lsad = item["properties"]["LSAD"]
id = item["id"]
geo_county_dict[id] = [coord, name, lsad]
dataframe = pd.DataFrame.from_dict(
geo_county_dict, orient="index", columns=["coord", "name", "lstd"]
).reset_index()
return dataframe

def encoding_single(self, org: Union[List, str]) -> Union[int, str]:
tree = spatial.KDTree(self.coord)
return str(self.index[int(tree.query(org)[1])])

def encoding_series(self, series: pd.Series) -> pd.Series:
encodings = pd.Series(series.apply(lambda x: self.encoding_single(x)))
return encodings

def graw_geo_graph(self, dataframe: pd.DataFrame, source_name: str) -> None:
max_value: float = np.ceil(dataframe[source_name].max())
min_value: float = np.floor(dataframe[source_name].min())
fig = px.choropleth(
dataframe,
geojson=self.encode_source,
locations="index",
color=source_name,
color_continuous_scale="OrRd",
range_color=(min_value, max_value),
scope="usa",
)
fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})
fig.show()
1 change: 1 addition & 0 deletions src/rsdiv/encoding/geojson-counties-fips.json

Large diffs are not rendered by default.

0 comments on commit 29d1ceb

Please sign in to comment.