Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data/preprocessors] concatenator should preserve order of concatenated #47997

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/data/doc_code/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _transform_pandas(self, df: DataFrame) -> DataFrame:
# [{'X': 1.0, 'Y': 2.0}, {'X': 4.0, 'Y': 0.0}]

scaler = StandardScaler(columns=["X", "Y"])
concatenator = Concatenator()
concatenator = Concatenator(columns=["X", "Y"])
dataset_transformed = scaler.fit_transform(dataset)
dataset_transformed = concatenator.fit_transform(dataset_transformed)
print(dataset_transformed.take())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,9 @@ You can use this with Ray Train Trainers by applying them on the dataset before

# Create preprocessors to scale some columns and concatenate the results.
scaler = StandardScaler(columns=["mean radius", "mean texture"])
concatenator = Concatenator(exclude=["target"], dtype=np.float32)
columns_to_concatenate = dataset.columns()
columns_to_concatenate.remove("target")
concatenator = Concatenator(columns=columns_to_concatenate, dtype=np.float32)

# Compute dataset statistics and get transformed datasets. Note that the
# fit call is executed immediately, but the transformation is lazy.
Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4240,7 +4240,8 @@ def to_tf(
:class:`~ray.data.preprocessors.Concatenator`.

>>> from ray.data.preprocessors import Concatenator
>>> preprocessor = Concatenator(output_column_name="features", exclude="target")
>>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
>>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features")
>>> ds = preprocessor.transform(ds)
>>> ds
Concatenator
Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def to_tf(
:class:`~ray.data.preprocessors.Concatenator`.

>>> from ray.data.preprocessors import Concatenator
>>> preprocessor = Concatenator(output_column_name="features", exclude="target")
>>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
>>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features")
>>> it = preprocessor.transform(ds).iterator()
>>> it
DataIterator(Concatenator
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/preprocessors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Chain(Preprocessor):
>>>
>>> preprocessor = Chain(
... StandardScaler(columns=["X0", "X1"]),
... Concatenator(include=["X0", "X1"], output_column_name="X"),
... Concatenator(columns=["X0", "X1"], output_column_name="X"),
... LabelEncoder(label_column="Y")
... )
>>> preprocessor.fit_transform(ds).to_pandas() # doctest: +SKIP
Expand Down
108 changes: 26 additions & 82 deletions python/ray/data/preprocessors/concatenator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Optional, Union
from typing import List, Optional

import numpy as np
import pandas as pd
Expand All @@ -13,13 +13,16 @@
@PublicAPI(stability="alpha")
class Concatenator(Preprocessor):
"""Combine numeric columns into a column of type
:class:`~ray.air.util.tensor_extensions.pandas.TensorDtype`.
:class:`~ray.air.util.tensor_extensions.pandas.TensorDtype`. Only columns
specified in ``columns`` will be concatenated.

This preprocessor concatenates numeric columns and stores the result in a new
column. The new column contains
:class:`~ray.air.util.tensor_extensions.pandas.TensorArrayElement` objects of
shape :math:`(m,)`, where :math:`m` is the number of columns concatenated.
The :math:`m` concatenated columns are dropped after concatenation.
The preprocessor preserves the order of the columns provided in the ``colummns``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i cannot add a comment to the docstring above this line because it wasn't modified in this PR. but can we also modify this line, to indicate that only columns specified in the columns arg will be included?

Combine numeric columns into a column of type
    :class:`~ray.air.util.tensor_extensions.pandas.TensorDtype`.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

argument and will use that order when calling ``transform()`` and ``transform_batch()``.

Examples:
>>> import numpy as np
Expand All @@ -32,8 +35,8 @@ class Concatenator(Preprocessor):

>>> df = pd.DataFrame({"X0": [0, 3, 1], "X1": [0.5, 0.2, 0.9]})
>>> ds = ray.data.from_pandas(df) # doctest: +SKIP
>>> concatenator = Concatenator()
>>> concatenator.fit_transform(ds).to_pandas() # doctest: +SKIP
>>> concatenator = Concatenator(columns=["X0", "X1"])
>>> concatenator.transform(ds).to_pandas() # doctest: +SKIP
concat_out
0 [0.0, 0.5]
1 [3.0, 0.2]
Expand All @@ -42,104 +45,54 @@ class Concatenator(Preprocessor):
By default, the created column is called `"concat_out"`, but you can specify
a different name.

>>> concatenator = Concatenator(output_column_name="tensor")
>>> concatenator.fit_transform(ds).to_pandas() # doctest: +SKIP
>>> concatenator = Concatenator(columns=["X0", "X1"], output_column_name="tensor")
>>> concatenator.transform(ds).to_pandas() # doctest: +SKIP
tensor
0 [0.0, 0.5]
1 [3.0, 0.2]
2 [1.0, 0.9]

Sometimes, you might not want to concatenate all of of the columns in your
dataset. In this case, you can exclude columns with the ``exclude`` parameter.

>>> df = pd.DataFrame({"X0": [0, 3, 1], "X1": [0.5, 0.2, 0.9], "Y": ["blue", "orange", "blue"]})
>>> ds = ray.data.from_pandas(df) # doctest: +SKIP
>>> concatenator = Concatenator(exclude=["Y"])
>>> concatenator.fit_transform(ds).to_pandas() # doctest: +SKIP
Y concat_out
0 blue [0.0, 0.5]
1 orange [3.0, 0.2]
2 blue [1.0, 0.9]

Alternatively, you can specify which columns to concatenate with the
``include`` parameter.

>>> concatenator = Concatenator(include=["X0", "X1"])
>>> concatenator.fit_transform(ds).to_pandas() # doctest: +SKIP
Y concat_out
0 blue [0.0, 0.5]
1 orange [3.0, 0.2]
2 blue [1.0, 0.9]

Note that if a column is in both ``include`` and ``exclude``, the column is
excluded.

>>> concatenator = Concatenator(include=["X0", "X1", "Y"], exclude=["Y"])
>>> concatenator.fit_transform(ds).to_pandas() # doctest: +SKIP
Y concat_out
0 blue [0.0, 0.5]
1 orange [3.0, 0.2]
2 blue [1.0, 0.9]

By default, the concatenated tensor is a ``dtype`` common to the input columns.
However, you can also explicitly set the ``dtype`` with the ``dtype``
parameter.

>>> concatenator = Concatenator(include=["X0", "X1"], dtype=np.float32)
>>> concatenator.fit_transform(ds) # doctest: +SKIP
>>> concatenator = Concatenator(columns=["X0", "X1"], dtype=np.float32)
>>> concatenator.transform(ds) # doctest: +SKIP
Dataset(num_rows=3, schema={Y: object, concat_out: TensorDtype(shape=(2,), dtype=float32)})

Args:
output_column_name: The desired name for the new column.
Defaults to ``"concat_out"``.
include: A list of columns to concatenate. If ``None``, all columns are
concatenated.
exclude: A list of column to exclude from concatenation.
If a column is in both ``include`` and ``exclude``, the column is excluded
from concatenation.
columns: A list of columns to concatenate. The provided order of the columns
will be retained during concatenation.
dtype: The ``dtype`` to convert the output tensors to. If unspecified,
the ``dtype`` is determined by standard coercion rules.
raise_if_missing: If ``True``, an error is raised if any
of the columns in ``include`` or ``exclude`` don't exist.
of the columns in ``columns`` don't exist.
Defaults to ``False``.

Raises:
ValueError: if `raise_if_missing` is `True` and a column in `include` or
`exclude` doesn't exist in the dataset.
ValueError: if `raise_if_missing` is `True` and a column in `columns` or
doesn't exist in the dataset.
""" # noqa: E501

_is_fittable = False

def __init__(
self,
columns: List[str],
output_column_name: str = "concat_out",
include: Optional[List[str]] = None,
exclude: Optional[Union[str, List[str]]] = None,
dtype: Optional[np.dtype] = None,
raise_if_missing: bool = False,
):
if isinstance(include, str):
include = [include]
if isinstance(exclude, str):
exclude = [exclude]
self.columns = columns

self.output_column_name = output_column_name
self.include = include
self.exclude = exclude or []
self.dtype = dtype
self.raise_if_missing = raise_if_missing

def _validate(self, df: pd.DataFrame) -> None:
for parameter in "include", "exclude":
columns = getattr(self, parameter)
if columns is None:
continue

missing_columns = set(columns) - set(df)
if not missing_columns:
continue

message = f"Missing columns specified in '{parameter}': {missing_columns}"
missing_columns = set(self.columns) - set(df)
if missing_columns:
message = (
f"Missing columns specified in '{self.columns}': {missing_columns}"
)
if self.raise_if_missing:
raise ValueError(message)
else:
Expand All @@ -148,16 +101,8 @@ def _validate(self, df: pd.DataFrame) -> None:
def _transform_pandas(self, df: pd.DataFrame):
self._validate(df)

included_columns = set(df)
if self.include: # subset of included columns
included_columns = set(self.include)

columns_to_concat = list(included_columns - set(self.exclude))
ordered_columns_to_concat = [
col for col in df.columns if col in columns_to_concat
]
concatenated = df[ordered_columns_to_concat].to_numpy(dtype=self.dtype)
df = df.drop(columns=columns_to_concat)
concatenated = df[self.columns].to_numpy(dtype=self.dtype)
df = df.drop(columns=self.columns)
# Use a Pandas Series for column assignment to get more consistent
# behavior across Pandas versions.
df.loc[:, self.output_column_name] = pd.Series(list(concatenated))
Expand All @@ -166,8 +111,7 @@ def _transform_pandas(self, df: pd.DataFrame):
def __repr__(self):
default_values = {
"output_column_name": "concat_out",
"include": None,
"exclude": [],
"columns": None,
"dtype": None,
"raise_if_missing": False,
}
Expand Down
52 changes: 39 additions & 13 deletions python/ray/data/tests/preprocessors/test_concatenator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

import ray
from ray.data.exceptions import UserCodeException
Expand All @@ -16,7 +17,7 @@ def test_basic(self):
}
)
ds = ray.data.from_pandas(df)
prep = Concatenator(output_column_name="c")
prep = Concatenator(columns=["a", "b"], output_column_name="c")
new_ds = prep.transform(ds)
for i, row in enumerate(new_ds.take()):
assert np.array_equal(row["c"], np.array([i + 1, i + 5]))
Expand All @@ -25,45 +26,70 @@ def test_raise_if_missing(self):
df = pd.DataFrame({"a": [1, 2, 3, 4]})
ds = ray.data.from_pandas(df)
prep = Concatenator(
output_column_name="c", exclude=["b"], raise_if_missing=True
columns=["a", "b"], output_column_name="c", raise_if_missing=True
)

with pytest.raises(UserCodeException):
with pytest.raises(ValueError, match="'b'"):
prep.transform(ds).materialize()

@pytest.mark.parametrize("exclude", ("b", ["b"]))
def test_exclude(self, exclude):
def test_exclude_column(self):
df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5], "c": [3, 4, 5, 6]})
ds = ray.data.from_pandas(df)
prep = Concatenator(exclude=exclude)
prep = Concatenator(columns=["a", "c"])
new_ds = prep.transform(ds)
for _, row in enumerate(new_ds.take()):
assert set(row) == {"concat_out", "b"}

