Skip to content

Commit

Permalink
EA: Tighten signature on DatetimeArray._from_sequence (#36718)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Oct 2, 2020
1 parent cc238b9 commit 456fcb9
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 15 deletions.
6 changes: 5 additions & 1 deletion pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,11 @@ def _simple_new(
return result

@classmethod
def _from_sequence(
def _from_sequence(cls, scalars, dtype=None, copy: bool = False):
return cls._from_sequence_not_strict(scalars, dtype=dtype, copy=copy)

@classmethod
def _from_sequence_not_strict(
cls,
data,
dtype=None,
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __new__(

name = maybe_extract_name(name, data, cls)

dtarr = DatetimeArray._from_sequence(
dtarr = DatetimeArray._from_sequence_not_strict(
data,
dtype=dtype,
copy=copy,
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,9 @@ def na_accum_func(values: ArrayLike, accum_func, skipna: bool) -> ArrayLike:
result = result.view(orig_dtype)
else:
# DatetimeArray
result = type(values)._from_sequence(result, dtype=orig_dtype)
result = type(values)._simple_new( # type: ignore[attr-defined]
result, dtype=orig_dtype
)

elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)):
vals = values.copy()
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/arrays/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def test_array_copy():
datetime.datetime(2000, 1, 1, tzinfo=cet),
datetime.datetime(2001, 1, 1, tzinfo=cet),
],
DatetimeArray._from_sequence(["2000", "2001"], tz=cet),
DatetimeArray._from_sequence(
["2000", "2001"], dtype=pd.DatetimeTZDtype(tz=cet)
),
),
# timedelta
(
Expand Down
27 changes: 19 additions & 8 deletions pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_mixing_naive_tzaware_raises(self, meth):
def test_from_pandas_array(self):
arr = pd.array(np.arange(5, dtype=np.int64)) * 3600 * 10 ** 9

result = DatetimeArray._from_sequence(arr, freq="infer")
result = DatetimeArray._from_sequence(arr)._with_freq("infer")

expected = pd.date_range("1970-01-01", periods=5, freq="H")._data
tm.assert_datetime_array_equal(result, expected)
Expand Down Expand Up @@ -162,7 +162,9 @@ def test_cmp_dt64_arraylike_tznaive(self, all_compare_operators):

class TestDatetimeArray:
def test_astype_to_same(self):
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
arr = DatetimeArray._from_sequence(
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
)
result = arr.astype(DatetimeTZDtype(tz="US/Central"), copy=False)
assert result is arr

Expand Down Expand Up @@ -193,7 +195,9 @@ def test_astype_int(self, dtype):
tm.assert_numpy_array_equal(result, expected)

def test_tz_setter_raises(self):
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
arr = DatetimeArray._from_sequence(
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
)
with pytest.raises(AttributeError, match="tz_localize"):
arr.tz = "UTC"

Expand Down Expand Up @@ -282,7 +286,8 @@ def test_fillna_preserves_tz(self, method):

fill_val = dti[1] if method == "pad" else dti[3]
expected = DatetimeArray._from_sequence(
[dti[0], dti[1], fill_val, dti[3], dti[4]], freq=None, tz="US/Central"
[dti[0], dti[1], fill_val, dti[3], dti[4]],
dtype=DatetimeTZDtype(tz="US/Central"),
)

result = arr.fillna(method=method)
Expand Down Expand Up @@ -434,19 +439,24 @@ def test_shift_value_tzawareness_mismatch(self):

class TestSequenceToDT64NS:
def test_tz_dtype_mismatch_raises(self):
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
arr = DatetimeArray._from_sequence(
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
)
with pytest.raises(TypeError, match="data is already tz-aware"):
sequence_to_dt64ns(arr, dtype=DatetimeTZDtype(tz="UTC"))

def test_tz_dtype_matches(self):
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
arr = DatetimeArray._from_sequence(
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
)
result, _, _ = sequence_to_dt64ns(arr, dtype=DatetimeTZDtype(tz="US/Central"))
tm.assert_numpy_array_equal(arr._data, result)


class TestReductions:
@pytest.mark.parametrize("tz", [None, "US/Central"])
def test_min_max(self, tz):
dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
arr = DatetimeArray._from_sequence(
[
"2000-01-03",
Expand All @@ -456,7 +466,7 @@ def test_min_max(self, tz):
"2000-01-05",
"2000-01-04",
],
tz=tz,
dtype=dtype,
)

result = arr.min()
Expand All @@ -476,7 +486,8 @@ def test_min_max(self, tz):
@pytest.mark.parametrize("tz", [None, "US/Central"])
@pytest.mark.parametrize("skipna", [True, False])
def test_min_max_empty(self, skipna, tz):
arr = DatetimeArray._from_sequence([], tz=tz)
dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
arr = DatetimeArray._from_sequence([], dtype=dtype)
result = arr.min(skipna=skipna)
assert result is pd.NaT

Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,10 @@ def test_concat_mixed_dtypes(self, data):
@pytest.mark.parametrize("obj", ["series", "frame"])
def test_unstack(self, obj):
# GH-13287: can't use base test, since building the expected fails.
dtype = DatetimeTZDtype(tz="US/Central")
data = DatetimeArray._from_sequence(
["2000", "2001", "2002", "2003"], tz="US/Central"
["2000", "2001", "2002", "2003"],
dtype=dtype,
)
index = pd.MultiIndex.from_product(([["A", "B"], ["a", "b"]]), names=["a", "b"])

Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/indexes/datetimes/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


class TestDatetimeIndex:
@pytest.mark.parametrize("dt_cls", [DatetimeIndex, DatetimeArray._from_sequence])
@pytest.mark.parametrize(
"dt_cls", [DatetimeIndex, DatetimeArray._from_sequence_not_strict]
)
def test_freq_validation_with_nat(self, dt_cls):
# GH#11587 make sure we get a useful error message when generate_range
# raises
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/scalar/test_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pandas import (
DatetimeIndex,
DatetimeTZDtype,
Index,
NaT,
Period,
Expand Down Expand Up @@ -440,7 +441,9 @@ def test_nat_rfloordiv_timedelta(val, expected):
DatetimeIndex(["2011-01-01", "2011-01-02"], name="x"),
DatetimeIndex(["2011-01-01", "2011-01-02"], tz="US/Eastern", name="x"),
DatetimeArray._from_sequence(["2011-01-01", "2011-01-02"]),
DatetimeArray._from_sequence(["2011-01-01", "2011-01-02"], tz="US/Pacific"),
DatetimeArray._from_sequence(
["2011-01-01", "2011-01-02"], dtype=DatetimeTZDtype(tz="US/Pacific")
),
TimedeltaIndex(["1 day", "2 day"], name="x"),
],
)
Expand Down

0 comments on commit 456fcb9

Please sign in to comment.