Skip to content

Commit

Permalink
Fix DataFrame comparison when data type is period
Browse files Browse the repository at this point in the history
  • Loading branch information
继盛 committed Aug 23, 2021
1 parent 3f41408 commit ebc7033
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mars/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from typing import NamedTuple, Optional

version_info = (0, 8, 0, 'a1')
version_info = (0, 8, 0, 'a2')
_num_index = max(idx if isinstance(v, int) else 0
for idx, v in enumerate(version_info))
__version__ = '.'.join(map(str, version_info[:_num_index + 1])) + \
Expand Down
33 changes: 22 additions & 11 deletions mars/dataframe/arithmetic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from ...core import ENTITY_TYPE, recursive_tile
from ...serialization.serializables import AnyField, Float64Field
from ...tensor.core import TENSOR_TYPE, ChunkData, Chunk
from ...tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE, ChunkData, Chunk
from ...tensor.datasource import tensor as astensor
from ...utils import classproperty
from ...utils import classproperty, get_dtype
from ..align import align_series_series, align_dataframe_series, align_dataframe_dataframe
from ..core import DATAFRAME_TYPE, SERIES_TYPE, DATAFRAME_CHUNK_TYPE, SERIES_CHUNK_TYPE
from ..initializer import Series, DataFrame
Expand Down Expand Up @@ -282,8 +282,9 @@ def _operator(self):

@classmethod
def _calc_properties(cls, x1, x2=None, axis='columns'):
if isinstance(x1, (DATAFRAME_TYPE, DATAFRAME_CHUNK_TYPE)) \
and (x2 is None or pd.api.types.is_scalar(x2) or isinstance(x2, TENSOR_TYPE)):
if isinstance(x1, (DATAFRAME_TYPE, DATAFRAME_CHUNK_TYPE)) and (
x2 is None or pd.api.types.is_scalar(x2) or
isinstance(x2, (TENSOR_TYPE, TENSOR_CHUNK_TYPE))):
if x2 is None:
dtypes = x1.dtypes
elif pd.api.types.is_scalar(x2):
Expand All @@ -297,13 +298,16 @@ def _calc_properties(cls, x1, x2=None, axis='columns'):
return {'shape': x1.shape, 'dtypes': dtypes,
'columns_value': x1.columns_value, 'index_value': x1.index_value}

if isinstance(x1, (SERIES_TYPE, SERIES_CHUNK_TYPE)) \
and (x2 is None or pd.api.types.is_scalar(x2) or isinstance(x2, TENSOR_TYPE)):
if isinstance(x1, (SERIES_TYPE, SERIES_CHUNK_TYPE)) and (
x2 is None or pd.api.types.is_scalar(x2) or
isinstance(x2, (TENSOR_TYPE, TENSOR_CHUNK_TYPE))):
x2_dtype = x2.dtype if hasattr(x2, 'dtype') else type(x2)
dtype = infer_dtype(x1.dtype, np.dtype(x2_dtype), cls._operator)
x2_dtype = get_dtype(x2_dtype)
dtype = infer_dtype(x1.dtype, x2_dtype, cls._operator)
ret = {'shape': x1.shape, 'dtype': dtype, 'index_value': x1.index_value}
if pd.api.types.is_scalar(x2) or (hasattr(x2, 'ndim') and (x2.ndim == 0 or
x2.ndim == 1)):
if pd.api.types.is_scalar(x2) or (
hasattr(x2, 'ndim') and (
x2.ndim == 0 or x2.ndim == 1)):
ret['name'] = x1.name
return ret

Expand Down Expand Up @@ -406,11 +410,18 @@ def _calc_properties(cls, x1, x2=None, axis='columns'):
raise NotImplementedError('Unknown combination of parameters')

def _new_chunks(self, inputs, kws=None, **kw):
property_inputs = [inp for inp in inputs if isinstance(inp, (DATAFRAME_CHUNK_TYPE, SERIES_CHUNK_TYPE))]
property_inputs = [
inp for inp in inputs
if isinstance(inp, (DATAFRAME_CHUNK_TYPE, SERIES_CHUNK_TYPE, TENSOR_CHUNK_TYPE))]
if len(property_inputs) == 1:
properties = self._calc_properties(*property_inputs)
elif any(inp.ndim == 0 for inp in property_inputs):
if property_inputs[0].ndim == 0:
property_inputs = reversed(property_inputs)
properties = self._calc_properties(*property_inputs)
else:
df1, df2 = property_inputs if isinstance(property_inputs[0], DATAFRAME_CHUNK_TYPE) else \
df1, df2 = property_inputs \
if isinstance(property_inputs[0], DATAFRAME_CHUNK_TYPE) else \
reversed(property_inputs)
properties = self._calc_properties(df1, df2, axis=self.axis)

Expand Down
9 changes: 8 additions & 1 deletion mars/dataframe/arithmetic/tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

from mars.core import enter_mode
from mars.dataframe.initializer import DataFrame
from mars.dataframe.initializer import DataFrame, Series


def test_comp(setup):
Expand Down Expand Up @@ -56,3 +56,10 @@ def test_comp(setup):
r_df = op(df, datetime(2013, 1, 2))
pd.testing.assert_index_equal(r_df.index_value.to_pandas(),
df.index_value.to_pandas())

# test period type
raw = pd.period_range("2000-01-01", periods=10, freq="D")
raw_series = pd.Series(raw)
series = Series(raw, chunk_size=5)
r = series >= series[1]
pd.testing.assert_series_equal(r.to_pandas(), raw_series >= raw_series[1])
4 changes: 2 additions & 2 deletions mars/tensor/base/astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

from ... import opcodes as OperandDef
from ...serialization.serializables import KeyField, DataTypeField, StringField
from ...utils import get_dtype
from ..array_utils import as_same_device, device
from ..operands import TensorHasInput, TensorOperandMixin
from ..utils import get_order
Expand Down Expand Up @@ -150,7 +150,7 @@ def _astype(tensor, dtype, order='K', casting='unsafe', copy=True):
>>> x.astype(int).execute()
array([1, 2, 2])
"""
dtype = np.dtype(dtype)
dtype = get_dtype(dtype)
tensor_order = get_order(order, tensor.order)

if tensor.dtype == dtype and tensor.order == tensor_order:
Expand Down
7 changes: 7 additions & 0 deletions mars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,13 @@ def is_object_dtype(dtype: np.dtype) -> bool:
return False


def get_dtype(dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]):
if pd.api.types.is_extension_array_dtype(dtype):
return dtype
else:
return np.dtype(dtype)


def calc_object_overhead(chunk: ChunkType,
shape: Tuple[int]) -> int:
from .dataframe.core import DATAFRAME_CHUNK_TYPE, SERIES_CHUNK_TYPE, INDEX_CHUNK_TYPE
Expand Down

0 comments on commit ebc7033

Please sign in to comment.