Skip to content

Commit

Permalink
perf(clean): improve the peformace of the clean subpackage
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonlockhart committed Jan 5, 2021
1 parent 0965f20 commit c7c787b
Show file tree
Hide file tree
Showing 10 changed files with 786 additions and 849 deletions.
206 changes: 101 additions & 105 deletions dataprep/clean/clean_country.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
"""
Implement clean_country function
"""
from typing import Union, Any
from functools import lru_cache
import os
from functools import lru_cache
from operator import itemgetter
from typing import Any, Union

import regex as re
import pandas as pd
import numpy as np
import dask.dataframe as dd
import dask
import dask.dataframe as dd
import numpy as np
import pandas as pd
import regex as re

from .utils import NULL_VALUES, create_report, to_dask
from ..eda.progress_bar import ProgressBar
from .utils import NULL_VALUES, create_report_new, to_dask

COUNTRY_DATA_FILE = os.path.join(os.path.split(os.path.abspath(__file__))[0], "country_data.tsv")

DATA = pd.read_csv(COUNTRY_DATA_FILE, sep="\t", encoding="utf-8", dtype=str)

REGEXES = [re.compile(entry, re.IGNORECASE) for entry in DATA.regex]
STATS = {"cleaned": 0, "null": 0, "unknown": 0}
# alternative regex search strategy given on line 243
# REGEXES = re.compile("|".join(f"(?P<a{i}>{x})" for i, x in enumerate(DATA.regex)), re.IGNORECASE)


