Skip to content

Commit

Permalink
refactor shrink
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Aug 17, 2023
1 parent 77cad79 commit 101cf8b
Show file tree
Hide file tree
Showing 14 changed files with 234 additions and 164 deletions.
198 changes: 113 additions & 85 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from copy import deepcopy
from platform import machine, platform, python_build, python_version
from typing import Callable
from pyarrow.lib import ArrowInvalid

import dill as pickle
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -50,7 +50,7 @@
__version__, bk, check_dependency, check_is_fitted, check_scaling,
composed, crash, custom_transform, fit_one, flt, get_cols,
get_custom_scorer, has_task, infer_task, is_multioutput, is_sparse, lst,
method_to_log, sign, variable_return,
method_to_log, sign, variable_return, MISSING_VALUES
)


Expand Down Expand Up @@ -92,8 +92,7 @@ def __init__(
self._memory = check_memory(tempfile.gettempdir())

self._missing = [
None, np.nan, np.inf, -np.inf, "", "?", "NA",
"nan", "NaN", "none", "None", "inf", "-inf"
"", "?", "NA", "nan", "NaN", "NaT", "none", "None", "inf", "-inf"
]

self._models = ClassMap()
Expand Down Expand Up @@ -121,9 +120,9 @@ def __init__(
if "gpu" in self.device.lower():
self.log("GPU training enabled.", 1)
if (data := self.engine.get("data")) != "numpy":
self.log(f"Data execution engine: {data}.", 1)
self.log(f"Data engine: {data}.", 1)
if (models := self.engine.get("models")) != "sklearn":
self.log(f"Models execution engine: {models}.", 1)
self.log(f"Models engine: {models}.", 1)
if self.backend == "ray" or self.n_jobs > 1:
self.log(f"Parallelization backend: {self.backend}", 1)
if self.experiment:
Expand Down Expand Up @@ -200,16 +199,16 @@ def missing(self) -> list:
These values are used by the [clean][self-clean] and
[impute][self-impute] methods. Default values are: None, NaN,
+inf, -inf, "", "?", "None", "NA", "nan", "NaN" and "inf".
Note that None, NaN, +inf and -inf are always considered
NaT, +inf, -inf, "", "?", "None", "NA", "nan", "NaN", "NaT",
"inf". Note that None, NaN, +inf and -inf are always considered
missing since they are incompatible with sklearn estimators.
"""
return self._missing

@missing.setter
def missing(self, value: SEQUENCE):
self._missing = list(set(list(value) + [None, np.nan, np.inf, -np.inf]))
self._missing = list(value)

@property
def scaled(self) -> bool:
Expand All @@ -231,15 +230,15 @@ def duplicates(self) -> SERIES:
def nans(self) -> SERIES | None:
"""Columns with the number of missing values in them."""
if not is_sparse(self.X):
nans = self.dataset.replace(self.missing, np.NaN)
nans = self.dataset.replace(self.missing + MISSING_VALUES, np.NaN)
nans = nans.isna().sum()
return nans[nans > 0]

@property
def n_nans(self) -> int | None:
"""Number of samples containing missing values."""
if not is_sparse(self.X):
nans = self.dataset.replace(self.missing, np.NaN)
nans = self.dataset.replace(self.missing + MISSING_VALUES, np.NaN)
nans = nans.isna().sum(axis=1)
return len(nans[nans > 0])

Expand Down Expand Up @@ -267,15 +266,16 @@ def n_categorical(self) -> int:
def outliers(self) -> SERIES | None:
"""Columns in training set with amount of outlier values."""
if not is_sparse(self.X):
z_scores = self.train.select_dtypes(include=["number"]).apply(stats.zscore)
z_scores = (z_scores.abs() > 3).sum(axis=0)
data = self.train.select_dtypes(include=["number"])
z_scores = (np.abs(stats.zscore(data.values.astype(float))) > 3).sum(axis=0)
return z_scores[z_scores > 0]

@property
def n_outliers(self) -> int | None:
"""Number of samples in the training set containing outliers."""
if not is_sparse(self.X):
z_scores = self.train.select_dtypes(include=["number"]).apply(stats.zscore)
data = self.train.select_dtypes(include=["number"])
z_scores = (np.abs(stats.zscore(data.values.astype(float))) > 3)
return (z_scores.abs() > 3).any(axis=1).sum()

@property
Expand Down Expand Up @@ -765,29 +765,35 @@ def save_data(self, filename: str = "auto", *, dataset: str = "dataset", **kwarg
def shrink(
self,
*,
obj2cat: bool = True,
int2bool: bool = False,
int2uint: bool = False,
str2cat: bool = False,
dense2sparse: bool = False,
columns: INT | str | slice | SEQUENCE | None = None,
):
"""Converts the columns to the smallest possible matching dtype.
Examples are: float64 -> float32, int64 -> int8, etc... Sparse
arrays also transform their non-fill value. Use this method for
memory optimization. Note that applying transformers to the
data may alter the types again.
memory optimization before [saving][self-save_data] the dataset.
Note that applying transformers to the data may alter the types
again.
Parameters
----------
obj2cat: bool, default=True
Whether to convert `object` to `category`. Only if the
number of categories would be less than 30% of the length
of the column.
int2bool: bool, default=False
Whether to convert `int` columns to `bool` type. Only if the
values in the column are strictly in (0, 1) or (-1, 1).
int2uint: bool, default=False
Whether to convert `int` to `uint` (unsigned integer). Only if
the values in the column are strictly positive.
str2cat: bool, default=False
Whether to convert `string` to `category`. Only if the
number of categories would be less than 30% of the length
of the column.
dense2sparse: bool, default=False
Whether to convert all features to sparse format. The value
that is compressed is the most frequent value in the column.
Expand All @@ -796,75 +802,101 @@ def shrink(
Names, positions or dtypes of the columns in the dataset to
shrink. If None, transform all columns.
Notes
-----
Partially from: https://github.com/fastai/fastai/blob/master/
fastai/tabular/core.py
"""
columns = self.branch._get_columns(columns)
exclude_types = ["category", "datetime64[ns]", "bool"]

# Build column filter and types
types_1 = (np.int8, np.int16, np.int32, np.int64)
types_2 = (np.uint8, np.uint16, np.uint32, np.uint64)
types_3 = (np.float32, np.float64, np.longdouble)
def get_data(new_t: str) -> SERIES:
"""Get the series with the right data format.
types = {
"int": [(np.dtype(x), np.iinfo(x).min, np.iinfo(x).max) for x in types_1],
"uint": [(np.dtype(x), np.iinfo(x).min, np.iinfo(x).max) for x in types_2],
"float": [(np.dtype(x), np.finfo(x).min, np.finfo(x).max) for x in types_3],
}
Also converts to sparse format if `dense2sparse=True`.
if obj2cat:
types["object"] = "category"
else:
exclude_types += ["object"]
Parameters
----------
new_t: str
Name of the new data type.
new_dtypes = {}
for name, column in self.dataset.items():
old_t = column.dtype
if name not in columns or old_t.name in exclude_types:
continue
Returns
-------
series
Object with the new data type.
"""
if pd.api.types.is_sparse(column):
t = next(v for k, v in types.items() if old_t.subtype.name.startswith(k))
# If already sparse array, cast directly to new sparse type
return column.astype(pd.SparseDtype(new_t, column.dtype.fill_value))
else:
t = next(v for k, v in types.items() if old_t.name.startswith(k))
if dense2sparse and name not in lst(self.target): # Skip target cols
# Select most frequent value to fill the sparse array
fill_value = column.mode(dropna=False)[0]

if isinstance(t, list):
# Use uint if values are strictly positive
if int2uint and t == types["int"] and column.min() >= 0:
t = types["uint"]
# Convert first to sparse array, else fails for nullable pd types
sparse_col = pd.arrays.SparseArray(column, fill_value=fill_value)

# Find the smallest type that fits
new_t = next(
r[0] for r in t if r[1] <= column.min() and r[2] >= column.max()
)
if new_t and new_t == old_t:
new_t = None # Keep as is
return sparse_col.astype(pd.SparseDtype(new_t, fill_value=fill_value))
else:
return column.astype(new_t)

t1 = (pd.Int8Dtype, pd.Int16Dtype, pd.Int32Dtype, pd.Int64Dtype)
t2 = (pd.UInt8Dtype, pd.UInt16Dtype, pd.UInt32Dtype, pd.UInt64Dtype)
t3 = (pd.Float32Dtype, pd.Float64Dtype)

types = {
"int": [(x.name, np.iinfo(x.type).min, np.iinfo(x.type).max) for x in t1],
"uint": [(x.name, np.iinfo(x.type).min, np.iinfo(x.type).max) for x in t2],
"float": [(x.name, np.finfo(x.type).min, np.finfo(x.type).max) for x in t3],
}

# Convert selected columns to the best nullable dtype
data = self.dataset[self.branch._get_columns(columns)]

Check notice on line 849 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class

for name, column in data.items():
if pd.api.types.is_sparse(column):
old_t = column.dtype.subtype
else:
# Convert to category if number of categories less than 30% of column
new_t = t if column.nunique() <= int(len(column) * 0.3) else "object"
old_t = column.dtype

if new_t:
if pd.api.types.is_sparse(column):
new_dtypes[name] = pd.SparseDtype(new_t, old_t.fill_value)
else:
new_dtypes[name] = new_t
# TODO: Finish shrink for pyarrow
if "pyarrow" in old_t.name:
column = column.astype(column.to_numpy().dtype)

self.branch.dataset = self.branch.dataset.astype(new_dtypes)
# TODO: Finish shrink for pyarrow
column = column.convert_dtypes()

if dense2sparse:
new_cols = {}
for name, column in self.X.items():
new_cols[name] = pd.arrays.SparseArray(
data=column,
fill_value=column.mode(dropna=False)[0],
dtype=column.dtype,
)
if old_t.name.startswith("string"):
if str2cat and column.nunique() <= int(len(column) * 0.3):
self.branch._data[name] = get_data("category")

Check notice on line 866 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
continue

try:
# Get the types to look at
t = next(v for k, v in types.items() if old_t.name.lower().startswith(k))
except StopIteration:
self.branch._data[name] = get_data(column.dtype.name)

Check notice on line 873 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
continue

self.branch.X = bk.DataFrame(new_cols, index=self.y.index)
# Use bool if values are in (0, 1)
if int2bool and (t == types["int"] or t == types["uint"]):
if column.isin([0, 1]).all() or column.isin([-1, 1]).all():
self.branch._data[name] = get_data("bool")

Check notice on line 879 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
continue

# Use uint if values are strictly positive
if int2uint and t == types["int"] and column.min() >= 0:
t = types["uint"]

# Find the smallest type that fits
self.branch._data[name] = next(

Check notice on line 887 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
get_data(r[0]) for r in t if r[1] <= column.min() and r[2] >= column.max()
)

# TODO: Finish shrink for pyarrow
from pandas.core.dtypes.cast import convert_dtypes
print(self.dtypes)
self.branch.dataset = self.branch.dataset.astype(
{
name: convert_dtypes(column, dtype_backend="pyarrow")
for name, column in data.items()
}
)

self.log("The column dtypes are successfully converted.", 1)

Expand Down Expand Up @@ -907,14 +939,7 @@ def stats(self, _vb: INT = -2, /):
else:
nans = self.nans.sum()
n_categorical = self.n_categorical
try: # Fails for pyarrow dtypes
outliers = self.outliers.sum()
except ArrowInvalid:
outliers = None
self.log(
"Unable to calculate the number of outlier values. "
"Incompatible operation with the pyarrrow data engine.", 3
)
outliers = self.outliers.sum()
try: # Can fail for unhashable columns (e.g. multilabel with lists)
duplicates = self.dataset.duplicated().sum()
except TypeError:
Expand Down Expand Up @@ -1295,7 +1320,8 @@ def balance(self, strategy: str = "adasyn", **kwargs):
def clean(
self,
*,
drop_types: str | SEQUENCE | None = None,
convert_dtypes: bool = True,
drop_dtypes: str | SEQUENCE | None = None,
drop_chars: str | None = None,
strip_categorical: bool = True,
drop_duplicates: bool = False,
Expand All @@ -1308,6 +1334,7 @@ def clean(
Use the parameters to choose which transformations to perform.
The available steps are:
- Convert dtypes to the best possible types.
- Drop columns with specific data types.
- Remove characters from column names.
- Strip categorical features from white spaces.
Expand All @@ -1320,7 +1347,8 @@ def clean(
"""
columns = kwargs.pop("columns", None)
cleaner = Cleaner(
drop_types=drop_types,
convert_dtypes=convert_dtypes,
drop_dtypes=drop_dtypes,
drop_chars=drop_chars,
strip_categorical=strip_categorical,
drop_duplicates=drop_duplicates,
Expand Down
2 changes: 1 addition & 1 deletion atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _gpu(self) -> bool:
def _est_class(self) -> Predictor:
"""Return the estimator's class (not instance)."""
try:
module = import_module(f"{self.engine}.{self._module}")
module = import_module(f"{self.engine['models']}.{self._module}")
cls = self._estimators.get(self.goal, self._estimators.get("reg"))
except (ModuleNotFoundError, AttributeError):
if "sklearn" in self.supports_engines:
Expand Down
4 changes: 2 additions & 2 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
from typing import Any, Callable

import pandas as pd
from joblib.memory import Memory
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.metaestimators import available_if
Expand All @@ -24,9 +25,8 @@
from atom.pipeline import Pipeline
from atom.utils import (
DF_ATTRS, FLOAT, INT, INT_TYPES, SEQUENCE, SERIES, ClassMap, CustomDict,
Model, bk, check_is_fitted, composed, crash, divide, export_pipeline, flt,
Model, check_is_fitted, composed, crash, divide, export_pipeline, flt,
get_best_score, get_versions, has_task, is_multioutput, lst, method_to_log,
pd,
)


Expand Down
Loading

0 comments on commit 101cf8b

Please sign in to comment.