def test_include(self):
def test_include_columns(self):
df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5], "c": [3, 4, 5, 6]})
ds = ray.data.from_pandas(df)
prep = Concatenator(include=["a", "b"])
prep = Concatenator(columns=["a", "b"])
new_ds = prep.transform(ds)
for _, row in enumerate(new_ds.take()):
assert set(row) == {"concat_out", "c"}

def test_exclude_overrides_include(self):
df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5], "c": [3, 4, 5, 6]})
def test_change_column_order(self):
df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5]})
ds = ray.data.from_pandas(df)
prep = Concatenator(include=["a", "b"], exclude=["b"])
prep = Concatenator(columns=["b", "a"])
new_ds = prep.transform(ds)
for _, row in enumerate(new_ds.take()):
assert set(row) == {"concat_out", "b", "c"}
expected_df = pd.DataFrame({"concat_out": [[2, 1], [3, 2], [4, 3], [5, 4]]})
print(new_ds.to_pandas())
assert_frame_equal(new_ds.to_pandas(), expected_df)

def test_strings(self):
df = pd.DataFrame({"a": ["string", "string2", "string3"]})
ds = ray.data.from_pandas(df)
prep = Concatenator(output_column_name="huh")
prep = Concatenator(columns=["a"], output_column_name="huh")
new_ds = prep.transform(ds)
assert "huh" in set(new_ds.schema().names)

