diff --git a/sdk/python/feast/client.py b/sdk/python/feast/client.py index fb5fe6ffc4..543f0afeb6 100644 --- a/sdk/python/feast/client.py +++ b/sdk/python/feast/client.py @@ -21,7 +21,6 @@ from collections import OrderedDict from math import ceil from typing import Dict, List, Tuple, Union, Optional -from typing import List from urllib.parse import urlparse import fastavro @@ -29,6 +28,7 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq + from feast.core.CoreService_pb2 import ( GetFeastCoreVersionRequest, ListFeatureSetsResponse, @@ -48,11 +48,11 @@ from feast.core.FeatureSet_pb2 import FeatureSetStatus from feast.feature_set import FeatureSet, Entity from feast.job import Job -from feast.serving.ServingService_pb2 import FeatureReference from feast.loaders.abstract_producer import get_producer from feast.loaders.file import export_source_to_staging_location from feast.loaders.ingest import KAFKA_CHUNK_PRODUCTION_TIMEOUT from feast.loaders.ingest import get_feature_row_chunks +from feast.serving.ServingService_pb2 import FeatureReference from feast.serving.ServingService_pb2 import GetFeastServingInfoResponse from feast.serving.ServingService_pb2 import ( GetOnlineFeaturesRequest, @@ -69,9 +69,11 @@ GRPC_CONNECTION_TIMEOUT_DEFAULT = 3 # type: int GRPC_CONNECTION_TIMEOUT_APPLY = 600 # type: int -FEAST_SERVING_URL_ENV_KEY = "FEAST_SERVING_URL" # type: str -FEAST_CORE_URL_ENV_KEY = "FEAST_CORE_URL" # type: str -FEAST_PROJECT_ENV_KEY = "FEAST_PROJECT" # type: str +FEAST_CORE_URL_ENV_KEY = "FEAST_CORE_URL" +FEAST_SERVING_URL_ENV_KEY = "FEAST_SERVING_URL" +FEAST_PROJECT_ENV_KEY = "FEAST_PROJECT" +FEAST_CORE_SECURE_ENV_KEY = "FEAST_CORE_SECURE" +FEAST_SERVING_SECURE_ENV_KEY = "FEAST_SERVING_SECURE" BATCH_FEATURE_REQUEST_WAIT_TIME_SECONDS = 300 CPU_COUNT = os.cpu_count() # type: int @@ -82,7 +84,8 @@ class Client: """ def __init__( - self, core_url: str = None, serving_url: str = None, project: str = None + self, core_url: str = None, serving_url: str = None, project: str = None, + core_secure: bool = None, serving_secure: bool = None ): """ The Feast Client should be initialized with at least one service url @@ -91,10 +94,14 @@ def __init__( core_url: Feast Core URL. Used to manage features serving_url: Feast Serving URL. Used to retrieve features project: Sets the active project. This field is optional. - """ - self._core_url = core_url - self._serving_url = serving_url - self._project = project + core_secure: Use client-side SSL/TLS for Core gRPC API + serving_secure: Use client-side SSL/TLS for Serving gRPC API + """ + self._core_url: str = core_url + self._serving_url: str = serving_url + self._project: str = project + self._core_secure: bool = core_secure + self._serving_secure: bool = serving_secure self.__core_channel: grpc.Channel = None self.__serving_channel: grpc.Channel = None self._core_service_stub: CoreServiceStub = None @@ -149,6 +156,52 @@ def serving_url(self, value: str): """ self._serving_url = value + @property + def core_secure(self) -> bool: + """ + Retrieve Feast Core client-side SSL/TLS setting + + Returns: + Whether client-side SSL/TLS is enabled + """ + + if self._core_secure is not None: + return self._core_secure + return os.getenv(FEAST_CORE_SECURE_ENV_KEY, "").lower() is "true" + + @core_secure.setter + def core_secure(self, value: bool): + """ + Set the Feast Core client-side SSL/TLS setting + + Args: + value: True to enable client-side SSL/TLS + """ + self._core_secure = value + + @property + def serving_secure(self) -> bool: + """ + Retrieve Feast Serving client-side SSL/TLS setting + + Returns: + Whether client-side SSL/TLS is enabled + """ + + if self._serving_secure is not None: + return self._serving_secure + return os.getenv(FEAST_SERVING_SECURE_ENV_KEY, "").lower() is "true" + + @serving_secure.setter + def serving_secure(self, value: bool): + """ + Set the Feast Serving client-side SSL/TLS setting + + Args: + value: True to enable client-side SSL/TLS + """ + self._serving_secure = value + def version(self): """ Returns version information from Feast Core and Feast Serving @@ -185,7 +238,10 @@ def _connect_core(self, skip_if_connected: bool = True): raise ValueError("Please set Feast Core URL.") if self.__core_channel is None: - self.__core_channel = grpc.insecure_channel(self.core_url) + if self.core_secure or self.core_url.endswith(":443"): + self.__core_channel = grpc.secure_channel(self.core_url, grpc.ssl_channel_credentials()) + else: + self.__core_channel = grpc.insecure_channel(self.core_url) try: grpc.channel_ready_future(self.__core_channel).result( @@ -214,7 +270,10 @@ def _connect_serving(self, skip_if_connected=True): raise ValueError("Please set Feast Serving URL.") if self.__serving_channel is None: - self.__serving_channel = grpc.insecure_channel(self.serving_url) + if self.serving_secure or self.serving_url.endswith(":443"): + self.__serving_channel = grpc.secure_channel(self.serving_url, grpc.ssl_channel_credentials()) + else: + self.__serving_channel = grpc.insecure_channel(self.serving_url) try: grpc.channel_ready_future(self.__serving_channel).result( diff --git a/sdk/python/requirements-ci.txt b/sdk/python/requirements-ci.txt index d0fdd76e49..31818ba7f7 100644 --- a/sdk/python/requirements-ci.txt +++ b/sdk/python/requirements-ci.txt @@ -12,6 +12,7 @@ mock==2.0.0 pandas==0.* protobuf==3.* pytest +pytest-lazy-fixture==0.6.3 pytest-mock pytest-timeout PyYAML==5.1.* diff --git a/sdk/python/tests/data/localhost.crt b/sdk/python/tests/data/localhost.crt new file mode 100644 index 0000000000..1f471506aa --- /dev/null +++ b/sdk/python/tests/data/localhost.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC5zCCAc+gAwIBAgIJAKzukpnyuwsVMA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV +BAMMCWxvY2FsaG9zdDAgFw0yMDAyMTcxMTE4NDNaGA8zMDE5MDYyMDExMTg0M1ow +FDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAqoanhiy4EUZjPA/m8IWk50OyTjKAnqZvEW5glqmTHP6lQbfyWQnzj3Ny +c++4Xn901FO2v07h+7lE3BScjgCX6klsLOHRnWcLX8lQygR6zzO+Oey1yXuCebBA +yhrsqgTDC/8zoCxe0W3t0vqvE4AJs3tJHq5Y1ba/X9OiKKsDZuMSSsbdd4qVEL6y +BD8PRNLT/iiD84Kq58GZtOI3fJls8E/bYbvksugcPI3kmlU4Plg3VrVplMl3DcMz +7BbvQP6jmVqdPtUT7+lL0C5CsNqbdDOIwg09+Gwus+A/g8PerBBd+ZCmdvSa9LYJ +OmlJszgZPIL9AagXLfuGQvNN2Y6WowIDAQABozowODAUBgNVHREEDTALgglsb2Nh +bGhvc3QwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA0GCSqGSIb3 +DQEBCwUAA4IBAQAuF1/VeQL73Y1FKrBX4bAb/Rdh2+Dadpi+w1pgEOi3P4udmQ+y +Xn9GwwLRQmHRLjyCT5KT8lNHdldPdlBamqPGGku449aCAjA/YHVHhcHaXl0MtPGq +BfKhHYSsvI2sIymlzZIvvIaf04yuJ1g+L0j8Px4Ecor9YwcKDZmpnIXLgdUtUrIQ +5Omrb4jImX6q8jp6Bjplb4H3o4TqKoa74NLOWUiH5/Rix3Lo8MRoEVbX2GhKk+8n +0eD3AuyrI1i+ce7zY8qGJKKFHGLDWPA/+006ZIS4j/Hr2FWo07CPFQ4/3gdJ8Erw +SzgO9vvIhQrBJn2CIH4+P5Cb1ktdobNWW9XK +-----END CERTIFICATE----- diff --git a/sdk/python/tests/data/localhost.key b/sdk/python/tests/data/localhost.key new file mode 100644 index 0000000000..dbd9cda062 --- /dev/null +++ b/sdk/python/tests/data/localhost.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCqhqeGLLgRRmM8 +D+bwhaTnQ7JOMoCepm8RbmCWqZMc/qVBt/JZCfOPc3Jz77hef3TUU7a/TuH7uUTc +FJyOAJfqSWws4dGdZwtfyVDKBHrPM7457LXJe4J5sEDKGuyqBMML/zOgLF7Rbe3S ++q8TgAmze0kerljVtr9f06IoqwNm4xJKxt13ipUQvrIEPw9E0tP+KIPzgqrnwZm0 +4jd8mWzwT9thu+Sy6Bw8jeSaVTg+WDdWtWmUyXcNwzPsFu9A/qOZWp0+1RPv6UvQ +LkKw2pt0M4jCDT34bC6z4D+Dw96sEF35kKZ29Jr0tgk6aUmzOBk8gv0BqBct+4ZC +803ZjpajAgMBAAECggEADE4FHphxe8WheX8IQgjSumFXJ29bepc14oMdcyGvXOM/ +F3vnf+dI7Ov+sUD2A9OcoYmc4TcW9WwL/Pl7xn9iduRvatmsn3gFCRdkvf8OwY7R +Riq/f1drNc6zDiJdO3N2g5IZrpAlE2WkSJoQMg8GJC5cO1uHS3yRWJ/Tzq1wZGcW +Dot9hAFgN0qNdP0xFkOsPM5ptC3DjLqsZWboJhIM19hgsIYaWQWHvcYlCcWTVhkj +FYzvLj5GrzAgyE89RpdXus670q5E2R2Rlnja21TfcxK0UOdIrKghZ0jxZMsXEwdB +8V7kIzL5kh//RhT/dIt0mHNMSdLFFx3yMTb2wTzpWQKBgQDRiCRslDSjiNSFySkn +6IivAwJtV2gLSxV05D9u9lrrlskHogrZUJkpVF1VzSnwv/ASaCZX4AGTtNPaz+vy +yDviwfjADsuum8jkzoxKCHnR1HVMyX+vm/g+pE20PMskTUuDE4zROtrqo9Ky0afv +94mJrf93Q815rsbEM5osugaeBQKBgQDQWAPTKy1wcG7edwfu3EaLYHPZ8pW9MldP +FvCLTMwSDkSzU+wA4BGE/5Tuu0WHSAfUc5C1LnMQXKBQXun+YCaBR6GZjUAmntz3 +poBIOYaxe651zqzCmo4ip1h5wIfPvynsyGmhsbpDSNhvXFgH2mF3XSY1nduKSRHu +389cHk3ahwKBgA4gAWSYcRv9I2aJcw7PrDcwGr/IPqlUPHQO1v/h96seFRtAnz6b +IlgY6dnY5NTn+4UiJEOUREbyz71Weu949CCLNvurg6uXsOlLy0VKYPv2OJoek08B +UrDWXq6h0of19fs2HC4Wq59Zv+ByJcIVi94OLsSZe4aSc6/SUrhlKgEJAoGBAIvR +5Y88NNx2uBEYdPx6W+WBr34e7Rrxw+JSFNCHk5SyeqyWr5XOyjMliv/EMl8dmhOc +Ewtkxte+MeB+Mi8CvBSay/rO7rR8fPK+jOzrnldSF7z8HLjlHGppQFlFOl/TfQFp +ZmqbadNp+caShImQp0SCAPiOnh1p+F0FWpYJyFnVAoGAKhSRP0iUmd+tId94px2m +G248BhcM9/0r+Y3yRX1eBx5eBzlzPUPcW1MSbhiZ1DIyLZ/MyObl98A1oNBGun11 +H/7Mq0E8BcJoXmt/6Z+2NhREBV9tDNuINyS/coYBV7H50pnSqyPpREPxNmu3Ukbm +u7ggLRfH+DexDysbpbCZ9l4= +-----END PRIVATE KEY----- diff --git a/sdk/python/tests/data/localhost.pem b/sdk/python/tests/data/localhost.pem new file mode 100644 index 0000000000..1f471506aa --- /dev/null +++ b/sdk/python/tests/data/localhost.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC5zCCAc+gAwIBAgIJAKzukpnyuwsVMA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV +BAMMCWxvY2FsaG9zdDAgFw0yMDAyMTcxMTE4NDNaGA8zMDE5MDYyMDExMTg0M1ow +FDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAqoanhiy4EUZjPA/m8IWk50OyTjKAnqZvEW5glqmTHP6lQbfyWQnzj3Ny +c++4Xn901FO2v07h+7lE3BScjgCX6klsLOHRnWcLX8lQygR6zzO+Oey1yXuCebBA +yhrsqgTDC/8zoCxe0W3t0vqvE4AJs3tJHq5Y1ba/X9OiKKsDZuMSSsbdd4qVEL6y +BD8PRNLT/iiD84Kq58GZtOI3fJls8E/bYbvksugcPI3kmlU4Plg3VrVplMl3DcMz +7BbvQP6jmVqdPtUT7+lL0C5CsNqbdDOIwg09+Gwus+A/g8PerBBd+ZCmdvSa9LYJ +OmlJszgZPIL9AagXLfuGQvNN2Y6WowIDAQABozowODAUBgNVHREEDTALgglsb2Nh +bGhvc3QwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA0GCSqGSIb3 +DQEBCwUAA4IBAQAuF1/VeQL73Y1FKrBX4bAb/Rdh2+Dadpi+w1pgEOi3P4udmQ+y +Xn9GwwLRQmHRLjyCT5KT8lNHdldPdlBamqPGGku449aCAjA/YHVHhcHaXl0MtPGq +BfKhHYSsvI2sIymlzZIvvIaf04yuJ1g+L0j8Px4Ecor9YwcKDZmpnIXLgdUtUrIQ +5Omrb4jImX6q8jp6Bjplb4H3o4TqKoa74NLOWUiH5/Rix3Lo8MRoEVbX2GhKk+8n +0eD3AuyrI1i+ce7zY8qGJKKFHGLDWPA/+006ZIS4j/Hr2FWo07CPFQ4/3gdJ8Erw +SzgO9vvIhQrBJn2CIH4+P5Cb1ktdobNWW9XK +-----END CERTIFICATE----- diff --git a/sdk/python/tests/test_client.py b/sdk/python/tests/test_client.py index 123cbe47fd..2724fff52e 100644 --- a/sdk/python/tests/test_client.py +++ b/sdk/python/tests/test_client.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pkgutil from datetime import datetime import tempfile +from unittest import mock + import grpc import pandas as pd from google.protobuf.duration_pb2 import Duration @@ -63,10 +66,38 @@ CORE_URL = "core.feast.example.com" SERVING_URL = "serving.example.com" +_PRIVATE_KEY_RESOURCE_PATH = 'data/localhost.key' +_CERTIFICATE_CHAIN_RESOURCE_PATH = 'data/localhost.pem' +_ROOT_CERTIFICATE_RESOURCE_PATH = 'data/localhost.crt' class TestClient: - @pytest.fixture(scope="function") + + @pytest.fixture + def secure_mock_client(self, mocker): + client = Client(core_url=CORE_URL, serving_url=SERVING_URL, core_secure=True, serving_secure=True) + mocker.patch.object(client, "_connect_core") + mocker.patch.object(client, "_connect_serving") + client._core_url = CORE_URL + client._serving_url = SERVING_URL + return client + + @pytest.fixture + def mock_client(self, mocker): + client = Client(core_url=CORE_URL, serving_url=SERVING_URL) + mocker.patch.object(client, "_connect_core") + mocker.patch.object(client, "_connect_serving") + client._core_url = CORE_URL + client._serving_url = SERVING_URL + return client + + @pytest.fixture + def server_credentials(self): + private_key = pkgutil.get_data(__name__, _PRIVATE_KEY_RESOURCE_PATH) + certificate_chain = pkgutil.get_data(__name__, _CERTIFICATE_CHAIN_RESOURCE_PATH) + return grpc.ssl_server_credentials(((private_key, certificate_chain),)) + + @pytest.fixture def core_server(self): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) Core.add_CoreServiceServicer_to_server(CoreServicer(), server) @@ -75,7 +106,7 @@ def core_server(self): yield server server.stop(0) - @pytest.fixture(scope="function") + @pytest.fixture def serving_server(self): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) Serving.add_ServingServiceServicer_to_server(ServingServicer(), server) @@ -85,48 +116,73 @@ def serving_server(self): server.stop(0) @pytest.fixture - def mock_client(self, mocker): - client = Client(core_url=CORE_URL, serving_url=SERVING_URL) - mocker.patch.object(client, "_connect_core") - mocker.patch.object(client, "_connect_serving") - client._core_url = CORE_URL - client._serving_url = SERVING_URL - return client + def secure_core_server(self, server_credentials): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + Core.add_CoreServiceServicer_to_server(CoreServicer(), server) + server.add_secure_port("[::]:50053", server_credentials) + server.start() + yield server + server.stop(0) + + @pytest.fixture + def secure_serving_server(self, server_credentials): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + Serving.add_ServingServiceServicer_to_server(ServingServicer(), server) + + server.add_secure_port("[::]:50054", server_credentials) + server.start() + yield server + server.stop(0) + + @pytest.fixture + def secure_client(self, secure_core_server, secure_serving_server): + root_certificate_credentials = pkgutil.get_data(__name__, _ROOT_CERTIFICATE_RESOURCE_PATH) + # this is needed to establish a secure connection using self-signed certificates, for the purpose of the test + ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates=root_certificate_credentials) + with mock.patch("grpc.ssl_channel_credentials", MagicMock(return_value=ssl_channel_credentials)): + yield Client(core_url="localhost:50053", serving_url="localhost:50054", core_secure=True, + serving_secure=True) @pytest.fixture def client(self, core_server, serving_server): return Client(core_url="localhost:50051", serving_url="localhost:50052") - def test_version(self, mock_client, mocker): - mock_client._core_service_stub = Core.CoreServiceStub(grpc.insecure_channel("")) - mock_client._serving_service_stub = Serving.ServingServiceStub( + @pytest.mark.parametrize("mocked_client", [pytest.lazy_fixture("mock_client"), + pytest.lazy_fixture("secure_mock_client") + ]) + def test_version(self, mocked_client, mocker): + mocked_client._core_service_stub = Core.CoreServiceStub(grpc.insecure_channel("")) + mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("") ) mocker.patch.object( - mock_client._core_service_stub, + mocked_client._core_service_stub, "GetFeastCoreVersion", return_value=GetFeastCoreVersionResponse(version="0.3.2"), ) mocker.patch.object( - mock_client._serving_service_stub, + mocked_client._serving_service_stub, "GetFeastServingInfo", return_value=GetFeastServingInfoResponse(version="0.3.2"), ) - status = mock_client.version() + status = mocked_client.version() assert ( - status["core"]["url"] == CORE_URL - and status["core"]["version"] == "0.3.2" - and status["serving"]["url"] == SERVING_URL - and status["serving"]["version"] == "0.3.2" + status["core"]["url"] == CORE_URL + and status["core"]["version"] == "0.3.2" + and status["serving"]["url"] == SERVING_URL + and status["serving"]["version"] == "0.3.2" ) - def test_get_online_features(self, mock_client, mocker): + @pytest.mark.parametrize("mocked_client", [pytest.lazy_fixture("mock_client"), + pytest.lazy_fixture("secure_mock_client") + ]) + def test_get_online_features(self, mocked_client, mocker): ROW_COUNT = 300 - mock_client._serving_service_stub = Serving.ServingServiceStub( + mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("") ) @@ -148,12 +204,12 @@ def test_get_online_features(self, mock_client, mocker): ) mocker.patch.object( - mock_client._serving_service_stub, + mocked_client._serving_service_stub, "GetOnlineFeatures", return_value=response, ) - response = mock_client.get_online_features( + response = mocked_client.get_online_features( entity_rows=entity_rows, feature_refs=[ "my_project/feature_1:1", @@ -169,17 +225,20 @@ def test_get_online_features(self, mock_client, mocker): ) # type: GetOnlineFeaturesResponse assert ( - response.field_values[0].fields["my_project/feature_1:1"].int64_val == 1 - and response.field_values[0].fields["my_project/feature_9:1"].int64_val == 9 + response.field_values[0].fields["my_project/feature_1:1"].int64_val == 1 + and response.field_values[0].fields["my_project/feature_9:1"].int64_val == 9 ) - def test_get_feature_set(self, mock_client, mocker): - mock_client._core_service_stub = Core.CoreServiceStub(grpc.insecure_channel("")) + @pytest.mark.parametrize("mocked_client", [pytest.lazy_fixture("mock_client"), + pytest.lazy_fixture("secure_mock_client") + ]) + def test_get_feature_set(self, mocked_client, mocker): + mocked_client._core_service_stub = Core.CoreServiceStub(grpc.insecure_channel("")) from google.protobuf.duration_pb2 import Duration mocker.patch.object( - mock_client._core_service_stub, + mocked_client._core_service_stub, "GetFeatureSet", return_value=GetFeatureSetResponse( feature_set=FeatureSetProto( @@ -214,29 +273,32 @@ def test_get_feature_set(self, mock_client, mocker): ) ), ) - mock_client.set_project("my_project") - feature_set = mock_client.get_feature_set("my_feature_set", version=2) + mocked_client.set_project("my_project") + feature_set = mocked_client.get_feature_set("my_feature_set", version=2) assert ( - feature_set.name == "my_feature_set" - and feature_set.version == 2 - and feature_set.fields["my_feature_1"].name == "my_feature_1" - and feature_set.fields["my_feature_1"].dtype == ValueType.FLOAT - and feature_set.fields["my_entity_1"].name == "my_entity_1" - and feature_set.fields["my_entity_1"].dtype == ValueType.INT64 - and len(feature_set.features) == 2 - and len(feature_set.entities) == 1 + feature_set.name == "my_feature_set" + and feature_set.version == 2 + and feature_set.fields["my_feature_1"].name == "my_feature_1" + and feature_set.fields["my_feature_1"].dtype == ValueType.FLOAT + and feature_set.fields["my_entity_1"].name == "my_entity_1" + and feature_set.fields["my_entity_1"].dtype == ValueType.INT64 + and len(feature_set.features) == 2 + and len(feature_set.entities) == 1 ) - def test_get_batch_features(self, mock_client, mocker): + @pytest.mark.parametrize("mocked_client", [pytest.lazy_fixture("mock_client"), + pytest.lazy_fixture("secure_mock_client") + ]) + def test_get_batch_features(self, mocked_client, mocker): - mock_client._serving_service_stub = Serving.ServingServiceStub( + mocked_client._serving_service_stub = Serving.ServingServiceStub( grpc.insecure_channel("") ) - mock_client._core_service_stub = Core.CoreServiceStub(grpc.insecure_channel("")) + mocked_client._core_service_stub = Core.CoreServiceStub(grpc.insecure_channel("")) mocker.patch.object( - mock_client._core_service_stub, + mocked_client._core_service_stub, "GetFeatureSet", return_value=GetFeatureSetResponse( feature_set=FeatureSetProto( @@ -283,7 +345,7 @@ def test_get_batch_features(self, mock_client, mocker): to_avro(file_path_or_buffer=final_results, df=expected_dataframe) mocker.patch.object( - mock_client._serving_service_stub, + mocked_client._serving_service_stub, "GetBatchFeatures", return_value=GetBatchFeaturesResponse( job=BatchFeaturesJob( @@ -297,7 +359,7 @@ def test_get_batch_features(self, mock_client, mocker): ) mocker.patch.object( - mock_client._serving_service_stub, + mocked_client._serving_service_stub, "GetJob", return_value=GetJobResponse( job=BatchFeaturesJob( @@ -311,7 +373,7 @@ def test_get_batch_features(self, mock_client, mocker): ) mocker.patch.object( - mock_client._serving_service_stub, + mocked_client._serving_service_stub, "GetFeastServingInfo", return_value=GetFeastServingInfoResponse( job_staging_location=f"file://{tempfile.mkdtemp()}/", @@ -319,8 +381,8 @@ def test_get_batch_features(self, mock_client, mocker): ), ) - mock_client.set_project("project1") - response = mock_client.get_batch_features( + mocked_client.set_project("project1") + response = mocked_client.get_batch_features( entity_rows=pd.DataFrame( { "datetime": [ @@ -348,9 +410,12 @@ def test_get_batch_features(self, mock_client, mocker): ] ) - def test_apply_feature_set_success(self, client): + @pytest.mark.parametrize("test_client", [pytest.lazy_fixture("client"), + pytest.lazy_fixture("secure_client") + ]) + def test_apply_feature_set_success(self, test_client): - client.set_project("project1") + test_client.set_project("project1") # Create Feature Sets fs1 = FeatureSet("my-feature-set-1") @@ -364,23 +429,24 @@ def test_apply_feature_set_success(self, client): fs2.add(Entity(name="fs2-my-entity-1", dtype=ValueType.INT64)) # Register Feature Set with Core - client.apply(fs1) - client.apply(fs2) + test_client.apply(fs1) + test_client.apply(fs2) - feature_sets = client.list_feature_sets() + feature_sets = test_client.list_feature_sets() # List Feature Sets assert ( - len(feature_sets) == 2 - and feature_sets[0].name == "my-feature-set-1" - and feature_sets[0].features[0].name == "fs1-my-feature-1" - and feature_sets[0].features[0].dtype == ValueType.INT64 - and feature_sets[1].features[1].dtype == ValueType.BYTES_LIST + len(feature_sets) == 2 + and feature_sets[0].name == "my-feature-set-1" + and feature_sets[0].features[0].name == "fs1-my-feature-1" + and feature_sets[0].features[0].dtype == ValueType.INT64 + and feature_sets[1].features[1].dtype == ValueType.BYTES_LIST ) - @pytest.mark.parametrize("dataframe", [dataframes.GOOD]) - def test_feature_set_ingest_success(self, dataframe, client, mocker): - client.set_project("project1") + @pytest.mark.parametrize("dataframe,test_client", [(dataframes.GOOD, pytest.lazy_fixture("client")), + (dataframes.GOOD, pytest.lazy_fixture("secure_client"))]) + def test_feature_set_ingest_success(self, dataframe, test_client, mocker): + test_client.set_project("project1") driver_fs = FeatureSet( "driver-feature-set", source=KafkaSource(brokers="kafka:9092", topic="test") ) @@ -390,12 +456,12 @@ def test_feature_set_ingest_success(self, dataframe, client, mocker): driver_fs.add(Entity(name="entity_id", dtype=ValueType.INT64)) # Register with Feast core - client.apply(driver_fs) + test_client.apply(driver_fs) driver_fs = driver_fs.to_proto() driver_fs.meta.status = FeatureSetStatusProto.STATUS_READY mocker.patch.object( - client._core_service_stub, + test_client._core_service_stub, "GetFeatureSet", return_value=GetFeatureSetResponse(feature_set=driver_fs), ) @@ -403,14 +469,16 @@ def test_feature_set_ingest_success(self, dataframe, client, mocker): # Need to create a mock producer with patch("feast.client.get_producer") as mocked_queue: # Ingest data into Feast - client.ingest("driver-feature-set", dataframe) + test_client.ingest("driver-feature-set", dataframe) - @pytest.mark.parametrize("dataframe,exception", [(dataframes.GOOD, TimeoutError)]) + @pytest.mark.parametrize("dataframe,exception,test_client", + [(dataframes.GOOD, TimeoutError, pytest.lazy_fixture("client")), + (dataframes.GOOD, TimeoutError, pytest.lazy_fixture("secure_client"))]) def test_feature_set_ingest_fail_if_pending( - self, dataframe, exception, client, mocker + self, dataframe, exception, test_client, mocker ): with pytest.raises(exception): - client.set_project("project1") + test_client.set_project("project1") driver_fs = FeatureSet( "driver-feature-set", source=KafkaSource(brokers="kafka:9092", topic="test"), @@ -421,12 +489,12 @@ def test_feature_set_ingest_fail_if_pending( driver_fs.add(Entity(name="entity_id", dtype=ValueType.INT64)) # Register with Feast core - client.apply(driver_fs) + test_client.apply(driver_fs) driver_fs = driver_fs.to_proto() driver_fs.meta.status = FeatureSetStatusProto.STATUS_PENDING mocker.patch.object( - client._core_service_stub, + test_client._core_service_stub, "GetFeatureSet", return_value=GetFeatureSetResponse(feature_set=driver_fs), ) @@ -434,18 +502,22 @@ def test_feature_set_ingest_fail_if_pending( # Need to create a mock producer with patch("feast.client.get_producer") as mocked_queue: # Ingest data into Feast - client.ingest("driver-feature-set", dataframe, timeout=1) + test_client.ingest("driver-feature-set", dataframe, timeout=1) @pytest.mark.parametrize( - "dataframe,exception", + "dataframe,exception,test_client", [ - (dataframes.BAD_NO_DATETIME, Exception), - (dataframes.BAD_INCORRECT_DATETIME_TYPE, Exception), - (dataframes.BAD_NO_ENTITY, Exception), - (dataframes.NO_FEATURES, Exception), + (dataframes.BAD_NO_DATETIME, Exception, pytest.lazy_fixture("client")), + (dataframes.BAD_INCORRECT_DATETIME_TYPE, Exception, pytest.lazy_fixture("client")), + (dataframes.BAD_NO_ENTITY, Exception, pytest.lazy_fixture("client")), + (dataframes.NO_FEATURES, Exception, pytest.lazy_fixture("client")), + (dataframes.BAD_NO_DATETIME, Exception, pytest.lazy_fixture("secure_client")), + (dataframes.BAD_INCORRECT_DATETIME_TYPE, Exception, pytest.lazy_fixture("secure_client")), + (dataframes.BAD_NO_ENTITY, Exception, pytest.lazy_fixture("secure_client")), + (dataframes.NO_FEATURES, Exception, pytest.lazy_fixture("secure_client")), ], ) - def test_feature_set_ingest_failure(self, client, dataframe, exception): + def test_feature_set_ingest_failure(self, test_client, dataframe, exception): with pytest.raises(exception): # Create feature set driver_fs = FeatureSet("driver-feature-set") @@ -454,15 +526,16 @@ def test_feature_set_ingest_failure(self, client, dataframe, exception): driver_fs.infer_fields_from_df(dataframe) # Register with Feast core - client.apply(driver_fs) + test_client.apply(driver_fs) # Ingest data into Feast - client.ingest(driver_fs, dataframe=dataframe) + test_client.ingest(driver_fs, dataframe=dataframe) - @pytest.mark.parametrize("dataframe", [dataframes.ALL_TYPES]) - def test_feature_set_types_success(self, client, dataframe, mocker): + @pytest.mark.parametrize("dataframe,test_client", [(dataframes.ALL_TYPES, pytest.lazy_fixture("client")), + (dataframes.ALL_TYPES, pytest.lazy_fixture("secure_client"))]) + def test_feature_set_types_success(self, test_client, dataframe, mocker): - client.set_project("project1") + test_client.set_project("project1") all_types_fs = FeatureSet( name="all_types", @@ -489,10 +562,10 @@ def test_feature_set_types_success(self, client, dataframe, mocker): ) # Register with Feast core - client.apply(all_types_fs) + test_client.apply(all_types_fs) mocker.patch.object( - client._core_service_stub, + test_client._core_service_stub, "GetFeatureSet", return_value=GetFeatureSetResponse(feature_set=all_types_fs.to_proto()), ) @@ -500,4 +573,38 @@ def test_feature_set_types_success(self, client, dataframe, mocker): # Need to create a mock producer with patch("feast.client.get_producer") as mocked_queue: # Ingest data into Feast - client.ingest(all_types_fs, dataframe) + test_client.ingest(all_types_fs, dataframe) + + @patch("grpc.channel_ready_future") + def test_secure_channel_creation_with_secure_client(self, _mocked_obj): + client = Client(core_url="localhost:50051", serving_url="localhost:50052", serving_secure=True, + core_secure=True) + with mock.patch("grpc.secure_channel") as _grpc_mock, \ + mock.patch("grpc.ssl_channel_credentials", MagicMock(return_value="test")) as _mocked_credentials: + client._connect_serving() + _grpc_mock.assert_called_with(client.serving_url, _mocked_credentials.return_value) + + @mock.patch("grpc.channel_ready_future") + def test_secure_channel_creation_with_secure_serving_url(self, _mocked_obj, ): + client = Client(core_url="localhost:50051", serving_url="localhost:443") + with mock.patch("grpc.secure_channel") as _grpc_mock, \ + mock.patch("grpc.ssl_channel_credentials", MagicMock(return_value="test")) as _mocked_credentials: + client._connect_serving() + _grpc_mock.assert_called_with(client.serving_url, _mocked_credentials.return_value) + + @patch("grpc.channel_ready_future") + def test_secure_channel_creation_with_secure_client(self, _mocked_obj): + client = Client(core_url="localhost:50053", serving_url="localhost:50054", serving_secure=True, + core_secure=True) + with mock.patch("grpc.secure_channel") as _grpc_mock, \ + mock.patch("grpc.ssl_channel_credentials", MagicMock(return_value="test")) as _mocked_credentials: + client._connect_core() + _grpc_mock.assert_called_with(client.core_url, _mocked_credentials.return_value) + + @patch("grpc.channel_ready_future") + def test_secure_channel_creation_with_secure_core_url(self, _mocked_obj): + client = Client(core_url="localhost:443", serving_url="localhost:50054") + with mock.patch("grpc.secure_channel") as _grpc_mock, \ + mock.patch("grpc.ssl_channel_credentials", MagicMock(return_value="test")) as _mocked_credentials: + client._connect_core() + _grpc_mock.assert_called_with(client.core_url, _mocked_credentials.return_value) \ No newline at end of file