Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH/POC: infer resolution in array_strptime #55778

Merged
merged 3 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandas/_libs/tslibs/strptime.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ cdef class DatetimeParseState:
cdef:
bint found_tz
bint found_naive
bint creso_ever_changed
NPY_DATETIMEUNIT creso

cdef tzinfo process_datetime(self, datetime dt, tzinfo tz, bint utc_convert)
cdef bint update_creso(self, NPY_DATETIMEUNIT item_reso) noexcept
103 changes: 88 additions & 15 deletions pandas/_libs/tslibs/strptime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ from numpy cimport (

from pandas._libs.missing cimport checknull_with_nat_and_na
from pandas._libs.tslibs.conversion cimport get_datetime64_nanos
from pandas._libs.tslibs.dtypes cimport (
get_supported_reso,
npy_unit_to_abbrev,
)
from pandas._libs.tslibs.nattype cimport (
NPY_NAT,
c_nat_strings as nat_strings,
Expand All @@ -57,6 +61,7 @@ from pandas._libs.tslibs.np_datetime cimport (
NPY_DATETIMEUNIT,
NPY_FR_ns,
check_dts_bounds,
get_datetime64_unit,
import_pandas_datetime,
npy_datetimestruct,
npy_datetimestruct_to_datetime,
Expand Down Expand Up @@ -232,9 +237,21 @@ cdef _get_format_regex(str fmt):


cdef class DatetimeParseState:
def __cinit__(self):
def __cinit__(self, NPY_DATETIMEUNIT creso=NPY_DATETIMEUNIT.NPY_FR_ns):
self.found_tz = False
self.found_naive = False
self.creso = creso
self.creso_ever_changed = False

cdef bint update_creso(self, NPY_DATETIMEUNIT item_reso) noexcept:
# Return a bool indicating whether we bumped to a higher resolution
if self.creso == NPY_DATETIMEUNIT.NPY_FR_GENERIC:
self.creso = item_reso
elif item_reso > self.creso:
self.creso = item_reso
self.creso_ever_changed = True
return True
return False

cdef tzinfo process_datetime(self, datetime dt, tzinfo tz, bint utc_convert):
if dt.tzinfo is not None:
Expand Down Expand Up @@ -268,6 +285,7 @@ def array_strptime(
bint exact=True,
errors="raise",
bint utc=False,
NPY_DATETIMEUNIT creso=NPY_FR_ns,
):
"""
Calculates the datetime structs represented by the passed array of strings
Expand All @@ -278,6 +296,8 @@ def array_strptime(
fmt : string-like regex
exact : matches must be exact if True, search if False
errors : string specifying error handling, {'raise', 'ignore', 'coerce'}
creso : NPY_DATETIMEUNIT, default NPY_FR_ns
Set to NPY_FR_GENERIC to infer a resolution.
"""

cdef:
Expand All @@ -291,17 +311,22 @@ def array_strptime(
bint is_coerce = errors=="coerce"
tzinfo tz_out = None
bint iso_format = format_is_iso(fmt)
NPY_DATETIMEUNIT out_bestunit
NPY_DATETIMEUNIT out_bestunit, item_reso
int out_local = 0, out_tzoffset = 0
bint string_to_dts_succeeded = 0
DatetimeParseState state = DatetimeParseState()
bint infer_reso = creso == NPY_DATETIMEUNIT.NPY_FR_GENERIC
DatetimeParseState state = DatetimeParseState(creso)

assert is_raise or is_ignore or is_coerce

_validate_fmt(fmt)
format_regex, locale_time = _get_format_regex(fmt)

result = np.empty(n, dtype="M8[ns]")
if infer_reso:
abbrev = "ns"
else:
abbrev = npy_unit_to_abbrev(creso)
result = np.empty(n, dtype=f"M8[{abbrev}]")
iresult = result.view("i8")
result_timezone = np.empty(n, dtype="object")

Expand All @@ -318,20 +343,32 @@ def array_strptime(
iresult[i] = NPY_NAT
continue
elif PyDateTime_Check(val):
if isinstance(val, _Timestamp):
item_reso = val._creso
else:
item_reso = NPY_DATETIMEUNIT.NPY_FR_us
state.update_creso(item_reso)
tz_out = state.process_datetime(val, tz_out, utc)
if isinstance(val, _Timestamp):
iresult[i] = val.tz_localize(None).as_unit("ns")._value
val = (<_Timestamp>val)._as_creso(state.creso)
iresult[i] = val.tz_localize(None)._value
else:
iresult[i] = pydatetime_to_dt64(val.replace(tzinfo=None), &dts)
check_dts_bounds(&dts)
iresult[i] = pydatetime_to_dt64(
val.replace(tzinfo=None), &dts, reso=state.creso
)
check_dts_bounds(&dts, state.creso)
result_timezone[i] = val.tzinfo
continue
elif PyDate_Check(val):
iresult[i] = pydate_to_dt64(val, &dts)
check_dts_bounds(&dts)
item_reso = NPY_DATETIMEUNIT.NPY_FR_s
state.update_creso(item_reso)
iresult[i] = pydate_to_dt64(val, &dts, reso=state.creso)
check_dts_bounds(&dts, state.creso)
continue
elif is_datetime64_object(val):
iresult[i] = get_datetime64_nanos(val, NPY_FR_ns)
item_reso = get_supported_reso(get_datetime64_unit(val))
state.update_creso(item_reso)
iresult[i] = get_datetime64_nanos(val, state.creso)
continue
elif (
(is_integer_object(val) or is_float_object(val))
Expand All @@ -355,7 +392,9 @@ def array_strptime(
if string_to_dts_succeeded:
# No error reported by string_to_dts, pick back up
# where we left off
value = npy_datetimestruct_to_datetime(NPY_FR_ns, &dts)
item_reso = get_supported_reso(out_bestunit)
state.update_creso(item_reso)
value = npy_datetimestruct_to_datetime(state.creso, &dts)
if out_local == 1:
# Store the out_tzoffset in seconds
# since we store the total_seconds of
Expand All @@ -368,7 +407,9 @@ def array_strptime(
check_dts_bounds(&dts)
continue

if parse_today_now(val, &iresult[i], utc, NPY_FR_ns):
if parse_today_now(val, &iresult[i], utc, state.creso):
item_reso = NPY_DATETIMEUNIT.NPY_FR_us
state.update_creso(item_reso)
continue

# Some ISO formats can't be parsed by string_to_dts
Expand All @@ -380,9 +421,10 @@ def array_strptime(
raise ValueError(f"Time data {val} is not ISO8601 format")

tz = _parse_with_format(
val, fmt, exact, format_regex, locale_time, &dts
val, fmt, exact, format_regex, locale_time, &dts, &item_reso
)
iresult[i] = npy_datetimestruct_to_datetime(NPY_FR_ns, &dts)
state.update_creso(item_reso)
iresult[i] = npy_datetimestruct_to_datetime(state.creso, &dts)
check_dts_bounds(&dts)
result_timezone[i] = tz

Expand All @@ -403,11 +445,34 @@ def array_strptime(
raise
return values, []

if infer_reso:
if state.creso_ever_changed:
# We encountered mismatched resolutions, need to re-parse with
# the correct one.
return array_strptime(
values,
fmt=fmt,
exact=exact,
errors=errors,
utc=utc,
creso=state.creso,
)

# Otherwise we can use the single reso that we encountered and avoid
# a second pass.
abbrev = npy_unit_to_abbrev(state.creso)
result = iresult.base.view(f"M8[{abbrev}]")
return result, result_timezone.base


cdef tzinfo _parse_with_format(
str val, str fmt, bint exact, format_regex, locale_time, npy_datetimestruct* dts
str val,
str fmt,
bint exact,
format_regex,
locale_time,
npy_datetimestruct* dts,
NPY_DATETIMEUNIT* item_reso,
):
# Based on https://github.com/python/cpython/blob/main/Lib/_strptime.py#L293
cdef:
Expand Down Expand Up @@ -441,6 +506,8 @@ cdef tzinfo _parse_with_format(
f"time data \"{val}\" doesn't match format \"{fmt}\""
)

item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_s

iso_year = -1
year = 1900
month = day = 1
Expand Down Expand Up @@ -527,6 +594,12 @@ cdef tzinfo _parse_with_format(
elif parse_code == 10:
# e.g. val='10:10:10.100'; fmt='%H:%M:%S.%f'
s = found_dict["f"]
if len(s) <= 3:
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_ms
elif len(s) <= 6:
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_us
else:
item_reso[0] = NPY_DATETIMEUNIT.NPY_FR_ns
# Pad to always return nanoseconds
s += "0" * (9 - len(s))
us = long(s)
Expand Down
61 changes: 61 additions & 0 deletions pandas/tests/tslibs/test_strptime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from datetime import (
datetime,
timezone,
)

import numpy as np
import pytest

from pandas._libs.tslibs.dtypes import NpyDatetimeUnit
from pandas._libs.tslibs.strptime import array_strptime

from pandas import Timestamp
import pandas._testing as tm

creso_infer = NpyDatetimeUnit.NPY_FR_GENERIC.value


class TestArrayStrptimeResolutionInference:
@pytest.mark.parametrize("tz", [None, timezone.utc])
def test_array_strptime_resolution_inference_homogeneous_strings(self, tz):
dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz)

fmt = "%Y-%m-%d %H:%M:%S"
dtstr = dt.strftime(fmt)
arr = np.array([dtstr] * 3, dtype=object)
expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[s]")

res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
tm.assert_numpy_array_equal(res, expected)

fmt = "%Y-%m-%d %H:%M:%S.%f"
dtstr = dt.strftime(fmt)
arr = np.array([dtstr] * 3, dtype=object)
expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[us]")

res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
tm.assert_numpy_array_equal(res, expected)

fmt = "ISO8601"
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
tm.assert_numpy_array_equal(res, expected)

@pytest.mark.parametrize("tz", [None, timezone.utc])
def test_array_strptime_resolution_mixed(self, tz):
dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz)

ts = Timestamp(dt).as_unit("ns")

arr = np.array([dt, ts], dtype=object)
expected = np.array(
[Timestamp(dt).as_unit("ns").asm8, ts.asm8],
dtype="M8[ns]",
)

fmt = "%Y-%m-%d %H:%M:%S"
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
tm.assert_numpy_array_equal(res, expected)

fmt = "ISO8601"
res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer)
tm.assert_numpy_array_equal(res, expected)
Loading