Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix df deep copy #709

Merged
merged 7 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/xorbits/_mars/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __copy__(self):
return self.copy()

def copy(self):
return self.copy_to(type(self)(_key=self.key))
return self.copy_to(type(self)())

def copy_to(self, target: "Base"):
target_fields = target._FIELDS
Expand Down
27 changes: 26 additions & 1 deletion python/xorbits/_mars/core/entity/tileables.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,14 @@ def __copy__(self):
def _view(self):
return super().copy()

def copy(self: TileableType) -> TileableType:
def copy(self: TileableType, **kw) -> TileableType:
from ...dataframe import Index
from ...deploy.oscar.session import SyncSession

new_name = None
if isinstance(self, Index):
new_name = kw.pop("name", None)

new_op = self.op.copy()
if new_op.create_view:
# if the operand is a view, make it a copy
Expand All @@ -378,6 +385,24 @@ def copy(self: TileableType) -> TileableType:
new_outs = new_op.new_tileables(
self.op.inputs, kws=params, output_limit=len(params)
)

sess = self._executed_sessions[-1] if self._executed_sessions else None
to_incref_keys = []
for _out in new_outs:
if sess:
_out._attach_session(sess)
to_incref_keys.append(_out.key)
if self.data in sess._tileable_to_fetch:
sess._tileable_to_fetch[_out.data] = sess._tileable_to_fetch[
self.data
]
if new_name:
_out.name = new_name

if to_incref_keys:
assert sess is not None
SyncSession.from_isolated_session(sess).incref(*to_incref_keys)

pos = -1
for i, out in enumerate(self.op.outputs):
# create a ref to copied one
Expand Down
50 changes: 50 additions & 0 deletions python/xorbits/_mars/dataframe/base/tests/test_base_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3182,3 +3182,53 @@ def test_nunique(setup, method, chunked, axis):
raw_df.nunique(axis=axis),
mdf.nunique(axis=axis, method=method).execute().fetch(),
)


@pytest.mark.parametrize("chunk_size", [None, 10])
def test_copy_deep(setup, chunk_size):
ns = np.random.RandomState(0)
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])
mdf = from_pandas_df(df, chunk_size=chunk_size)

# test case that there is no other result between copy and origin data
res = mdf.copy()
res["a0"] = res["a0"] + 1
dfc = df.copy(deep=True)
dfc["a0"] = dfc["a0"] + 1
pd.testing.assert_frame_equal(res.execute().fetch(), dfc)
pd.testing.assert_frame_equal(mdf.execute().fetch(), df)

s = pd.Series(ns.randint(0, 100, size=(100,)))
ms = from_pandas_series(s, chunk_size=chunk_size)

res = ms.copy()
res.iloc[0] = 111.0
sc = s.copy(deep=True)
sc.iloc[0] = 111.0
pd.testing.assert_series_equal(res.execute().fetch(), sc)
pd.testing.assert_series_equal(ms.execute().fetch(), s)

index = pd.Index([i for i in range(100)], name="test")
m_index = from_pandas_index(index, chunk_size=chunk_size)

res = m_index.copy()
assert res is not m_index
pd.testing.assert_index_equal(res.execute().fetch(), index.copy())
pd.testing.assert_index_equal(m_index.execute().fetch(), index)

res = m_index.copy(name="abc")
pd.testing.assert_index_equal(res.execute().fetch(), index.copy(name="abc"))
pd.testing.assert_index_equal(m_index.execute().fetch(), index)

# test case that there is other ops between copy and origin data
xdf = (mdf + 1) * 2 / 7
expected = (df + 1) * 2 / 7
pd.testing.assert_frame_equal(xdf.execute().fetch(), expected)

xdf_c = xdf.copy()
expected_c = expected.copy(deep=True)
pd.testing.assert_frame_equal(xdf_c.execute().fetch(), expected)
xdf_c["a1"] = xdf_c["a1"] + 0.8
expected_c["a1"] = expected_c["a1"] + 0.8
pd.testing.assert_frame_equal(xdf_c.execute().fetch(), expected_c)
pd.testing.assert_frame_equal(xdf.execute().fetch(), expected)
43 changes: 39 additions & 4 deletions python/xorbits/_mars/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,37 @@ def to_series(self, index=None, name=None):

return series_from_index(self, index=index, name=name)

def copy(self, name=None, deep=False):
"""
Make a copy of this object.

Name is set on the new object.

Parameters
----------
name : Label, optional
Set name for new object.
deep : bool, default False

Returns
-------
Index
Index refer to new object which is a copy of this object.

Notes
-----
In most cases, there should be no functional difference from using
``deep``, but if ``deep`` is passed it will attempt to deepcopy.

Examples
--------
>>> idx = pd.Index(['a', 'b', 'c'])
>>> new_idx = idx.copy()
>>> idx is new_idx
False
"""
return super().copy(name=name)


