From 6cdd173e41398354a650d33f967bf635a6cccdca Mon Sep 17 00:00:00 2001 From: pradithya aria Date: Sun, 24 Feb 2019 11:26:14 +0800 Subject: [PATCH] Fix import spec created from Importer.from_csv (#143) --- sdk/python/feast/sdk/importer.py | 8 +- sdk/python/tests/sdk/test_importer.py | 188 ++++++++++++++------------ 2 files changed, 107 insertions(+), 89 deletions(-) diff --git a/sdk/python/feast/sdk/importer.py b/sdk/python/feast/sdk/importer.py index c7f31491b9..f9a97aad31 100644 --- a/sdk/python/feast/sdk/importer.py +++ b/sdk/python/feast/sdk/importer.py @@ -106,7 +106,8 @@ def from_csv(cls, path, entity, granularity, owner, staging_location=None, Returns: Importer: the importer for the dataset provided. """ - source_options = {"format": "csv"} + src_type = "file.csv" + source_options = {} source_options["path"], require_staging = \ _get_remote_location(path, staging_location) if is_gs_path(path): @@ -118,9 +119,10 @@ def from_csv(cls, path, entity, granularity, owner, staging_location=None, feature_columns, timestamp_column, timestamp_value, serving_store, warehouse_store, df) - iport_spec = _create_import("file", source_options, job_options, entity, schema) + iport_spec = _create_import(src_type, source_options, job_options, + entity, schema) - props = (_properties("csv", len(df.index), require_staging, + props = (_properties(src_type, len(df.index), require_staging, source_options["path"])) specs = _specs(iport_spec, Entity(name=entity), features) diff --git a/sdk/python/tests/sdk/test_importer.py b/sdk/python/tests/sdk/test_importer.py index c09cad0db8..7bf1c05c3f 100644 --- a/sdk/python/tests/sdk/test_importer.py +++ b/sdk/python/tests/sdk/test_importer.py @@ -15,7 +15,8 @@ import pandas as pd import pytest import ntpath -from feast.sdk.resources.feature import Feature, Granularity, ValueType, Datastore +from feast.sdk.resources.feature import Feature, Granularity, ValueType, \ + Datastore from feast.sdk.importer import _create_feature, Importer from feast.sdk.utils.gs_utils import is_gs_path from feast.types.Granularity_pb2 import Granularity as Granularity_pb2 @@ -30,56 +31,60 @@ def test_from_csv(self): staging_location = "gs://test-bucket" id_column = "driver_id" feature_columns = ["avg_distance_completed", - "avg_customer_distance_completed"] + "avg_customer_distance_completed"] timestamp_column = "ts" - importer = Importer.from_csv(path = csv_path, - entity = entity_name, - granularity = feature_granularity, - owner = owner, - staging_location=staging_location, - id_column = id_column, - feature_columns=feature_columns, - timestamp_column=timestamp_column) + importer = Importer.from_csv(path=csv_path, + entity=entity_name, + granularity=feature_granularity, + owner=owner, + staging_location=staging_location, + id_column=id_column, + feature_columns=feature_columns, + timestamp_column=timestamp_column) self._validate_csv_importer(importer, csv_path, entity_name, - feature_granularity, owner, staging_location, id_column, - feature_columns, timestamp_column) + feature_granularity, owner, + staging_location, id_column, + feature_columns, timestamp_column) def test_from_csv_id_column_not_specified(self): with pytest.raises(ValueError, - match="Column with name driver is not found") as e_info: + match="Column with name driver is not found"): feature_columns = ["avg_distance_completed", - "avg_customer_distance_completed"] + "avg_customer_distance_completed"] csv_path = "tests/data/driver_features.csv" - importer = Importer.from_csv(path = csv_path, - entity = "driver", - granularity = Granularity.DAY, - owner = "owner@feast.com", - staging_location="gs://test-bucket", - feature_columns=feature_columns, - timestamp_column="ts") + Importer.from_csv(path=csv_path, + entity="driver", + granularity=Granularity.DAY, + owner="owner@feast.com", + staging_location="gs://test-bucket", + feature_columns=feature_columns, + timestamp_column="ts") def test_from_csv_timestamp_column_not_specified(self): feature_columns = ["avg_distance_completed", - "avg_customer_distance_completed", "avg_distance_cancelled"] + "avg_customer_distance_completed", + "avg_distance_cancelled"] csv_path = "tests/data/driver_features.csv" entity_name = "driver" granularity = Granularity.DAY owner = "owner@feast.com" staging_location = "gs://test-bucket" id_column = "driver_id" - importer = Importer.from_csv(path = csv_path, - entity = entity_name, - granularity = granularity, - owner = owner, - staging_location=staging_location, - id_column = id_column, - feature_columns= feature_columns) + importer = Importer.from_csv(path=csv_path, + entity=entity_name, + granularity=granularity, + owner=owner, + staging_location=staging_location, + id_column=id_column, + feature_columns=feature_columns) self._validate_csv_importer(importer, csv_path, entity_name, - granularity, owner, staging_location = staging_location, - id_column=id_column, feature_columns=feature_columns) + granularity, owner, + staging_location=staging_location, + id_column=id_column, + feature_columns=feature_columns) def test_from_csv_feature_columns_not_specified(self): csv_path = "tests/data/driver_features.csv" @@ -89,43 +94,45 @@ def test_from_csv_feature_columns_not_specified(self): staging_location = "gs://test-bucket" id_column = "driver_id" timestamp_column = "ts" - importer = Importer.from_csv(path = csv_path, - entity = entity_name, - granularity = granularity, - owner = owner, - staging_location=staging_location, - id_column = id_column, - timestamp_column=timestamp_column) + importer = Importer.from_csv(path=csv_path, + entity=entity_name, + granularity=granularity, + owner=owner, + staging_location=staging_location, + id_column=id_column, + timestamp_column=timestamp_column) self._validate_csv_importer(importer, csv_path, entity_name, - granularity, owner, staging_location = staging_location, - id_column=id_column, timestamp_column=timestamp_column) + granularity, owner, + staging_location=staging_location, + id_column=id_column, + timestamp_column=timestamp_column) def test_from_csv_staging_location_not_specified(self): with pytest.raises(ValueError, - match="Specify staging_location for importing local file/dataframe") as e_info: + match="Specify staging_location for importing local file/dataframe"): feature_columns = ["avg_distance_completed", - "avg_customer_distance_completed"] + "avg_customer_distance_completed"] csv_path = "tests/data/driver_features.csv" - importer = Importer.from_csv(path = csv_path, - entity = "driver", - granularity = Granularity.DAY, - owner = "owner@feast.com", - feature_columns=feature_columns, - timestamp_column="ts") + Importer.from_csv(path=csv_path, + entity="driver", + granularity=Granularity.DAY, + owner="owner@feast.com", + feature_columns=feature_columns, + timestamp_column="ts") with pytest.raises(ValueError, - match="Staging location must be in GCS") as e_info: + match="Staging location must be in GCS") as e_info: feature_columns = ["avg_distance_completed", - "avg_customer_distance_completed"] + "avg_customer_distance_completed"] csv_path = "tests/data/driver_features.csv" - importer = Importer.from_csv(path = csv_path, - entity = "driver", - granularity = Granularity.DAY, - owner = "owner@feast.com", - staging_location = "/home", - feature_columns=feature_columns, - timestamp_column="ts") + Importer.from_csv(path=csv_path, + entity="driver", + granularity=Granularity.DAY, + owner="owner@feast.com", + staging_location="/home", + feature_columns=feature_columns, + timestamp_column="ts") def test_from_df(self): csv_path = "tests/data/driver_features.csv" @@ -133,59 +140,63 @@ def test_from_df(self): staging_location = "gs://test-bucket" entity = "driver" - importer = Importer.from_df(df = df, - entity = entity, - granularity = Granularity.DAY, - owner = "owner@feast.com", - staging_location=staging_location, - id_column = "driver_id", - timestamp_column="ts") - + importer = Importer.from_df(df=df, + entity=entity, + granularity=Granularity.DAY, + owner="owner@feast.com", + staging_location=staging_location, + id_column="driver_id", + timestamp_column="ts") assert importer.require_staging == True assert ("{}/tmp_{}".format(staging_location, entity) - in importer.remote_path) + in importer.remote_path) for feature in importer.features.values(): assert feature.name in df.columns assert feature.id == "driver.day." + feature.name import_spec = importer.spec assert import_spec.type == "file" - assert import_spec.sourceOptions == {"format" : "csv", "path" : importer.remote_path} + assert import_spec.sourceOptions == {"format": "csv", + "path": importer.remote_path} assert import_spec.entities == ["driver"] schema = import_spec.schema assert schema.entityIdColumn == "driver_id" assert schema.timestampValue is not None feature_columns = ["completed", "avg_distance_completed", - "avg_customer_distance_completed", - "avg_distance_cancelled"] + "avg_customer_distance_completed", + "avg_distance_cancelled"] for col, field in zip(df.columns.values, schema.fields): assert col == field.name if col in feature_columns: assert field.featureId == "driver.day." + col def _validate_csv_importer(self, - importer, csv_path, entity_name, feature_granularity, owner, - staging_location = None, id_column = None, feature_columns = None, - timestamp_column = None, timestamp_value = None): + importer, csv_path, entity_name, + feature_granularity, owner, + staging_location=None, id_column=None, + feature_columns=None, + timestamp_column=None, timestamp_value=None): df = pd.read_csv(csv_path) assert not importer.require_staging == is_gs_path(csv_path) if importer.require_staging: assert importer.remote_path == "{}/{}".format(staging_location, - ntpath.basename(csv_path)) + ntpath.basename( + csv_path)) # check features created for feature in importer.features.values(): assert feature.name in df.columns assert feature.id == "{}.{}.{}".format(entity_name, - Granularity_pb2.Enum.Name(feature_granularity.value).lower(), - feature.name) + Granularity_pb2.Enum.Name( + feature_granularity.value).lower(), + feature.name) import_spec = importer.spec - assert import_spec.type == "file" + assert import_spec.type == "file.csv" path = importer.remote_path if importer.require_staging else csv_path - assert import_spec.sourceOptions == {"format" : "csv", "path" : path} + assert import_spec.sourceOptions == {"path": path} assert import_spec.entities == [entity_name] schema = import_spec.schema @@ -204,19 +215,23 @@ def _validate_csv_importer(self, for col, field in zip(df.columns.values, schema.fields): assert col == field.name if col in feature_columns: - assert field.featureId == "{}.{}.{}".format(entity_name, - Granularity_pb2.Enum.Name(feature_granularity.value).lower(), col) + assert field.featureId == \ + "{}.{}.{}".format(entity_name, + Granularity_pb2.Enum.Name( + feature_granularity.value).lower(), + col) class TestHelpers: def test_create_feature(self): - col = pd.Series([1]*3,dtype='int32',name="test") + col = pd.Series([1] * 3, dtype='int32', name="test") expected = Feature(name="test", - entity="test", - granularity=Granularity.NONE, - owner="person", - value_type=ValueType.INT32) - actual = _create_feature(col, "test", Granularity.NONE, "person", None, None) + entity="test", + granularity=Granularity.NONE, + owner="person", + value_type=ValueType.INT32) + actual = _create_feature(col, "test", Granularity.NONE, "person", None, + None) assert actual.id == expected.id assert actual.value_type == expected.value_type assert actual.owner == expected.owner @@ -231,7 +246,8 @@ def test_create_feature_with_stores(self): serving_store=Datastore(id="SERVING"), warehouse_store=Datastore(id="WAREHOUSE")) actual = _create_feature(col, "test", Granularity.NONE, "person", - Datastore(id="SERVING"), Datastore(id="WAREHOUSE")) + Datastore(id="SERVING"), + Datastore(id="WAREHOUSE")) assert actual.id == expected.id assert actual.value_type == expected.value_type assert actual.owner == expected.owner