def clean_country(
Expand All @@ -31,6 +34,7 @@ def clean_country(
inplace: bool = False,
report: bool = True,
errors: str = "coerce",
progress: bool = True,
) -> pd.DataFrame:
"""
This function cleans countries
Expand Down Expand Up @@ -67,45 +71,61 @@ def clean_country(
* If ‘raise’, then invalid parsing will raise an exception.
* If ‘coerce’, then invalid parsing will be set as NaN.
* If ‘ignore’, then invalid parsing will return the input.
progress
If True, enable the progress bar
"""
# pylint: disable=too-many-arguments
reset_stats()

input_formats = {"auto", "name", "official", "alpha-2", "alpha-3", "numeric"}
output_formats = {"name", "official", "alpha-2", "alpha-3", "numeric"}
if input_format not in input_formats:
raise ValueError(
f'input_format {input_format} is invalid, it needs to be "auto", '
f'"name", "official", "alpha-2", "alpha-3" or "numeric'
f'input_format {input_format} is invalid, it needs to be one of "auto", '
'"name", "official", "alpha-2", "alpha-3" or "numeric'
)
if output_format not in output_formats:
raise ValueError(
f'output_format {output_format} is invalid, it needs to be "name", '
f'"official", "alpha-2", "alpha-3" or "numeric'
'"official", "alpha-2", "alpha-3" or "numeric'
)
if strict and fuzzy_dist > 0:
raise ValueError(
"can't do fuzzy matching while strict mode is enabled, "
"set strict = False for fuzzy matching or fuzzy_dist = 0 for strict matching"
"set strict=False for fuzzy matching or fuzzy_dist=0 for strict matching"
)

# convert to dask
df = to_dask(df)
meta = df.dtypes.to_dict()
meta[f"{column}_clean"] = str

df = df.apply(
format_country,
args=(column, input_format, output_format, fuzzy_dist, strict, errors),
axis=1,
meta=meta,

# To clean, create a new column "clean_code_tup" which contains
# the cleaned values and code indicating how the initial value was
# changed in a tuple. Then split the column of tuples and count the
# amount of different codes to produce the report
df["clean_code_tup"] = df[column].map_partitions(
lambda srs: [
_format_country(x, input_format, output_format, fuzzy_dist, strict, errors) for x in srs
],
meta=object,
)
df = df.assign(
_temp_=df["clean_code_tup"].map(itemgetter(0)),
_code_=df["clean_code_tup"].map(itemgetter(1)),
)
df = df.rename(columns={"_temp_": f"{column}_clean"})

# counts of codes indicating how values were changed
stats = df["_code_"].value_counts(sort=False)
df = df.drop(columns=["clean_code_tup", "_code_"])

if inplace:
df = df.drop(columns=[column])
df = df.drop(columns=column)

df, nrows = dask.compute(df, df.shape[0])
with ProgressBar(minimum=1, disable=not progress):
df, stats = dask.compute(df, stats)

# output a report describing the result of clean_country
if report:
create_report("Country", STATS, nrows)
create_report_new("Country", stats, errors)

return df

Expand All @@ -132,69 +152,62 @@ def validate_country(
"""

if isinstance(x, pd.Series):
x = x.astype(str).str.lower()
x = x.str.strip()
return x.apply(check_country, args=(input_format, False, strict))
x = x.astype(str).str.lower().str.strip()
return x.apply(_check_country, args=(input_format, strict, False))

x = str(x).lower()
x = x.strip()
return check_country(x, input_format, False, strict)
x = str(x).lower().strip()
return _check_country(x, input_format, strict, False)


def format_country(
row: pd.Series,
col: str,
def _format_country(
val: Any,
input_format: str,
output_format: str,
fuzzy_dist: int,
strict: bool,
errors: str,
) -> pd.Series:
) -> Any:
"""
Function to transform a country instance into the
desired format
Function to transform a country instance into the desired format
The last component of the returned tuple contains a code indicating how the
input value was changed:
0 := the value is null
1 := the value could not be parsed
2 := the value is cleaned and the cleaned value is DIFFERENT than the input value
3 := the value is cleaned and is THE SAME as the input value (no transformation)
"""
# pylint: disable=too-many-arguments
# check_country parses the value in row[col], and will return the index of the country
# _check_country parses input value "val", and returns the index of the country
# in the DATA dataframe. The returned value "status" can be either "null"
# (which means row[col] contains a null value), "unknown" (in which case the value
# in row[col] could not be parsed) or "success" (a successful parse of the value).
# (which means val is a null value), "unknown" (in which case val
# could not be parsed) or "success" (a successful parse of the value).

country = str(row[col])
country = country.lower().strip()
result_index, status = check_country(country, input_format, True, strict)
country = str(val).lower().strip()
result_index, status = _check_country(country, input_format, strict, True)

if fuzzy_dist > 0 and status == "unknown" and input_format in ("auto", "name", "official"):
result_index, status = check_fuzzy_dist(country, fuzzy_dist)
result_index, status = _check_fuzzy_dist(country, fuzzy_dist)

if status == "null":
STATS["null"] += 1
row[f"{col}_clean"] = np.nan
return row
return np.nan, 0
if status == "unknown":
if errors == "raise":
raise ValueError(f"unable to parse value {row[col]}")
STATS["unknown"] += 1
row[f"{col}_clean"] = row[col] if errors == "ignore" else np.nan
return row
raise ValueError(f"unable to parse value {val}")
return val if errors == "ignore" else np.nan, 1

result = DATA.iloc[result_index][output_format]
result = DATA.loc[result_index, output_format]
if pd.isna(result):
# country doesn't have the required output format
if errors == "raise":
raise ValueError(f"unable to parse value {row[col]}")
STATS["unknown"] += 1
row[f"{col}_clean"] = row[col] if errors == "ignore" else np.nan
return row
raise ValueError(f"unable to parse value {val}")
return val if errors == "ignore" else np.nan, 1

row[f"{col}_clean"] = result
if row[col] != row[f"{col}_clean"]:
STATS["cleaned"] += 1
return row
return result, 2 if val != result else 3


@lru_cache(maxsize=2 ** 20)
def check_country(country: str, input_format: str, clean: bool, strict: bool) -> Any:
def _check_country(country: str, input_format: str, strict: bool, clean: bool) -> Any:
"""
Finds the index of the given country in the DATA dataframe.
Expand All @@ -204,91 +217,74 @@ def check_country(country: str, input_format: str, clean: bool, strict: bool) ->
string containing the country value being cleaned
input_format
the ISO 3166 input format of the country
clean
If True, a tuple (index, status) is returned.
If False, the function returns True/False to be used by the validate country function.
strict
If True, for input types "name" and "offical" the function looks for a direct match
in the DATA dataframe. If False, the country input is searched for a regex match.
clean
If True, a tuple (index, status) is returned.
If False, the function returns True/False to be used by the validate country function.
"""
if country in NULL_VALUES:
return ("", "null") if clean else False
return (None, "null") if clean else False

if input_format == "auto":
input_format = get_format_from_name(country)
input_format = _get_format_from_name(country)

if strict and input_format == "regex":
for input_type in ("name", "official"):
indices = DATA[
DATA[input_type].str.contains("^" + country + "$", flags=re.IGNORECASE, na=False)
].index

if np.size(indices) > 0:
return (indices[0], "success") if clean else True
for form in ("name", "official"):
ind = DATA[DATA[form].str.contains(f"^{country}$", flags=re.IGNORECASE, na=False)].index
if np.size(ind) > 0:
return (ind[0], "success") if clean else True

elif not strict and input_format in ("regex", "name", "official"):
for index, country_regex in enumerate(REGEXES):
if country_regex.search(country):
return (index, "success") if clean else True

# alternative regex search strategy
# match = REGEXES.search(country)
# if match:
# return (int(match.lastgroup[1:]), "success") if clean else True
else:
indices = DATA[
DATA[input_format].str.contains("^" + country + "$", flags=re.IGNORECASE, na=False)
ind = DATA[
DATA[input_format].str.contains(f"^{country}$", flags=re.IGNORECASE, na=False)
].index
if np.size(ind) > 0:
return (ind[0], "success") if clean else True

if np.size(indices) > 0:
return (indices[0], "success") if clean else True

return ("", "unknown") if clean else False
return (None, "unknown") if clean else False


@lru_cache(maxsize=2 ** 20)
def check_fuzzy_dist(country: str, fuzzy_dist: int) -> Any:
def _check_fuzzy_dist(country: str, fuzzy_dist: int) -> Any:
"""
A match is found if a country has an edit distance <= fuzzy_dist
with a string that contains a match with one of the country regexes.
Find the index of a match with a minimum edit distance.
"""
results = []
for index, country_regex in enumerate(DATA.regex):
for i, country_regex in enumerate(DATA.regex):
# {e<=fuzzy_dist} means the total number of errors
# (insertions, deletions and substitutions) must be <= fuzzy_dist,
# re.BESTMATCH looks for a match with minimum number of errors
fuzzy_regex = "(" + country_regex + f"){{e<={fuzzy_dist}}}"
fuzzy_regex = f"({country_regex}){{e<={fuzzy_dist}}}"
match = re.search(fuzzy_regex, country, flags=re.BESTMATCH | re.IGNORECASE)
if match:
# add total number of errors and the index to results
results.append((sum(match.fuzzy_counts), index))
results.append((sum(match.fuzzy_counts), i))

if not results:
return "", "unknown"

_, min_index = min(results)
return min_index, "success"
return None, "unknown"

return min(results)[1], "success"

def get_format_from_name(name: str) -> str:

def _get_format_from_name(name: str) -> str:
"""
Function to infer the input format. Used when the input format is auto.
"""
try:
int(name)
src_format = "numeric"
return "numeric"
except ValueError:
if len(name) == 2:
src_format = "alpha-2"
elif len(name) == 3:
src_format = "alpha-3"
else:
src_format = "regex"

return src_format


def reset_stats() -> None:
"""
Reset global statistics dictionary
"""
STATS["cleaned"] = 0
STATS["null"] = 0
STATS["unknown"] = 0
return "alpha-2" if len(name) == 2 else "alpha-3" if len(name) == 3 else "regex"
Loading

0 comments on commit c7c787b

Please sign in to comment.