class RangeIndex(Index):
__slots__ = ()
Expand Down Expand Up @@ -1591,10 +1622,9 @@ def copy(self, deep=True): # pylint: disable=arguments-differ
copy : Series or DataFrame
Object type matches caller.
"""
if deep:
return super().copy()
else:
return super()._view()
if deep is False:
raise NotImplementedError("Not support `deep=False` for now")
return super().copy()

def __len__(self):
return len(self._data)
Expand Down Expand Up @@ -2618,6 +2648,11 @@ def apply_if_callable(maybe_callable, obj, **kwargs):
data[k] = apply_if_callable(v, data)
return data

def copy(self, deep=True):
if deep is False:
raise NotImplementedError("Not support `deep=False` for now")
return super().copy()


class DataFrameGroupByChunkData(BaseDataFrameChunkData):
type_name = "DataFrameGroupBy"
Expand Down
33 changes: 31 additions & 2 deletions python/xorbits/_mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,17 @@ def decref(self, *tileables_keys):
Tileables' keys
"""

@abstractmethod
def incref(self, *tileables_keys):
"""
Incref tileables.

Parameters
----------
tileables_keys : list
Tileables' keys
"""

@abstractmethod
def _get_ref_counts(self) -> Dict[str, int]:
"""
Expand Down Expand Up @@ -960,10 +971,19 @@ async def execute(self, *tileables, **kwargs) -> ExecutionInfo:
def _get_to_fetch_tileable(
self, tileable: TileableType
) -> Tuple[TileableType, List[Union[slice, Integral]]]:
from ...dataframe.indexing.iloc import DataFrameIlocGetItem, SeriesIlocGetItem
from ...dataframe.indexing.iloc import (
DataFrameIlocGetItem,
IndexIlocGetItem,
SeriesIlocGetItem,
)
from ...tensor.indexing import TensorIndex

slice_op_types = TensorIndex, DataFrameIlocGetItem, SeriesIlocGetItem
slice_op_types = (
TensorIndex,
DataFrameIlocGetItem,
SeriesIlocGetItem,
IndexIlocGetItem,
)

if hasattr(tileable, "data"):
tileable = tileable.data
Expand Down Expand Up @@ -1200,6 +1220,10 @@ async def decref(self, *tileable_keys):
logger.debug("Decref tileables on client: %s", tileable_keys)
return await self._lifecycle_api.decref_tileables(list(tileable_keys))

async def incref(self, *tileable_keys):
logger.debug("Incref tileables on client: %s", tileable_keys)
return await self._lifecycle_api.incref_tileables(list(tileable_keys))

async def _get_ref_counts(self) -> Dict[str, int]:
return await self._lifecycle_api.get_all_chunk_ref_counts()

Expand Down Expand Up @@ -1623,6 +1647,11 @@ def fetch_infos(self, *tileables, fields, **kwargs) -> list:
def decref(self, *tileables_keys):
pass # pragma: no cover

@implements(AbstractSyncSession.incref)
@_delegate_to_isolated_session
def incref(self, *tileables_keys):
pass # pragma: no cover

@implements(AbstractSyncSession._get_ref_counts)
@_delegate_to_isolated_session
def _get_ref_counts(self) -> Dict[str, int]:
Expand Down
6 changes: 6 additions & 0 deletions python/xorbits/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ def collect_cls_members(
) -> Dict[str, Any]:
cls_members: Dict[str, Any] = {}
for name, cls_member in inspect.getmembers(cls):
# Tileable and TileableData object may have functions that have the same names.
# For example, Index and IndexData both have `copy` function, but they have completely different semantics.
# Therefore, when the Index's `copy` method has been collected,
# the method of the same name on IndexData cannot be collected again.
if cls.__name__.endswith("Data") and name in DATA_MEMBERS[data_type]: # type: ignore
continue
if inspect.isfunction(cls_member) and not name.startswith("_"):
cls_members[name] = wrap_mars_callable(
cls_member,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .... import pandas as xpd
from ....core.data import DataRef
from ....core.execution import need_to_execute


def test_pandas_dataframe_methods(setup):
Expand Down Expand Up @@ -499,3 +500,36 @@ def test_read_pickle(setup):
assert (x == y).all()
finally:
shutil.rmtree(tempdir)


def test_copy(setup):
index = xpd.Index([i for i in range(100)], name="test")
index_iloc = index[:20]
assert need_to_execute(index_iloc) is True
repr(index_iloc)

index_copy = index_iloc.copy()
assert need_to_execute(index_copy) is False
pd.testing.assert_index_equal(index_copy.to_pandas(), index_iloc.to_pandas())

index_copy = index_iloc.copy(name="abc")
assert need_to_execute(index_copy) is True
pd.testing.assert_index_equal(
index_copy.to_pandas(), index_iloc.to_pandas().copy(name="abc")
)

series = xpd.Series([1, 2, 3, 4, np.nan, 6])
series = series + 1
assert need_to_execute(series) is True
repr(series)

sc = series.copy()
assert need_to_execute(sc) is False
expected = series.to_pandas()
pd.testing.assert_series_equal(sc.to_pandas(), expected)

sc[0] = np.nan
assert need_to_execute(sc) is True
ec = expected.copy()
ec[0] = np.nan
pd.testing.assert_series_equal(sc.to_pandas(), ec)