def test_preserves_order(self):
df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2, 3, 4, 5]})
ds = ray.data.from_pandas(df)
prep = Concatenator(columns=["a", "b"], output_column_name="c")
prep = prep.fit(ds)

df = pd.DataFrame({"a": [5, 6, 7, 8], "b": [6, 7, 8, 9]})
concatenated_df = prep.transform_batch(df)
expected_df = pd.DataFrame({"c": [[5, 6], [6, 7], [7, 8], [8, 9]]})
assert_frame_equal(concatenated_df, expected_df)

other_df = pd.DataFrame({"a": [9, 10, 11, 12], "b": [10, 11, 12, 13]})
concatenated_other_df = prep.transform_batch(other_df)
expected_df = pd.DataFrame(
{
"c": [
[9, 10],
[10, 11],
[11, 12],
[12, 13],
]
}
)
assert_frame_equal(concatenated_other_df, expected_df)


if __name__ == "__main__":
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def preferred_batch_format(cls) -> BatchFormat:
RobustScaler(columns=["X"]),
SimpleImputer(columns=["X"]),
StandardScaler(columns=["X"]),
Concatenator(),
Concatenator(columns=["X"]),
Tokenizer(columns=["X"]),
],
)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def train_func():
multi_worker_model.fit(dataset)

