Skip to content

Commit

Permalink
Add schemas to datasets (#7)
Browse files Browse the repository at this point in the history
* Add schemas to datasets and parse from yaml file
* removed Unnamed: 0 from the csv files
* Add tests for each dataset
  • Loading branch information
simonhkswan authored Feb 16, 2024
1 parent a9071b0 commit 8fd344f
Show file tree
Hide file tree
Showing 14 changed files with 531,654 additions and 530,119 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
*.pyc
**/__pycache__
version.py
venv

# VS Code
.vscode
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ dependencies = [
]
[tool.setuptools.package-dir]
synthesized_datasets = "src/synthesized_datasets"

[tool.setuptools.package-data]
synthesized_datasets = [
"datasets.yaml"
]

[tool.setuptools.dynamic]
version = {attr = "synthesized_datasets.version.version"}
Expand Down
69 changes: 7 additions & 62 deletions src/synthesized_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,15 @@
from ._datasets import _Dataset, _Tag
import os as _os

# Tabular
_Dataset("biased_data", "tabular/biased/biased_data.csv", [_Tag.FINANCE, _Tag.REGRESSION])
_Dataset("biased_data_mixed_types", "tabular/biased/biased_data_mixed_types.csv", [_Tag.FINANCE, _Tag.REGRESSION])
_Dataset("compas", "tabular/biased/compas.csv", [_Tag.REGRESSION])
import yaml as _yaml

_Dataset("uk_libraries", "tabular/geolocation_data/UK_libraries.csv", [_Tag.GEOSPATIAL])
_Dataset("uk_open_pubs", "tabular/geolocation_data/UK_open_pubs.csv", [_Tag.GEOSPATIAL])
_Dataset("uk_schools_list", "tabular/geolocation_data/UK_schools_list.csv", [_Tag.GEOSPATIAL])
from ._datasets import _Dataset

_Dataset("autism", "tabular/health/autism.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.HEALTHCARE])
_Dataset("breast_cancer", "tabular/health/breast_cancer.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.HEALTHCARE])
_Dataset("healthcare", "tabular/health/healthcare.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.HEALTHCARE])
_Dataset("indian_liver_patient", "tabular/health/indian-liver-patient-dataset.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.HEALTHCARE])
_Dataset("parkinsons", "tabular/health/parkinsons.csv", [_Tag.HEALTHCARE])
_DATASETS_YAML = _os.path.join(_os.path.dirname(__file__), "datasets.yaml")

_Dataset("retail_data_transactions", "tabular/insurance/retailer_data_transactions.csv", [_Tag.INSURANCE])
_Dataset("sweden_motor_insurance", "tabular/insurance/sweden_motor_insurance.csv", [_Tag.REGRESSION, _Tag.INSURANCE])
_Dataset("uk_insurance_claims_1", "tabular/insurance/uk_insurance_claims_1.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.INSURANCE])
_Dataset("uk_insurance_claims_2", "tabular/insurance/uk_insurance_claims_2.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.INSURANCE])
_Dataset("uk_land_register_transactions", "tabular/insurance/uk_land_register_transactions.csv", [_Tag.REGRESSION, _Tag.INSURANCE])
_Dataset("us_insurance_premiums", "tabular/insurance/us_insurance_premiums.csv", [_Tag.REGRESSION, _Tag.INSURANCE])

_Dataset("adult", "tabular/templates/adult.csv", [_Tag.REGRESSION])
_Dataset("atlas_higgs_detection", "tabular/templates/atlas_higgs_detection.csv", [_Tag.BINARY_CLASSIFICATION])
_Dataset("boston_housing_prices", "tabular/templates/boston_housing_prices.csv", [_Tag.REGRESSION])
_Dataset("churn_prediction", "tabular/templates/churn_prediction.csv", [_Tag.CHURN, _Tag.BINARY_CLASSIFICATION])
_Dataset("claim_prediction", "tabular/templates/claim_prediction.csv", [_Tag.INSURANCE, _Tag.BINARY_CLASSIFICATION])
_Dataset("credit", "tabular/templates/credit.csv", [_Tag.CREDIT, _Tag.BINARY_CLASSIFICATION])
_Dataset("credit_with_categories", "tabular/templates/credit_with_categoricals.csv", [_Tag.CREDIT, _Tag.BINARY_CLASSIFICATION])
_Dataset("fire_peril", "tabular/templates/fire-peril.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.INSURANCE])
_Dataset("german_credit_data", "tabular/templates/german_credit_data.csv", [_Tag.CREDIT, _Tag.REGRESSION])
_Dataset("homesite_quote_conversion", "tabular/templates/homesite-quote-conversion.csv", [_Tag.GEOSPATIAL])
_Dataset("life_insurance", "tabular/templates/life-insurance.csv", [_Tag.INSURANCE])
_Dataset("noshowappointments", "tabular/templates/noshowappointments.csv", [_Tag.HEALTHCARE, _Tag.BINARY_CLASSIFICATION])
_Dataset("price_paid_household", "tabular/templates/price_paid_household.csv", [_Tag.REGRESSION])
_Dataset("sales_pipeline", "tabular/templates/sales_pipeline.csv", [_Tag.REGRESSION])
_Dataset("segmentation_analysis", "tabular/templates/segmentation_analysis.csv", [_Tag.BINARY_CLASSIFICATION])
_Dataset("telecom_churn", "tabular/templates/telecom-churn.csv", [_Tag.CHURN, _Tag.BINARY_CLASSIFICATION])
_Dataset("telecom_churn_large", "tabular/templates/telecom-churn-large.csv", [_Tag.CHURN, _Tag.BINARY_CLASSIFICATION])
_Dataset("titanic", "tabular/templates/titanic.csv", [_Tag.REGRESSION])
_Dataset("vehicle_insurance", "tabular/templates/vehicle-insurance.csv", [_Tag.INSURANCE])
with open(_DATASETS_YAML, "r", encoding="utf-8") as _f:
for _d in _yaml.full_load_all(_f):
_Dataset._from_dict(_d)

_Dataset("bank_marketing1", "tabular/uci/bank_marketing1.csv", [_Tag.REGRESSION, _Tag.FINANCE])
_Dataset("bank_marketing2", "tabular/uci/bank_marketing2.csv", [_Tag.REGRESSION, _Tag.FINANCE])
_Dataset("creditcard_default", "tabular/uci/creditcard_default.csv", [_Tag.BINARY_CLASSIFICATION, _Tag.FINANCE, _Tag.INSURANCE])
_Dataset("onlinenews_popularity", "tabular/uci/onlinenews_popularity.csv", [_Tag.REGRESSION, _Tag.TIME_SERIES])
_Dataset("wine_quality-red", "tabular/uci/wine_quality-red.csv", [_Tag.REGRESSION])
_Dataset("wine_quality-white", "tabular/uci/wine_quality-white.csv", [_Tag.REGRESSION])


# Timeseries
_Dataset("air_quality", "time-series/air-quality.csv", [_Tag.TIME_SERIES])
_Dataset("bitcoin_price", "time-series/bitcoin_price.csv", [_Tag.FINANCE, _Tag.TIME_SERIES])
_Dataset("brent_oil_prices", "time-series/brent-oil-prices.csv", [_Tag.FINANCE, _Tag.TIME_SERIES])
_Dataset("simple_fraud", "time-series/fraud-time-series.csv", [_Tag.FRAUD, _Tag.BINARY_CLASSIFICATION])
_Dataset("simple_fraud_5gb", "https://storage.googleapis.com/synthesized-datasets-public/simple_fraud_5GB.parquet", [_Tag.FRAUD, _Tag.BINARY_CLASSIFICATION])
_Dataset("household_power_consumption_small", "time-series/household_power_consumption_small.csv", [_Tag.TIME_SERIES])
_Dataset("mock_medical_data", "time-series/mock_medical_data.csv", [_Tag.HEALTHCARE, _Tag.TIME_SERIES])
_Dataset("noaa_isd_weather_additional_dtypes_small", "time-series/NoaaIsdWeather_added_dtypes_small.csv", [_Tag.TIME_SERIES])
_Dataset("noaa_isd_weather_additional_dtypes_medium", "time-series/NoaaIsdWeather_added_dtypes_medium.csv", [_Tag.TIME_SERIES])
_Dataset("noaa_isd_weather_additional_dtypes_100gb", "https://storage.googleapis.com/synthesized-datasets-public/noaa_100gb_dtypes_set.parquet", [_Tag.TIME_SERIES])
_Dataset("occupancy_data", "time-series/occupancy-data.csv", [_Tag.TIME_SERIES])
_Dataset("s_and_p_500_5yr", "time-series/sandp500_5yr.csv", [_Tag.FINANCE, _Tag.TIME_SERIES])
_Dataset("time_series_basic", "time-series/time_series_basic.csv", [_Tag.TIME_SERIES])
_Dataset("transactions", "time-series/transactions.csv", [_Tag.TIME_SERIES])
_Dataset("transactions_sample_10k", "time-series/transactions_sample_10k.csv", [_Tag.TIME_SERIES])

from ._datasets import *
98 changes: 87 additions & 11 deletions src/synthesized_datasets/_datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os as _os
import sys as _sys
import typing as _typing
import typing as _ty
from enum import Enum as _Enum
import os as _os

import pandas as _pd
import pyspark.sql as _ps
import yaml as _yaml
from pyspark import SparkFiles as _SparkFiles

from ._dtypes import _PD_DTYPE_MAP
from ._dtypes import DType as _DType
from ._dtypes import create_pyspark_schema as _create_pyspark_schema

_ROOT_GITHUB_URL = "https://raw.githubusercontent.com/synthesized-io/datasets/master/"

Expand All @@ -29,10 +33,24 @@ def __repr__(self):


class _Dataset:
def __init__(self, name: str, url: str, tags: _typing.Optional[_typing.List[_Tag]] = None):
def __init__(
self,
name: str,
url: str,
schema: _ty.Mapping[str, _DType],
tags: _ty.Optional[_ty.List[_Tag]] = None,
date_format: _ty.Optional[str] = None,
):
self._name = name
self._url = url if url.startswith("https://storage.googleapis.com") else _ROOT_GITHUB_URL + url
self._tags: _typing.List[_Tag] = tags if tags is not None else []
self._url = (
url
if url.startswith("https://storage.googleapis.com")
else _ROOT_GITHUB_URL + url
)
self._tags: _ty.List[_Tag] = tags if tags is not None else []
self._schema = schema
self._date_format = date_format

_REGISTRIES[_Tag.ALL]._register(self)
for tag in self._tags:
_REGISTRIES[tag]._register(self)
Expand All @@ -46,7 +64,7 @@ def url(self) -> str:
return self._url

@property
def tags(self) -> _typing.List[_Tag]:
def tags(self) -> _ty.List[_Tag]:
return self._tags

def load(self) -> _pd.DataFrame:
Expand All @@ -55,34 +73,92 @@ def load(self) -> _pd.DataFrame:
df = _pd.read_parquet(self.url)
else:
# CSV load is the default
df = _pd.read_csv(self.url)
dtypes = {
col: (
_PD_DTYPE_MAP[dtype]
if dtype
not in [
_DType.DATETIME,
_DType.TIMEDELTA,
_DType.DATE,
_DType.TIME,
]
else "string"
)
for col, dtype in self._schema.items()
}
df = _pd.read_csv(self.url, dtype=dtypes)
for col, dtype in self._schema.items():
if dtype is [_DType.DATETIME, _DType.DATE]:
df[col] = _pd.to_datetime(df[col], dayfirst=True)
if dtype in [_DType.TIMEDELTA, _DType.TIME]:
df[col] = _pd.to_timedelta(df[col])
df.attrs["name"] = self.name
return df

def load_spark(self, spark: _typing.Optional[_ps.SparkSession] = None) -> _ps.DataFrame:
def load_spark(self, spark: _ty.Optional[_ps.SparkSession] = None) -> _ps.DataFrame:
"""Loads the dataset as a Spark DataFrame."""

if spark is None:
spark = _ps.SparkSession.builder.getOrCreate()

schema = _create_pyspark_schema(self._schema)
spark.sparkContext.addFile(self.url)
_, filename = _os.path.split(self.url)
if self.url.endswith("parquet"):
df = spark.read.parquet(_SparkFiles.get(filename))
else:
# CSV load is the default
df = spark.read.csv(_SparkFiles.get(filename), header=True, inferSchema=True)
df = spark.read.csv(
_SparkFiles.get(filename),
header=True,
schema=schema,
enforceSchema=False,
dateFormat=self._date_format,
)
df.name = self.name
return df

def __repr__(self):
return f"<Dataset: {self.url}>"

def _to_dict(self):
params = {
"name": self.name,
"url": (
self.url[len(_ROOT_GITHUB_URL) :]
if self.url.startswith(_ROOT_GITHUB_URL)
else self.url
),
"schema": [{key: value.value} for key, value in self._schema.items()],
"tags": [tag.value for tag in self.tags],
}
if self._date_format is not None:
params["date_format"] = self._date_format

return params

@classmethod
def _from_dict(cls, d):
return _Dataset(
name=d["name"],
url=d["url"],
date_format=d.get("date_format"),
schema={
next(iter(elem.keys())): _DType(next(iter(elem.values())))
for elem in d["schema"]
},
tags=[_Tag(tag) for tag in d["tags"]],
)

def _to_yaml(self):
return _yaml.dump(self._to_dict(), indent=2)


class _Registry:
def __init__(self, tag: _Tag):
self._tag = tag
self._datasets: _typing.MutableMapping[str, _Dataset] = {}
self._datasets: _ty.MutableMapping[str, _Dataset] = {}

def _register(self, dataset: _Dataset):
if self._tag not in dataset.tags and self._tag != _Tag.ALL:
Expand All @@ -93,7 +169,7 @@ def _register(self, dataset: _Dataset):
setattr(self, dataset.name, dataset)


_REGISTRIES: _typing.MutableMapping[_Tag, _Registry] = {}
_REGISTRIES: _ty.MutableMapping[_Tag, _Registry] = {}

for _tag in _Tag:
_registry = _Registry(_tag)
Expand Down
82 changes: 82 additions & 0 deletions src/synthesized_datasets/_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from enum import Enum

import pyspark.sql.types as st


class DType(str, Enum):
"""Class to define the internal dtypes that can be handled."""

BOOL = "bool"
NULLABLE_BOOL = "bool?"
DATETIME = "datetime"
DATE = "date"
FLOAT = "float"
DOUBLE = "double"
INTEGER = "int"
LONG = "long"
NULLABLE_LONG = "long?"
STRING = "string"
TIMEDELTA = "timedelta"
TIME = "time"


def create_pandas_schema(schema: dict[str, DType]) -> dict[str, str]:
"""Creates a PySpark schema from a dictionary of column names and d"""
return {name: _PD_DTYPE_MAP[dtype] for name, dtype in schema.items()}


def create_pyspark_schema(schema: dict[str, DType]) -> st.StructType:
"""Creates a PySpark schema from a dictionary of column names and d"""
return st.StructType(
[
st.StructField(name, _PS_DTYPE_MAP[dtype], _PS_NULLABLE_MAP[dtype])
for name, dtype in schema.items()
]
)


_PD_DTYPE_MAP = {
DType.BOOL: "bool",
DType.NULLABLE_BOOL: "boolean",
DType.DATETIME: "datetime64[ns]",
DType.DATE: "datetime64[ns]",
DType.FLOAT: "float32",
DType.DOUBLE: "float64",
DType.INTEGER: "int32",
DType.LONG: "int64",
DType.NULLABLE_LONG: "Int64",
DType.STRING: "string",
DType.TIMEDELTA: "timedelta64[ns]",
DType.TIME: "timedelta64[ns]",
}


_PS_DTYPE_MAP = {
DType.BOOL: st.BooleanType(),
DType.NULLABLE_BOOL: st.BooleanType(),
DType.DATETIME: st.TimestampType(),
DType.DATE: st.DateType(),
DType.FLOAT: st.FloatType(),
DType.DOUBLE: st.DoubleType(),
DType.INTEGER: st.IntegerType(),
DType.LONG: st.LongType(),
DType.NULLABLE_LONG: st.LongType(),
DType.STRING: st.StringType(),
DType.TIMEDELTA: st.StringType(),
DType.TIME: st.DayTimeIntervalType(1, 3),
}

_PS_NULLABLE_MAP = {
DType.BOOL: False,
DType.NULLABLE_BOOL: True,
DType.DATETIME: False,
DType.DATE: False,
DType.FLOAT: False,
DType.DOUBLE: False,
DType.INTEGER: False,
DType.LONG: False,
DType.NULLABLE_LONG: True,
DType.STRING: True,
DType.TIMEDELTA: True,
DType.TIME: True,
}
Loading

0 comments on commit 8fd344f

Please sign in to comment.