diff --git a/mars/dataframe/base/rechunk.py b/mars/dataframe/base/rechunk.py index fa7024f4e9..932a913621 100644 --- a/mars/dataframe/base/rechunk.py +++ b/mars/dataframe/base/rechunk.py @@ -209,11 +209,14 @@ def compute_rechunk(a, chunk_size): calc_sliced_size(s, chunk_slice[0]) for s in old_chunk.shape ) new_index_value = indexing_index_value( - old_chunk.index_value, chunk_slice[0] + old_chunk.index_value, chunk_slice[0], rechunk=True ) if is_dataframe: new_columns_value = indexing_index_value( - old_chunk.columns_value, chunk_slice[1], store_data=True + old_chunk.columns_value, + chunk_slice[1], + store_data=True, + rechunk=True, ) merge_chunk_op = DataFrameIlocGetItem( list(chunk_slice), diff --git a/mars/dataframe/indexing/loc.py b/mars/dataframe/indexing/loc.py index 69e9a41b6a..2dc9c2bd9e 100644 --- a/mars/dataframe/indexing/loc.py +++ b/mars/dataframe/indexing/loc.py @@ -26,7 +26,7 @@ from ...serialization.serializables import KeyField, ListField from ...tensor.datasource import asarray from ...tensor.utils import calc_sliced_size, filter_inputs -from ...utils import lazy_import +from ...utils import lazy_import, is_full_slice from ..core import IndexValue, DATAFRAME_TYPE from ..operands import DataFrameOperand, DataFrameOperandMixin from ..utils import parse_index @@ -154,7 +154,13 @@ def _calc_slice_param( axis: int, ) -> Dict: param = dict() - if input_index_value.has_value(): + if is_full_slice(index): + # full slice on this axis + param["shape"] = inp.shape[axis] + param["index_value"] = input_index_value + if axis == 1: + param["dtypes"] = inp.dtypes + elif input_index_value.has_value(): start, end = pd_index.slice_locs( index.start, index.stop, index.step, kind="loc" ) diff --git a/mars/dataframe/indexing/tests/test_indexing.py b/mars/dataframe/indexing/tests/test_indexing.py index 098b86da33..10df6ec78b 100644 --- a/mars/dataframe/indexing/tests/test_indexing.py +++ b/mars/dataframe/indexing/tests/test_indexing.py @@ -122,6 +122,7 @@ def test_iloc_getitem(): df4 = tile(df4) assert isinstance(df4, DATAFRAME_TYPE) assert isinstance(df4.op, DataFrameIlocGetItem) + assert df4.index_value.key == df2.index_value.key assert df4.shape == (3, 1) assert df4.chunk_shape == (2, 1) assert df4.chunks[0].shape == (2, 1) @@ -479,6 +480,7 @@ def test_dataframe_loc(): df2.index_value.to_pandas(), df.index_value.to_pandas() ) assert df2.name == "y" + assert df2.index_value.key == df.index_value.key df2 = tile(df2) assert len(df2.chunks) == 2 diff --git a/mars/dataframe/utils.py b/mars/dataframe/utils.py index c20f8eddba..9a5c5b20c1 100644 --- a/mars/dataframe/utils.py +++ b/mars/dataframe/utils.py @@ -27,7 +27,7 @@ from ..core import Entity, ExecutableTuple from ..lib.mmh3 import hash as mmh_hash from ..tensor.utils import dictify_chunk_size, normalize_chunk_sizes -from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder +from ..utils import tokenize, sbytes, lazy_import, ModulePlaceholder, is_full_slice try: import pyarrow as pa @@ -804,9 +804,13 @@ def filter_index_value(index_value, min_max, store_data=False): return parse_index(pd_index[f], store_data=store_data) -def indexing_index_value(index_value, indexes, store_data=False): +def indexing_index_value(index_value, indexes, store_data=False, rechunk=False): pd_index = index_value.to_pandas() - if not index_value.has_value(): + # when rechunk is True, the output index shall be treated + # different from the input one + if not rechunk and isinstance(indexes, slice) and is_full_slice(indexes): + return index_value + elif not index_value.has_value(): new_index_value = parse_index(pd_index, indexes, store_data=store_data) new_index_value._index_value._min_val = index_value.min_val new_index_value._index_value._min_val_close = index_value.min_val_close diff --git a/mars/utils.py b/mars/utils.py index 7b6dcede85..8b05ab3c71 100644 --- a/mars/utils.py +++ b/mars/utils.py @@ -1519,3 +1519,13 @@ def flatten_dict_to_nested_dict(flatten_dict: Dict, sep=".") -> Dict: else: sub_nested_dict = sub_nested_dict[sub_key] return nested_dict + + +def is_full_slice(slc: Any) -> bool: + """Check if the input is a full slice ((:) or (0:))""" + return ( + isinstance(slc, slice) + and (slc.start == 0 or slc.start is None) + and slc.stop is None + and slc.step is None + )