Skip to content

Commit

Permalink
Release 1.4.0+sklearn MulticolumnTfIdfvectorizer support (#229)
Browse files Browse the repository at this point in the history
* Enable multiple column based count vectorizer output for text processing

* Add feature multicolumnvectorizer

* black reformatting

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-116.us-west-2.compute.internal>
  • Loading branch information
CloudManX and Ubuntu authored Sep 29, 2021
1 parent b767faa commit d2755c6
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 14 deletions.
71 changes: 58 additions & 13 deletions python/tvm/relay/frontend/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# pylint: disable=import-outside-toplevel

import numpy as np
import tvm
from enum import IntEnum
from tvm import relay
from tvm.ir import IRModule

Expand Down Expand Up @@ -143,6 +145,24 @@ def _Pipeline(op, inexpr, dshape, dtype, func_name, columns=None):
return inexpr


def _MultiColumn(op, inexpr, dshape, dtype, func_name, columns=None):
"""
Scikit-Learn Pipeline:
Handling of multi-column based tranforms which applied on multiple inputs (len(inexpr) > 0)
Currently support: MultiColumnTfidfVectorizer
"""
assert len(inexpr) > 0
out = []
for idx, sub_vec in enumerate(op.vectorizers_):
num_features = len(sub_vec.get_feature_names())
out.append(
sklearn_op_to_relay(
sub_vec, inexpr[idx], (dshape[0], num_features), dtype, func_name, columns
)
)
return _op.concatenate(out, axis=1)


def _ColumnTransformer(op, inexpr, dshape, dtype, func_name, columns=None):
"""
Scikit-Learn Compose:
Expand All @@ -163,6 +183,9 @@ def _ColumnTransformer(op, inexpr, dshape, dtype, func_name, columns=None):
pipe, inexpr[op_type], date_shape, dtype, func_name, date_cols
)
)
elif proc_name == "text_processing":
# Pass all text input count matrices to pipeline, which is the last section of input array
out.append(sklearn_op_to_relay(pipe, inexpr[op_type:], dshape, dtype, func_name, cols))
else:
out.append(sklearn_op_to_relay(pipe, inexpr[op_type], dshape, dtype, func_name, cols))

Expand Down Expand Up @@ -585,6 +608,7 @@ def _DateTimeVectorizer(op, inexpr, dshape, dtype, columns=None):


_convert_map = {
"ColumnTransformer": {"transform": _ColumnTransformer},
"ColumnTransformer": {"transform": _ColumnTransformer},
"SimpleImputer": {"transform": _SimpleImputer},
"RobustImputer": {"transform": _RobustImputer},
Expand All @@ -595,25 +619,31 @@ def _DateTimeVectorizer(op, inexpr, dshape, dtype, columns=None):
"RobustOrdinalEncoder": {"transform": _RobustOrdinalEncoder},
"KBinsDiscretizer": {"transform": _KBinsDiscretizer},
"TfidfVectorizer": {"transform": _TfidfVectorizer},
"MultiColumnTfidfVectorizer": {"transform": _MultiColumn},
"RobustMissingIndicator": {"transform": _RobustMissingIndicator},
"RobustPCA": {"transform": _RobustPCA},
"FeatureUnion": {"transform": _FeatureUnion},
"DateTimeVectorizer": {"transform": _DateTimeVectorizer},
"Pipeline": {"transform": _Pipeline},
}

INPUT_FLOAT = 0
INPUT_STRING = 1
INPUT_DATETIME = 2

class Category(IntEnum):
INPUT_FLOAT = 0
INPUT_STRING = 1
INPUT_DATETIME = 2
INPUT_TEXT = 3


column_transformer_op_types = {
"RobustImputer": INPUT_FLOAT,
"RobustMissingIndicator": INPUT_FLOAT,
"FeatureUnion": INPUT_FLOAT,
"RobustStandardScaler": INPUT_FLOAT,
"RobustOrdinalEncoder": INPUT_STRING,
"ThresholdOneHotEncoder": INPUT_STRING,
"DateTimeVectorizer": INPUT_DATETIME,
"RobustImputer": Category.INPUT_FLOAT,
"RobustMissingIndicator": Category.INPUT_FLOAT,
"FeatureUnion": Category.INPUT_FLOAT,
"RobustStandardScaler": Category.INPUT_FLOAT,
"RobustOrdinalEncoder": Category.INPUT_STRING,
"ThresholdOneHotEncoder": Category.INPUT_STRING,
"DateTimeVectorizer": Category.INPUT_DATETIME,
"MultiColumnTfidfVectorizer": Category.INPUT_TEXT,
}


Expand All @@ -632,7 +662,7 @@ def sklearn_op_to_relay(op, inexpr, dshape, dtype, func_name, columns=None):
)
)

if classname in ["ColumnTransformer", "Pipeline", "FeatureUnion"]:
if classname in ["ColumnTransformer", "Pipeline", "FeatureUnion", "MultiColumnTfidfVectorizer"]:
return _convert_map[classname][func_name](op, inexpr, dshape, dtype, func_name, columns)

return _convert_map[classname][func_name](op, inexpr, dshape, dtype, columns)
Expand Down Expand Up @@ -669,6 +699,8 @@ def from_auto_ml(model, shape=None, dtype="float32", func_name="transform"):
if func_name == "transform":
inexpr_float = _expr.var("input_float", shape=shape, dtype=dtype)
inexpr_string = _expr.var("input_string", shape=shape, dtype=dtype)
inexpr_datetime = None
inexpr_texts = []

inexpr = [inexpr_float, inexpr_string]

Expand All @@ -680,12 +712,25 @@ def from_auto_ml(model, shape=None, dtype="float32", func_name="transform"):
)
)

for proc_name, _, cols in column_transformer.transformers_:
for proc_name, transformer, cols in column_transformer.transformers_:
if proc_name == "datetime_processing":
inexpr_datetime = _expr.var(
"input_datetime", shape=(shape[0], len(cols) * kNumDateTimeCols), dtype=dtype
)
inexpr.append(inexpr_datetime)
if proc_name == "text_processing":
multivec = transformer.steps[0][1]
for idx, sub_vec in enumerate(multivec.vectorizers_):
num_features = len(sub_vec.get_feature_names())
inexpr_texts.append(
_expr.var(
"input_text{}".format(idx), shape=(shape[0], num_features), dtype=dtype
)
)

inexpr.append(inexpr_datetime) # Padding None if inexpr_datetime is empty

for inexpr_text in inexpr_texts:
inexpr.append(inexpr_text)

outexpr = inexpr
for _, transformer in model.feature_transformer.steps:
Expand Down
46 changes: 45 additions & 1 deletion tests/python/frontend/sklearn/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from sklearn.compose import ColumnTransformer
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sagemaker_sklearn_extension.externals import AutoMLTransformer
from sagemaker_sklearn_extension.externals import Header
from sagemaker_sklearn_extension.impute import RobustImputer, RobustMissingIndicator
from sagemaker_sklearn_extension.decomposition import RobustPCA
from sagemaker_sklearn_extension.feature_extraction.text import MultiColumnTfidfVectorizer
from sagemaker_sklearn_extension.preprocessing import (
RobustStandardScaler,
ThresholdOneHotEncoder,
Expand Down Expand Up @@ -279,6 +280,49 @@ def test_automl():
_test_model_impl(st_helper, automl_transformer, dshape, data, auto_ml=True)


def test_automl_multicolumn_tfidifvectorizer():
st_helper = SklearnTestHelper()
mctiv = MultiColumnTfidfVectorizer()

corpus = np.array(
[
["Cats eat rats.", "Rats are mammals."],
["Dogs chase cats.", "Rats are mammals."],
["People like dogs.", "Rats are mammals."],
["People hate rats.", "Rats are mammals."],
]
)

pipeline = Pipeline(steps=[("multicolumnvectorizer", mctiv)])

column_transformer = ColumnTransformer(transformers=[("text_processing", pipeline, [0, 1])])
column_transformer.fit(corpus)

pipeline = Pipeline(steps=[("column_transformer", column_transformer)])
header = Header(column_names=["x1", "x2"], target_column_name="x2")

automl_transformer = AutoMLTransformer(header, pipeline, None)

dshape = [relay.Any(), relay.Any()]
st_helper.compile(automl_transformer, dshape, "float32", "transform", None, True)

multivec = column_transformer.transformers_[0][1].steps[0][1]

sklearn_out = mctiv.fit_transform(corpus).toarray()

input_data = []
for idx, sub_vec in enumerate(multivec.vectorizers_):
vectorizer = CountVectorizer(dtype=np.float32)
vectorizer.vocabulary_ = sub_vec.vocabulary_
vectorizer.fixed_vocabulary_ = sub_vec.fixed_vocabulary_
vectorizer.stop_words_ = sub_vec.stop_words_
input_data.append(vectorizer.transform(corpus[:, idx]).toarray())

tvm_out = st_helper.ex.evaluate()(input_data[0], input_data[1]).asnumpy()

tvm.testing.assert_allclose(sklearn_out, tvm_out, rtol=1e-5, atol=1e-5)


def test_feature_union():
st_helper = SklearnTestHelper()
rPCA = RobustPCA(n_components=2)
Expand Down

0 comments on commit d2755c6

Please sign in to comment.