Skip to content

Commit

Permalink
fix: Retire pytz library (feast-dev#4406)
Browse files Browse the repository at this point in the history
* fix: Remove pytz.

Signed-off-by: Shuchu Han <shuchu.han@gmail.com>

* fix: Keep the pytz.UTC part in dask.py

Signed-off-by: Shuchu Han <shuchu.han@gmail.com>

---------

Signed-off-by: Shuchu Han <shuchu.han@gmail.com>
  • Loading branch information
shuchu authored Aug 16, 2024
1 parent cebbe04 commit 23c6c86
Show file tree
Hide file tree
Showing 29 changed files with 109 additions and 133 deletions.
13 changes: 9 additions & 4 deletions sdk/python/feast/driver_test_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# This module generates dummy data to be used for tests and examples.
import itertools
from datetime import timedelta, timezone
from enum import Enum

import numpy as np
import pandas as pd
from pytz import FixedOffset, timezone, utc
from zoneinfo import ZoneInfo

from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
Expand All @@ -22,11 +23,15 @@ def _convert_event_timestamp(event_timestamp: pd.Timestamp, t: EventTimestampTyp
if t == EventTimestampType.TZ_NAIVE:
return event_timestamp
elif t == EventTimestampType.TZ_AWARE_UTC:
return event_timestamp.replace(tzinfo=utc)
return event_timestamp.replace(tzinfo=timezone.utc)
elif t == EventTimestampType.TZ_AWARE_FIXED_OFFSET:
return event_timestamp.replace(tzinfo=utc).astimezone(FixedOffset(60))
return event_timestamp.replace(tzinfo=timezone.utc).astimezone(
tz=timezone(timedelta(minutes=60))
)
elif t == EventTimestampType.TZ_AWARE_US_PACIFIC:
return event_timestamp.replace(tzinfo=utc).astimezone(timezone("US/Pacific"))
return event_timestamp.replace(tzinfo=timezone.utc).astimezone(
tz=ZoneInfo("US/Pacific")
)


def create_orders_df(
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/embedded_go/type_map.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from datetime import timezone
from typing import List

import pyarrow as pa
import pytz

from feast.protos.feast.types import Value_pb2
from feast.types import Array, PrimitiveFeastType

PA_TIMESTAMP_TYPE = pa.timestamp("s", tz=pytz.UTC)
PA_TIMESTAMP_TYPE = pa.timestamp("s", tz=timezone.utc)

ARROW_TYPE_TO_PROTO_FIELD = {
pa.int32(): "int32_val",
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/feature_logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
from datetime import timezone
from typing import TYPE_CHECKING, Dict, Optional, Type, cast

import pyarrow as pa
from pytz import UTC

from feast.data_source import DataSource
from feast.embedded_go.type_map import FEAST_TYPE_TO_ARROW_TYPE, PA_TIMESTAMP_TYPE
Expand Down Expand Up @@ -97,7 +97,7 @@ def get_schema(self, registry: "BaseRegistry") -> pa.Schema:
)

# system columns
fields[LOG_TIMESTAMP_FIELD] = pa.timestamp("us", tz=UTC)
fields[LOG_TIMESTAMP_FIELD] = pa.timestamp("us", tz=timezone.utc)
fields[LOG_DATE_FIELD] = pa.date32()
fields[REQUEST_ID_FIELD] = pa.string()

Expand Down
8 changes: 5 additions & 3 deletions sdk/python/feast/infra/materialization/snowflake_engine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import shutil
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from typing import Callable, List, Literal, Optional, Sequence, Union

import click
import pandas as pd
from colorama import Fore, Style
from pydantic import ConfigDict, Field, StrictStr
from pytz import utc
from tqdm import tqdm

import feast
Expand Down Expand Up @@ -276,7 +275,10 @@ def _materialize_one(
execute_snowflake_statement(conn, query).fetchall()[0][0]
/ 1_000_000_000
)
if last_commit_change_time < start_date.astimezone(tz=utc).timestamp():
if (
last_commit_change_time
< start_date.astimezone(tz=timezone.utc).timestamp()
):
return SnowflakeMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import uuid
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import (
Callable,
Expand All @@ -19,7 +19,6 @@
import pyarrow
import pyarrow as pa
from pydantic import StrictStr
from pytz import utc

from feast import OnDemandFeatureView
from feast.data_source import DataSource
Expand Down Expand Up @@ -100,8 +99,8 @@ def pull_latest_from_table_or_query(
athena_client = aws_utils.get_athena_data_client(config.offline_store.region)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT
Expand Down Expand Up @@ -151,7 +150,7 @@ def pull_all_from_table_or_query(
query = f"""
SELECT {field_string}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.astimezone(tz=utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' AND TIMESTAMP '{end_date.astimezone(tz=utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}'
WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.astimezone(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' AND TIMESTAMP '{end_date.astimezone(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}'
{"AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''}
"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from dataclasses import asdict
from datetime import datetime
from datetime import datetime, timezone
from typing import (
Any,
Callable,
Expand All @@ -20,7 +20,6 @@
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from psycopg import sql
from pytz import utc

from feast.data_source import DataSource
from feast.errors import InvalidEntityType, ZeroColumnQueryResult, ZeroRowsQueryResult
Expand Down Expand Up @@ -214,8 +213,8 @@ def pull_all_from_table_or_query(
join_key_columns + feature_name_columns + [timestamp_field]
)

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT {field_string}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile
import uuid
import warnings
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
Expand All @@ -14,7 +14,6 @@
from pydantic import StrictStr
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pytz import utc

from feast import FeatureView, OnDemandFeatureView
from feast.data_source import DataSource
Expand Down Expand Up @@ -284,8 +283,8 @@ def pull_all_from_table_or_query(

fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field])
from_expression = data_source.get_table_query_string()
start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT {fields}
Expand Down Expand Up @@ -520,21 +519,18 @@ def _upload_entity_df(
entity_df[event_timestamp_col], utc=True
)
spark_session.createDataFrame(entity_df).createOrReplaceTempView(table_name)
return
elif isinstance(entity_df, str):
spark_session.sql(entity_df).createOrReplaceTempView(table_name)
return
elif isinstance(entity_df, pyspark.sql.DataFrame):
entity_df.createOrReplaceTempView(table_name)
return
else:
raise InvalidEntityType(type(entity_df))


def _format_datetime(t: datetime) -> str:
# Since Hive does not support timezone, need to transform to utc.
if t.tzinfo:
t = t.astimezone(tz=utc)
t = t.astimezone(tz=timezone.utc)
dt = t.strftime("%Y-%m-%d %H:%M:%S.%f")
return dt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
```
"""

from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, Iterator, Optional, Set

import numpy as np
import pandas as pd
import pyarrow
from pytz import utc

from feast.infra.offline_stores.contrib.trino_offline_store.trino_queries import Trino
from feast.infra.offline_stores.contrib.trino_offline_store.trino_type_map import (
Expand Down Expand Up @@ -141,7 +140,7 @@ def _format_value(row: pd.Series, schema: Dict[str, Any]) -> str:

def format_datetime(t: datetime) -> str:
if t.tzinfo:
t = t.astimezone(tz=utc)
t = t.astimezone(tz=timezone.utc)
return t.strftime("%Y-%m-%d %H:%M:%S.%f")


Expand Down
12 changes: 8 additions & 4 deletions sdk/python/feast/infra/offline_stores/dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import uuid
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

Expand Down Expand Up @@ -178,6 +178,8 @@ def evaluate_historical_retrieval():
entity_df_event_timestamp_col_type = entity_df_with_features.dtypes[
entity_df_event_timestamp_col
]

# TODO: need to figure out why the value of entity_df_event_timestamp_col_type.tz is pytz.UTC
if (
not hasattr(entity_df_event_timestamp_col_type, "tz")
or entity_df_event_timestamp_col_type.tz != pytz.UTC
Expand All @@ -189,7 +191,7 @@ def evaluate_historical_retrieval():
].apply(
lambda x: x
if x.tzinfo is not None
else x.replace(tzinfo=pytz.utc)
else x.replace(tzinfo=timezone.utc)
)
)

Expand Down Expand Up @@ -616,6 +618,7 @@ def _normalize_timestamp(
if created_timestamp_column:
created_timestamp_column_type = df_to_join_types[created_timestamp_column]

# TODO: need to figure out why the value of timestamp_field_type.tz is pytz.UTC
if not hasattr(timestamp_field_type, "tz") or timestamp_field_type.tz != pytz.UTC:
# if you are querying for the event timestamp field, we have to deduplicate
if len(df_to_join[timestamp_field].shape) > 1:
Expand All @@ -624,10 +627,11 @@ def _normalize_timestamp(

# Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC
df_to_join[timestamp_field] = df_to_join[timestamp_field].apply(
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc),
lambda x: x if x.tzinfo else x.replace(tzinfo=timezone.utc),
meta=(timestamp_field, "datetime64[ns, UTC]"),
)

# TODO: need to figure out why the value of created_timestamp_column_type.tz is pytz.UTC
if created_timestamp_column and (
not hasattr(created_timestamp_column_type, "tz")
or created_timestamp_column_type.tz != pytz.UTC
Expand All @@ -640,7 +644,7 @@ def _normalize_timestamp(
df_to_join[created_timestamp_column] = df_to_join[
created_timestamp_column
].apply(
lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc),
lambda x: x if x.tzinfo else x.replace(tzinfo=timezone.utc),
meta=(timestamp_field, "datetime64[ns, UTC]"),
)

Expand Down
11 changes: 5 additions & 6 deletions sdk/python/feast/infra/offline_stores/ibis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import string
import uuid
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -12,7 +12,6 @@
import pyarrow
from ibis.expr import datatypes as dt
from ibis.expr.types import Table
from pytz import utc

from feast.data_source import DataSource
from feast.feature_logging import LoggingConfig, LoggingSource
Expand Down Expand Up @@ -55,8 +54,8 @@ def pull_latest_from_table_or_query_ibis(
fields = join_key_columns + feature_name_columns + [timestamp_field]
if created_timestamp_column:
fields.append(created_timestamp_column)
start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

table = data_source_reader(data_source)

Expand Down Expand Up @@ -265,8 +264,8 @@ def pull_all_from_table_or_query_ibis(
staging_location_endpoint_override: Optional[str] = None,
) -> RetrievalJob:
fields = join_key_columns + feature_name_columns + [timestamp_field]
start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

table = data_source_reader(data_source)

Expand Down
11 changes: 5 additions & 6 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import uuid
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import (
Any,
Expand All @@ -21,7 +21,6 @@
import pyarrow as pa
from dateutil import parser
from pydantic import StrictStr, model_validator
from pytz import utc

from feast import OnDemandFeatureView, RedshiftSource
from feast.data_source import DataSource
Expand Down Expand Up @@ -127,8 +126,8 @@ def pull_latest_from_table_or_query(
)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT
Expand Down Expand Up @@ -174,8 +173,8 @@ def pull_all_from_table_or_query(
)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

query = f"""
SELECT {field_string}
Expand Down
Loading

0 comments on commit 23c6c86

Please sign in to comment.