diff --git a/java/src/main/java/com/logicalclocks/hsfs/FeatureStore.java b/java/src/main/java/com/logicalclocks/hsfs/FeatureStore.java index 4353bd91e9..46353dc8fd 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/FeatureStore.java +++ b/java/src/main/java/com/logicalclocks/hsfs/FeatureStore.java @@ -121,9 +121,8 @@ public Dataset sql(String query) { return SparkEngine.getInstance().sql(query); } - public StorageConnector getStorageConnector(String name, StorageConnectorType type) - throws FeatureStoreException, IOException { - return storageConnectorApi.getByNameAndType(this, name, type); + public StorageConnector getStorageConnector(String name) throws FeatureStoreException, IOException { + return storageConnectorApi.getByName(this, name); } public StorageConnector getOnlineStorageConnector() throws FeatureStoreException, IOException { diff --git a/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java b/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java index 3186c40877..8381226dbb 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java +++ b/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java @@ -71,10 +71,6 @@ public class TrainingDataset { @JsonIgnore private FeatureStore featureStore; - @Getter - @Setter - private Integer storageConnectorId; - @Getter @Setter @JsonIgnore @@ -137,7 +133,6 @@ public TrainingDataset(@NonNull String name, Integer version, String description this.storageConnector = storageConnector; if (storageConnector != null) { - this.storageConnectorId = storageConnector.getId(); if (storageConnector.getStorageConnectorType() == StorageConnectorType.S3) { // Default it's already HOPSFS_TRAINING_DATASET this.trainingDatasetType = TrainingDatasetType.EXTERNAL_TRAINING_DATASET; diff --git a/java/src/main/java/com/logicalclocks/hsfs/engine/Utils.java b/java/src/main/java/com/logicalclocks/hsfs/engine/Utils.java index 2bb6a5bd46..f7d0b5b63d 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/engine/Utils.java +++ b/java/src/main/java/com/logicalclocks/hsfs/engine/Utils.java @@ -21,8 +21,6 @@ import com.logicalclocks.hsfs.FeatureGroup; import com.logicalclocks.hsfs.TrainingDatasetFeature; import com.logicalclocks.hsfs.StorageConnector; -import com.logicalclocks.hsfs.StorageConnectorType; -import com.logicalclocks.hsfs.TrainingDatasetFeature; import com.logicalclocks.hsfs.metadata.StorageConnectorApi; import org.apache.commons.io.FileUtils; import org.apache.spark.sql.Dataset; @@ -115,8 +113,8 @@ public String getFgName(FeatureGroup featureGroup) { } public String getHiveMetastoreConnector(FeatureGroup featureGroup) throws IOException, FeatureStoreException { - StorageConnector storageConnector = storageConnectorApi.getByNameAndType(featureGroup.getFeatureStore(), - featureGroup.getFeatureStore().getName(), StorageConnectorType.JDBC); + StorageConnector storageConnector = storageConnectorApi.getByName(featureGroup.getFeatureStore(), + featureGroup.getFeatureStore().getName()); String connStr = storageConnector.getConnectionString(); String pw = FileUtils.readFileToString(new File("material_passwd")); return connStr + "sslTrustStore=t_certificate;trustStorePassword=" + pw diff --git a/java/src/main/java/com/logicalclocks/hsfs/metadata/StorageConnectorApi.java b/java/src/main/java/com/logicalclocks/hsfs/metadata/StorageConnectorApi.java index cc588ba910..792aadb3c7 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/metadata/StorageConnectorApi.java +++ b/java/src/main/java/com/logicalclocks/hsfs/metadata/StorageConnectorApi.java @@ -20,7 +20,6 @@ import com.logicalclocks.hsfs.FeatureStore; import com.logicalclocks.hsfs.FeatureStoreException; import com.logicalclocks.hsfs.StorageConnector; -import com.logicalclocks.hsfs.StorageConnectorType; import org.apache.http.client.methods.HttpGet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,7 +35,7 @@ public class StorageConnectorApi { private static final Logger LOGGER = LoggerFactory.getLogger(StorageConnectorApi.class); - public StorageConnector getByNameAndType(FeatureStore featureStore, String name, StorageConnectorType type) + public StorageConnector getByName(FeatureStore featureStore, String name) throws IOException, FeatureStoreException { HopsworksClient hopsworksClient = HopsworksClient.getInstance(); String pathTemplate = HopsworksClient.PROJECT_PATH @@ -46,7 +45,6 @@ public StorageConnector getByNameAndType(FeatureStore featureStore, String name, String uri = UriTemplate.fromTemplate(pathTemplate) .set("projectId", featureStore.getProjectId()) .set("fsId", featureStore.getId()) - .set("connType", type) .set("name", name) .set("temporaryCredentials", true) .expand(); diff --git a/python/hsfs/core/storage_connector_api.py b/python/hsfs/core/storage_connector_api.py index 94f438bc3f..6753be8c57 100644 --- a/python/hsfs/core/storage_connector_api.py +++ b/python/hsfs/core/storage_connector_api.py @@ -23,13 +23,11 @@ class StorageConnectorApi: def __init__(self, feature_store_id): self._feature_store_id = feature_store_id - def get(self, name, connector_type): + def get(self, name): """Get storage connector with name and type. :param name: name of the storage connector :type name: str - :param connector_type: connector type - :type connector_type: str :return: the storage connector :rtype: StorageConnector """ @@ -40,7 +38,6 @@ def get(self, name, connector_type): "featurestores", self._feature_store_id, "storageconnectors", - connector_type, name, ] query_params = {"temporaryCredentials": True} diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index 8c58f3ab99..31752c002d 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -185,13 +185,7 @@ def save_dataframe( ) def _save_offline_dataframe( - self, - table_name, - feature_group, - dataframe, - save_mode, - operation, - write_options, + self, table_name, feature_group, dataframe, save_mode, operation, write_options, ): if feature_group.time_travel_format == "HUDI": hudi_engine_instance = hudi_engine.HudiEngine( @@ -246,10 +240,8 @@ def read(self, storage_connector, data_format, read_options, path): def profile(self, dataframe, relevant_columns, correlations, histograms): """Profile a dataframe with Deequ.""" - return ( - self._jvm.com.logicalclocks.hsfs.engine.SparkEngine.getInstance().profile( - dataframe._jdf, relevant_columns, correlations, histograms - ) + return self._jvm.com.logicalclocks.hsfs.engine.SparkEngine.getInstance().profile( + dataframe._jdf, relevant_columns, correlations, histograms ) def write_options(self, data_format, provided_options): @@ -362,8 +354,7 @@ def _setup_s3(self, storage_connector, path): "org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider", ) self._spark_context._jsc.hadoopConfiguration().set( - "fs.s3a.session.token", - storage_connector.session_token, + "fs.s3a.session.token", storage_connector.session_token, ) return path.replace("s3", "s3a", 1) diff --git a/python/hsfs/feature_store.py b/python/hsfs/feature_store.py index 417161cec1..004aea77d4 100644 --- a/python/hsfs/feature_store.py +++ b/python/hsfs/feature_store.py @@ -180,7 +180,7 @@ def get_training_dataset(self, name: str, version: int = None): version = self.DEFAULT_VERSION return self._training_dataset_api.get(name, version) - def get_storage_connector(self, name: str, connector_type: str): + def get_storage_connector(self, name: str): """Get a previously created storage connector from the feature store. Storage connectors encapsulate all information needed for the execution engine @@ -194,20 +194,18 @@ def get_storage_connector(self, name: str, connector_type: str): !!! example "Getting a Storage Connector" ```python - sc = fs.get_storage_connector("demo_fs_meb10000_Training_Datasets", "HOPSFS") + sc = fs.get_storage_connector("demo_fs_meb10000_Training_Datasets") td = fs.create_training_dataset(..., storage_connector=sc, ...) ``` # Arguments name: Name of the storage connector to retrieve. - connector_type: Type of the storage connector, e.g. `"JDBC"`, `"HOPSFS"` - or `"S3"`. # Returns `StorageConnector`. Storage connector object. """ - return self._storage_connector_api.get(name, connector_type) + return self._storage_connector_api.get(name) def sql(self, query, dataframe_type="default", online=False): return self._feature_group_engine.sql(query, self._name, dataframe_type, online) diff --git a/python/hsfs/storage_connector.py b/python/hsfs/storage_connector.py index 62d7af9d04..0b1d011597 100644 --- a/python/hsfs/storage_connector.py +++ b/python/hsfs/storage_connector.py @@ -110,77 +110,167 @@ def connector_type(self): @property def access_key(self): """Access key.""" - return self._access_key + if self._storage_connector_type.upper() == "S3": + return self._access_key + else: + raise Exception( + "Access key is not supported for connector " + + self._storage_connector_type + ) @property def secret_key(self): """Secret key.""" - return self._secret_key + if self._storage_connector_type.upper() == "S3": + return self._secret_key + else: + raise Exception( + "Secret key is not supported for connector " + + self._storage_connector_type + ) @property def server_encryption_algorithm(self): """Encryption algorithm if server-side S3 bucket encryption is enabled.""" - return self._server_encryption_algorithm + if self._storage_connector_type.upper() == "S3": + return self._server_encryption_algorithm + else: + raise Exception( + "Encryption algorithm is not supported for connector " + + self._storage_connector_type + ) @property def server_encryption_key(self): """Encryption key if server-side S3 bucket encryption is enabled.""" - return self._server_encryption_key + if self._storage_connector_type.upper() == "S3": + return self._server_encryption_key + else: + raise Exception( + "Encryption key is not supported for connector " + + self._storage_connector_type + ) @property def cluster_identifier(self): """Cluster identifier for redshift cluster.""" - return self._cluster_identifier + if self._storage_connector_type.upper() == "REDSHIFT": + return self._cluster_identifier + else: + raise Exception( + "Cluster identifier is not supported for connector " + + self._storage_connector_type + ) @property def database_driver(self): """Database endpoint for redshift cluster.""" - return self._database_driver + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_driver + else: + raise Exception( + "Database driver is not supported for connector " + + self._storage_connector_type + ) @property def database_endpoint(self): """Database endpoint for redshift cluster.""" - return self._database_endpoint + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_endpoint + else: + raise Exception( + "Database endpoint is not supported for connector " + + self._storage_connector_type + ) @property def database_name(self): """Database name for redshift cluster.""" - return self._database_name + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_name + else: + raise Exception( + "Database name is not supported for connector " + + self._storage_connector_type + ) @property def database_port(self): """Database port for redshift cluster.""" - return self._database_port + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_port + else: + raise Exception( + "Database port is not supported for connector " + + self._storage_connector_type + ) @property def table_name(self): """Table name for redshift cluster.""" - return self._table_name + if self._storage_connector_type.upper() == "REDSHIFT": + return self._table_name + else: + raise Exception( + "Table name is not supported for connector " + + self._storage_connector_type + ) @property def database_user_name(self): """Database username for redshift cluster.""" - return self._database_user_name + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_user_name + else: + raise Exception( + "Database username is not supported for connector " + + self._storage_connector_type + ) @property def auto_create(self): """Database username for redshift cluster.""" - return self._auto_create + if self._storage_connector_type.upper() == "REDSHIFT": + return self._auto_create + else: + raise Exception( + "Auto create is not supported for connector " + + self._storage_connector_type + ) @property def database_group(self): """Database username for redshift cluster.""" - return self._database_group + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_group + else: + raise Exception( + "Database group is not supported for connector " + + self._storage_connector_type + ) @property def database_password(self): """Database password for redshift cluster.""" - return self._database_password + if self._storage_connector_type.upper() == "REDSHIFT": + return self._database_password + else: + raise Exception( + "Database password is not supported for connector " + + self._storage_connector_type + ) @property def session_token(self): """Session token.""" - return self._session_token + if self._storage_connector_type.upper() == "S3": + return self._session_token + else: + raise Exception( + "Session token is not supported for connector " + + self._storage_connector_type + ) @property def iam_role(self): @@ -190,23 +280,62 @@ def iam_role(self): @property def expiration(self): """Cluster temporary credential expiration time.""" - return self._expiration + if self._storage_connector_type.upper() in ["S3", "REDSHIFT"]: + return self._expiration + else: + raise Exception( + "Expiration is not supported for connector " + + self._storage_connector_type + ) + + @property + def bucket(self): + """Return the bucket for S3 connectors.""" + if self._storage_connector_type.upper() == "S3": + return self._bucket + else: + raise Exception( + "Bucket is not supported for connector " + self._storage_connector_type + ) @property def connection_string(self): """JDBC connection string.""" - return self._connection_string + if self._storage_connector_type.upper() == "JDBC": + return self._connection_string + else: + raise Exception( + "Connection string is not supported for connector " + + self._storage_connector_type + ) @property def arguments(self): """Additional JDBC arguments.""" - return self._arguments + if self._storage_connector_type.upper() == "JDBC": + return self._arguments + else: + raise Exception( + "Arguments is not supported for connector " + + self._storage_connector_type + ) + + @property + def path(self): + """If the connector refers to a path (e.g. S3) - return the path of the connector + """ + if self._storage_connector_type.upper() == "S3": + return "s3://" + self._bucket + else: + raise Exception( + "Path is not supported for connector " + self._storage_connector_type + ) def spark_options(self): """Return prepared options to be passed to Spark, based on the additional arguments. """ - if self._storage_connector_type == "JDBC": + if self._storage_connector_type.upper() == "JDBC": args = [arg.split("=") for arg in self._arguments.split(",")] return { @@ -214,7 +343,7 @@ def spark_options(self): "user": [arg[1] for arg in args if arg[0] == "user"][0], "password": [arg[1] for arg in args if arg[0] == "password"][0], } - elif self._storage_connector_type == "REDSHIFT": + elif self._storage_connector_type.upper() == "REDSHIFT": connstr = ( "jdbc:redshift://" + self._cluster_identifier diff --git a/python/hsfs/training_dataset.py b/python/hsfs/training_dataset.py index bf8e78a2b3..c508ac93e3 100644 --- a/python/hsfs/training_dataset.py +++ b/python/hsfs/training_dataset.py @@ -28,7 +28,6 @@ from hsfs.core import ( query, training_dataset_api, - storage_connector_api, training_dataset_engine, tfdata_engine, statistics_engine, @@ -60,9 +59,6 @@ def __init__( id=None, jobs=None, inode_id=None, - storage_connector_name=None, - storage_connector_type=None, - storage_connector_id=None, training_dataset_type=None, from_query=None, querydto=None, @@ -86,10 +82,6 @@ def __init__( featurestore_id ) - self._storage_connector_api = storage_connector_api.StorageConnectorApi( - featurestore_id - ) - self._statistics_engine = statistics_engine.StatisticsEngine( featurestore_id, self.ENTITY_TYPE ) @@ -105,9 +97,10 @@ def __init__( else: # type available -> init from backend response # make rest call to get all connector information, description etc. - self._storage_connector = self._storage_connector_api.get( - storage_connector_name, storage_connector_type + self._storage_connector = StorageConnector.from_response_json( + storage_connector ) + self._features = [ training_dataset_feature.TrainingDatasetFeature.from_response_json(feat) for feat in features