Skip to content

Commit

Permalink
Add unit test for historical retrieval with panda dataframe
Browse files Browse the repository at this point in the history
Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>
  • Loading branch information
khorshuheng committed Oct 21, 2020
1 parent 3a65949 commit 5fab886
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 43 deletions.
5 changes: 3 additions & 2 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,9 @@ def get_historical_features(
feature_tables = self._get_feature_tables_from_feature_refs(
feature_refs, project
)
output_location = self._config.get(
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION
output_location = os.path.join(
self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION),
str(uuid.uuid4()),
)
output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/pyspark/launchers/standalone/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_output_file_uri(self, timeout_sec: int = None):
with self._process as p:
try:
p.wait(timeout_sec)
return self._output_file_uri
except Exception:
p.kill()
raise SparkJobFailure("Timeout waiting for subprocess to return")
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/staging/storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def list_files(self, bucket: str, path: str) -> List[str]:
raise NotImplementedError("list files not implemented for Local file")

def upload_file(self, local_path: str, bucket: str, remote_path: str):
dest_fpath = "/" + remote_path
dest_fpath = remote_path if remote_path.startswith("/") else "/" + remote_path
os.makedirs(os.path.dirname(dest_fpath), exist_ok=True)
shutil.copy(local_path, dest_fpath)

Expand Down
178 changes: 138 additions & 40 deletions sdk/python/tests/test_historical_feature_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
from contextlib import closing
from datetime import datetime
from typing import List, Tuple
from urllib.parse import urlparse

import grpc
import numpy as np
import pandas as pd
import pytest
from google.protobuf.duration_pb2 import Duration
from pandas.util.testing import assert_frame_equal
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import (
BooleanType,
Expand All @@ -19,6 +23,7 @@
StructType,
TimestampType,
)
from pytz import utc

from feast import Client, Entity, Feature, FeatureTable, FileSource, ValueType
from feast.core import CoreService_pb2_grpc as Core
Expand Down Expand Up @@ -82,6 +87,26 @@ def client(server):
return Client(core_url=f"localhost:{free_port}")


@pytest.yield_fixture()
def client_with_local_spark(tmpdir):
import pyspark

spark_staging_location = f"file://{os.path.join(tmpdir, 'staging')}"
historical_feature_output_location = (
f"file://{os.path.join(tmpdir, 'historical_feature_retrieval_output')}"
)

return Client(
core_url=f"localhost:{free_port}",
spark_launcher="standalone",
spark_standalone_master="local",
spark_home=os.path.dirname(pyspark.__file__),
spark_staging_location=spark_staging_location,
historical_feature_output_location=historical_feature_output_location,
historical_feature_output_format="parquet",
)


@pytest.fixture()
def driver_entity(client):
return client.apply_entity(Entity("driver_id", "description", ValueType.INT32))
Expand Down Expand Up @@ -116,36 +141,36 @@ def transactions_feature_table(spark, client):
df_data = [
(
1001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
50.0,
True,
),
(
1001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
100.0,
True,
),
(
2001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
400.0,
False,
),
(
1001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
200.0,
False,
),
(
1001,
datetime(year=2020, month=9, day=4),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=4, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
300.0,
False,
),
Expand Down Expand Up @@ -180,20 +205,20 @@ def bookings_feature_table(spark, client):
df_data = [
(
8001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
100,
),
(
8001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
150,
),
(
8002,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
200,
),
]
Expand Down Expand Up @@ -225,20 +250,20 @@ def bookings_feature_table_with_mapping(spark, client):
df_data = [
(
8001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1, tzinfo=utc),
datetime(year=2020, month=9, day=1, tzinfo=utc),
100,
),
(
8001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
150,
),
(
8002,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2, tzinfo=utc),
datetime(year=2020, month=9, day=2, tzinfo=utc),
200,
),
]
Expand Down Expand Up @@ -273,12 +298,12 @@ def test_historical_feature_retrieval_from_local_spark_session(
]
)
df_data = [
(1001, 8001, datetime(year=2020, month=9, day=1),),
(2001, 8001, datetime(year=2020, month=9, day=2),),
(2001, 8002, datetime(year=2020, month=9, day=1),),
(1001, 8001, datetime(year=2020, month=9, day=2),),
(1001, 8001, datetime(year=2020, month=9, day=3),),
(1001, 8001, datetime(year=2020, month=9, day=4),),
(1001, 8001, datetime(year=2020, month=9, day=1, tzinfo=utc)),
(2001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)),
(2001, 8002, datetime(year=2020, month=9, day=1, tzinfo=utc)),
(1001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)),
(1001, 8001, datetime(year=2020, month=9, day=3, tzinfo=utc)),
(1001, 8001, datetime(year=2020, month=9, day=4, tzinfo=utc)),
]
temp_dir, file_uri = create_temp_parquet_file(
spark, "customer_driver_pair", schema, df_data
Expand All @@ -300,12 +325,12 @@ def test_historical_feature_retrieval_from_local_spark_session(
]
)
expected_joined_df_data = [
(1001, 8001, datetime(year=2020, month=9, day=1), 100.0, 100),
(2001, 8001, datetime(year=2020, month=9, day=2), 400.0, 150),
(2001, 8002, datetime(year=2020, month=9, day=1), 400.0, None),
(1001, 8001, datetime(year=2020, month=9, day=2), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=3), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=4), 300.0, None),
(1001, 8001, datetime(year=2020, month=9, day=1, tzinfo=utc), 100.0, 100),
(2001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc), 400.0, 150),
(2001, 8002, datetime(year=2020, month=9, day=1, tzinfo=utc), 400.0, None),
(1001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=3, tzinfo=utc), 200.0, 150),
(1001, 8001, datetime(year=2020, month=9, day=4, tzinfo=utc), 300.0, None),
]
expected_joined_df = spark.createDataFrame(
spark.sparkContext.parallelize(expected_joined_df_data),
Expand All @@ -325,9 +350,9 @@ def test_historical_feature_retrieval_with_field_mappings_from_local_spark_sessi
]
)
df_data = [
(8001, datetime(year=2020, month=9, day=1)),
(8001, datetime(year=2020, month=9, day=2)),
(8002, datetime(year=2020, month=9, day=1)),
(8001, datetime(year=2020, month=9, day=1, tzinfo=utc)),
(8001, datetime(year=2020, month=9, day=2, tzinfo=utc)),
(8002, datetime(year=2020, month=9, day=1, tzinfo=utc)),
]
temp_dir, file_uri = create_temp_parquet_file(spark, "drivers", schema, df_data)
entity_source = FileSource(
Expand All @@ -344,13 +369,86 @@ def test_historical_feature_retrieval_with_field_mappings_from_local_spark_sessi
]
)
expected_joined_df_data = [
(8001, datetime(year=2020, month=9, day=1), 100),
(8001, datetime(year=2020, month=9, day=2), 150),
(8002, datetime(year=2020, month=9, day=1), None),
(8001, datetime(year=2020, month=9, day=1, tzinfo=utc), 100),
(8001, datetime(year=2020, month=9, day=2, tzinfo=utc), 150),
(8002, datetime(year=2020, month=9, day=1, tzinfo=utc), None),
]
expected_joined_df = spark.createDataFrame(
spark.sparkContext.parallelize(expected_joined_df_data),
expected_joined_df_schema,
)
assert_dataframe_equal(joined_df, expected_joined_df)
shutil.rmtree(temp_dir)