dataset = ray.data.from_items(8 * [{"X0": 0, "X1": 0, "Y": 0, "W": 0}])
concatenator = Concatenator(exclude=["Y", "W"], output_column_name="X")
concatenator = Concatenator(columns=["X0", "X1"], output_column_name="X")
dataset = concatenator.transform(dataset)

trainer = TensorflowTrainer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def train_func(config: dict):

def train_tensorflow_regression(num_workers: int = 2, use_gpu: bool = False) -> Result:
dataset = ray.data.read_csv("s3://anonymous@air-example-data/regression.csv")
preprocessor = Concatenator(exclude=["", "y"], output_column_name="x")
columns_to_concatenate = [f"x{i:03}" for i in range(100)]
preprocessor = Concatenator(columns=columns_to_concatenate, output_column_name="x")
dataset = preprocessor.fit_transform(dataset)

config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}
Expand Down
3 changes: 2 additions & 1 deletion python/ray/train/tests/test_tensorflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def train_func(config):
}
scaling_config = ScalingConfig(num_workers=num_workers)
dataset = ray.data.read_csv("s3://anonymous@air-example-data/regression.csv")
preprocessor = Concatenator(exclude=["", "y"], output_column_name="x")
columns_to_concatenate = [f"x{i:03}" for i in range(100)]
preprocessor = Concatenator(columns=columns_to_concatenate, output_column_name="x")
dataset = preprocessor.transform(dataset)

trainer = TensorflowTrainer(
Expand Down
Loading