@pytest.mark.usefixtures(
"driver_entity",
"customer_entity",
"bookings_feature_table",
"transactions_feature_table",
)
def test_historical_feature_retrieval_with_pandas_dataframe_input(
client_with_local_spark,
):

customer_driver_pairs_pandas_df = pd.DataFrame(
np.array(
[
[1001, 8001, datetime(year=2020, month=9, day=1, tzinfo=utc)],
[2001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)],
[2001, 8002, datetime(year=2020, month=9, day=1, tzinfo=utc)],
[1001, 8001, datetime(year=2020, month=9, day=2, tzinfo=utc)],
[1001, 8001, datetime(year=2020, month=9, day=3, tzinfo=utc)],
[1001, 8001, datetime(year=2020, month=9, day=4, tzinfo=utc)],
]
),
columns=["customer_id", "driver_id", "event_timestamp"],
)
customer_driver_pairs_pandas_df = customer_driver_pairs_pandas_df.astype(
{"customer_id": "int32", "driver_id": "int32"}
)

job_output = client_with_local_spark.get_historical_features(
["transactions:total_transactions", "bookings:total_completed_bookings"],
customer_driver_pairs_pandas_df,
)

output_dir = job_output.get_output_file_uri()
joined_df = pd.read_parquet(urlparse(output_dir).path)

expected_joined_df = pd.DataFrame(
np.array(
[
[1001, 8001, datetime(year=2020, month=9, day=1), 100.0, 100],
[2001, 8001, datetime(year=2020, month=9, day=2), 400.0, 150],
[2001, 8002, datetime(year=2020, month=9, day=1), 400.0, None],
[1001, 8001, datetime(year=2020, month=9, day=2), 200.0, 150],
[1001, 8001, datetime(year=2020, month=9, day=3), 200.0, 150],
[1001, 8001, datetime(year=2020, month=9, day=4), 300.0, None],
]
),
columns=[
"customer_id",
"driver_id",
"event_timestamp",
"transactions__total_transactions",
"bookings__total_completed_bookings",
],
)
expected_joined_df = expected_joined_df.astype(
{
"customer_id": "int32",
"driver_id": "int32",
"transactions__total_transactions": "float64",
"bookings__total_completed_bookings": "float64",
}
)

assert_frame_equal(
joined_df.sort_values(
by=["customer_id", "driver_id", "event_timestamp"]
).reset_index(drop=True),
expected_joined_df.sort_values(
by=["customer_id", "driver_id", "event_timestamp"]
).reset_index(drop=True),
)

0 comments on commit 5fab886

Please sign in to comment.