diff --git a/auto_doc.py b/auto_doc.py index c1259631b5..8676a6cada 100644 --- a/auto_doc.py +++ b/auto_doc.py @@ -78,6 +78,12 @@ "hsfs.constructor.query.Query" ), }, + "statistics.md": { + "statistics_config": ["hsfs.statistics_config.StatisticsConfig"], + "statistics_config_properties": keras_autodoc.get_properties( + "hsfs.statistics_config.StatisticsConfig" + ), + }, "api/connection_api.md": { "connection": ["hsfs.connection.Connection"], "connection_properties": keras_autodoc.get_properties( @@ -130,6 +136,12 @@ "hsfs.storage_connector.StorageConnector" ), }, + "api/statistics_config_api.md": { + "statistics_config": ["hsfs.statistics_config.StatisticsConfig"], + "statistics_config_properties": keras_autodoc.get_properties( + "hsfs.statistics_config.StatisticsConfig" + ), + }, } hsfs_dir = pathlib.Path(__file__).resolve().parents[0] diff --git a/docs/templates/api/statistics_config_api.md b/docs/templates/api/statistics_config_api.md new file mode 100644 index 0000000000..a907d1d323 --- /dev/null +++ b/docs/templates/api/statistics_config_api.md @@ -0,0 +1,7 @@ +# StatisticsConfig + +{{statistics_config}} + +## Properties + +{{statistics_config_properties}} diff --git a/docs/templates/statistics.md b/docs/templates/statistics.md new file mode 100644 index 0000000000..039fa8daf3 --- /dev/null +++ b/docs/templates/statistics.md @@ -0,0 +1,44 @@ +# Statistics + +HSFS provides functionality to compute statistics for [training datasets](training_dataset.md) and [feature groups](feature_group.md) and save these along with their other metadata in the [feature store](feature_store.md). +These statistics are meant to be helpful for Data Scientists to perform explorative data analysis and then recognize suitable [features](feature.md) or [training datasets](training_dataset.md) for models. + +Statistics are configured on a training dataset or feature group level using a `StatisticsConfig` object. +This object can be passed at creation time of the dataset or group or it can later on be updated through the API. + +{{statistics_config}} + +For example, to enable all statistics (descriptive, histograms and correlations) for a training dataset: + +=== "Python" + ```python + from hsfs.statistics_config import StatisticsConfig + + td = fs.create_training_dataset("rain_dataset", + version=1, + label=”weekly_rain”, + data_format=”tfrecords”, + statistics_config=StatisticsConfig(true, true, true)) + + ``` +=== "Scala" + ```scala + val td = (fs.createTrainingDataset() + .name("rain_dataset") + .version(1) + .label(”weekly_rain”) + .dataFormat(”tfrecords”) + .statisticsConfig(new StatisticsConfig(true, true, true)) + .build()) + ``` + +And similarly for feature groups. + +!!! note "Default StatisticsConfig" + By default all training datasets and feature groups will be configured such that only descriptive statistics + are computed. However, you can also enable `histograms` and `correlations` or limit the features for which + statistics are computed. + +## Properties + +{{statistics_config_properties}} diff --git a/java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java b/java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java index 9ebdc966dd..c5a1c42995 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java +++ b/java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java @@ -75,8 +75,7 @@ public class FeatureGroup extends FeatureGroupBase { public FeatureGroup(FeatureStore featureStore, @NonNull String name, Integer version, String description, List primaryKeys, List partitionKeys, String hudiPrecombineKey, boolean onlineEnabled, TimeTravelFormat timeTravelFormat, List features, - Boolean statisticsEnabled, Boolean histograms, Boolean correlations, - List statisticColumns) { + StatisticsConfig statisticsConfig) { this.featureStore = featureStore; this.name = name; this.version = version; @@ -87,10 +86,7 @@ public FeatureGroup(FeatureStore featureStore, @NonNull String name, Integer ver this.onlineEnabled = onlineEnabled; this.timeTravelFormat = timeTravelFormat != null ? timeTravelFormat : TimeTravelFormat.HUDI; this.features = features; - this.statisticsEnabled = statisticsEnabled != null ? statisticsEnabled : true; - this.histograms = histograms; - this.correlations = correlations; - this.statisticColumns = statisticColumns; + this.statisticsConfig = statisticsConfig != null ? statisticsConfig : new StatisticsConfig(); } public FeatureGroup() { @@ -183,7 +179,7 @@ public void save(Dataset featureData, Map writeOptions) throws FeatureStoreException, IOException { featureGroupEngine.saveFeatureGroup(this, featureData, primaryKeys, partitionKeys, hudiPrecombineKey, writeOptions); - if (statisticsEnabled) { + if (statisticsConfig.getEnabled()) { statisticsEngine.computeStatistics(this, featureData); } } diff --git a/java/src/main/java/com/logicalclocks/hsfs/OnDemandFeatureGroup.java b/java/src/main/java/com/logicalclocks/hsfs/OnDemandFeatureGroup.java index 9d8b9154df..ce442c7dc0 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/OnDemandFeatureGroup.java +++ b/java/src/main/java/com/logicalclocks/hsfs/OnDemandFeatureGroup.java @@ -68,8 +68,7 @@ public class OnDemandFeatureGroup extends FeatureGroupBase { public OnDemandFeatureGroup(FeatureStore featureStore, @NonNull String name, Integer version, String query, OnDemandDataFormat dataFormat, String path, Map options, @NonNull StorageConnector storageConnector, String description, List features, - Boolean statisticsEnabled, Boolean histograms, Boolean correlations, - List statisticColumns) { + StatisticsConfig statisticsConfig) { this.featureStore = featureStore; this.name = name; this.version = version; @@ -83,10 +82,7 @@ public OnDemandFeatureGroup(FeatureStore featureStore, @NonNull String name, Int this.description = description; this.storageConnector = storageConnector; this.features = features; - this.statisticsEnabled = statisticsEnabled != null ? statisticsEnabled : true; - this.histograms = histograms; - this.correlations = correlations; - this.statisticColumns = statisticColumns; + this.statisticsConfig = statisticsConfig != null ? statisticsConfig : new StatisticsConfig(); } public OnDemandFeatureGroup() { @@ -95,7 +91,7 @@ public OnDemandFeatureGroup() { public void save() throws FeatureStoreException, IOException { onDemandFeatureGroupEngine.saveFeatureGroup(this); - if (statisticsEnabled) { + if (statisticsConfig.getEnabled()) { statisticsEngine.computeStatistics(this, read()); } } diff --git a/java/src/main/java/com/logicalclocks/hsfs/StatisticsConfig.java b/java/src/main/java/com/logicalclocks/hsfs/StatisticsConfig.java new file mode 100644 index 0000000000..4a73bf08ca --- /dev/null +++ b/java/src/main/java/com/logicalclocks/hsfs/StatisticsConfig.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 Logical Clocks AB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * See the License for the specific language governing permissions and limitations under the License. + */ + +package com.logicalclocks.hsfs; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +import java.util.ArrayList; +import java.util.List; + +@AllArgsConstructor +@NoArgsConstructor +@Builder +public class StatisticsConfig { + @Getter + @Setter + private Boolean enabled = true; + + @Getter + @Setter + private Boolean histograms = false; + + @Getter + @Setter + private Boolean correlations = false; + + @Getter + @Setter + private List columns = new ArrayList<>(); + + public StatisticsConfig(Boolean enabled, Boolean histograms, Boolean correlations) { + this.enabled = enabled; + this.histograms = histograms; + this.correlations = correlations; + } +} diff --git a/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java b/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java index 703eb21125..8039f0b26f 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java +++ b/java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java @@ -90,23 +90,7 @@ public class TrainingDataset { @Getter @Setter - @JsonIgnore - private Boolean statisticsEnabled = true; - - @Getter - @Setter - @JsonIgnore - private Boolean histograms; - - @Getter - @Setter - @JsonIgnore - private Boolean correlations; - - @Getter - @Setter - @JsonIgnore - private List statisticColumns; + private StatisticsConfig statisticsConfig = new StatisticsConfig(); @Getter @Setter @@ -123,8 +107,7 @@ public class TrainingDataset { @Builder public TrainingDataset(@NonNull String name, Integer version, String description, DataFormat dataFormat, StorageConnector storageConnector, String location, List splits, Long seed, - FeatureStore featureStore, Boolean statisticsEnabled, Boolean histograms, - Boolean correlations, List statisticColumns, List label) { + FeatureStore featureStore, StatisticsConfig statisticsConfig, List label) { this.name = name; this.version = version; this.description = description; @@ -142,10 +125,7 @@ public TrainingDataset(@NonNull String name, Integer version, String description this.splits = splits; this.seed = seed; this.featureStore = featureStore; - this.statisticsEnabled = statisticsEnabled != null ? statisticsEnabled : true; - this.histograms = histograms; - this.correlations = correlations; - this.statisticColumns = statisticColumns; + this.statisticsConfig = statisticsConfig != null ? statisticsConfig : new StatisticsConfig(); this.label = label; } @@ -195,7 +175,7 @@ public void save(Query query, Map writeOptions) throws FeatureSt public void save(Dataset dataset, Map writeOptions) throws FeatureStoreException, IOException { trainingDatasetEngine.save(this, dataset, writeOptions, label); - if (statisticsEnabled) { + if (statisticsConfig.getEnabled()) { statisticsEngine.computeStatistics(this, dataset); } } @@ -314,12 +294,24 @@ public void show(int numRows) { * @throws IOException */ public Statistics computeStatistics() throws FeatureStoreException, IOException { - if (statisticsEnabled) { + if (statisticsConfig.getEnabled()) { return statisticsEngine.computeStatistics(this, read()); } return null; } + /** + * Update the statistics configuration of the training dataset. + * Change the `enabled`, `histograms`, `correlations` or `columns` attributes and persist + * the changes by calling this method. + * + * @throws FeatureStoreException + * @throws IOException + */ + public void updateStatisticsConfig() throws FeatureStoreException, IOException { + trainingDatasetEngine.updateStatisticsConfig(this); + } + /** * Get the last statistics commit for the training dataset. * diff --git a/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupBaseEngine.java b/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupBaseEngine.java index 2b4ce722e7..09b10a4121 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupBaseEngine.java +++ b/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupBaseEngine.java @@ -66,10 +66,10 @@ public void appendFeatures(FeatureGroupBase featureGroup, List features FeatureGroup apiFG = featureGroupApi.updateMetadata(fgBaseSend, "updateMetadata"); featureGroup.setFeatures(apiFG.getFeatures()); } - - public void updateStatisticsConfig(FeatureGroupBase featureGroup) throws FeatureStoreException, IOException { - FeatureGroup apiFG = featureGroupApi.updateMetadata(featureGroup, "updateStatsSettings"); - featureGroup.setCorrelations(apiFG.getCorrelations()); - featureGroup.setHistograms(apiFG.getHistograms()); + + public void updateStatisticsConfig(FeatureGroup featureGroup) throws FeatureStoreException, IOException { + FeatureGroup apiFG = featureGroupApi.updateMetadata(featureGroup, "updateStatsConfig"); + featureGroup.getStatisticsConfig().setCorrelations(apiFG.getStatisticsConfig().getCorrelations()); + featureGroup.getStatisticsConfig().setHistograms(apiFG.getStatisticsConfig().getHistograms()); } } diff --git a/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupEngine.java b/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupEngine.java index 9847e0de3f..1380617aba 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupEngine.java +++ b/java/src/main/java/com/logicalclocks/hsfs/engine/FeatureGroupEngine.java @@ -109,8 +109,7 @@ public void saveFeatureGroup(FeatureGroup featureGroup, Dataset dataset, Li featureGroup.setVersion(apiFG.getVersion()); featureGroup.setLocation(apiFG.getLocation()); featureGroup.setId(apiFG.getId()); - featureGroup.setCorrelations(apiFG.getCorrelations()); - featureGroup.setHistograms(apiFG.getHistograms()); + featureGroup.setStatisticsConfig(apiFG.getStatisticsConfig()); /* if hudi precombine key was not provided and TimeTravelFormat is HUDI, retrieve from backend and set */ if (featureGroup.getTimeTravelFormat() == TimeTravelFormat.HUDI & hudiPrecombineKey == null) { diff --git a/java/src/main/java/com/logicalclocks/hsfs/engine/StatisticsEngine.java b/java/src/main/java/com/logicalclocks/hsfs/engine/StatisticsEngine.java index 2e117ddf6a..cebf554c19 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/engine/StatisticsEngine.java +++ b/java/src/main/java/com/logicalclocks/hsfs/engine/StatisticsEngine.java @@ -44,14 +44,17 @@ public StatisticsEngine(EntityEndpointType entityType) { public Statistics computeStatistics(TrainingDataset trainingDataset, Dataset dataFrame) throws FeatureStoreException, IOException { - return statisticsApi.post(trainingDataset, computeStatistics(dataFrame, trainingDataset.getStatisticColumns(), - trainingDataset.getHistograms(), trainingDataset.getCorrelations())); + return statisticsApi.post(trainingDataset, computeStatistics(dataFrame, + trainingDataset.getStatisticsConfig().getColumns(), + trainingDataset.getStatisticsConfig().getHistograms(), + trainingDataset.getStatisticsConfig().getCorrelations())); } public Statistics computeStatistics(FeatureGroupBase featureGroup, Dataset dataFrame) throws FeatureStoreException, IOException { - return statisticsApi.post(featureGroup, computeStatistics(dataFrame, featureGroup.getStatisticColumns(), - featureGroup.getHistograms(), featureGroup.getCorrelations())); + return statisticsApi.post(featureGroup, computeStatistics(dataFrame, + featureGroup.getStatisticsConfig().getColumns(), + featureGroup.getStatisticsConfig().getHistograms(), featureGroup.getStatisticsConfig().getCorrelations())); } private Statistics computeStatistics(Dataset dataFrame, List statisticColumns, Boolean histograms, diff --git a/java/src/main/java/com/logicalclocks/hsfs/engine/TrainingDatasetEngine.java b/java/src/main/java/com/logicalclocks/hsfs/engine/TrainingDatasetEngine.java index 20bbd50ee8..9a3788353d 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/engine/TrainingDatasetEngine.java +++ b/java/src/main/java/com/logicalclocks/hsfs/engine/TrainingDatasetEngine.java @@ -150,4 +150,10 @@ public String getQuery(TrainingDataset trainingDataset, Storage storage, boolean throws FeatureStoreException, IOException { return trainingDatasetApi.getQuery(trainingDataset, withLabel).getStorageQuery(storage); } + + public void updateStatisticsConfig(TrainingDataset trainingDataset) throws FeatureStoreException, IOException { + TrainingDataset apiTD = trainingDatasetApi.updateMetadata(trainingDataset, "updateStatsConfig"); + trainingDataset.getStatisticsConfig().setCorrelations(apiTD.getStatisticsConfig().getCorrelations()); + trainingDataset.getStatisticsConfig().setHistograms(apiTD.getStatisticsConfig().getHistograms()); + } } diff --git a/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupApi.java b/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupApi.java index 702bcc0c33..a28aeb07f1 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupApi.java +++ b/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupApi.java @@ -40,7 +40,7 @@ public class FeatureGroupApi { public static final String FEATURE_GROUP_ROOT_PATH = "/featuregroups"; public static final String FEATURE_GROUP_PATH = FEATURE_GROUP_ROOT_PATH + "{/fgName}{?version}"; - public static final String FEATURE_GROUP_ID_PATH = FEATURE_GROUP_ROOT_PATH + "{/fgId}{?updateStatsSettings," + public static final String FEATURE_GROUP_ID_PATH = FEATURE_GROUP_ROOT_PATH + "{/fgId}{?updateStatsConfig," + "updateMetadata}"; public static final String FEATURE_GROUP_COMMIT_PATH = FEATURE_GROUP_ID_PATH + "/commits{?sort_by,offset,limit}"; diff --git a/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupBase.java b/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupBase.java index baf8afe421..0479815bb4 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupBase.java +++ b/java/src/main/java/com/logicalclocks/hsfs/metadata/FeatureGroupBase.java @@ -17,11 +17,12 @@ package com.logicalclocks.hsfs.metadata; import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonProperty; import com.logicalclocks.hsfs.EntityEndpointType; import com.logicalclocks.hsfs.Feature; +import com.logicalclocks.hsfs.FeatureGroup; import com.logicalclocks.hsfs.FeatureStore; import com.logicalclocks.hsfs.FeatureStoreException; +import com.logicalclocks.hsfs.StatisticsConfig; import com.logicalclocks.hsfs.constructor.Filter; import com.logicalclocks.hsfs.constructor.FilterLogic; import com.logicalclocks.hsfs.constructor.Query; @@ -77,22 +78,7 @@ public class FeatureGroupBase { @Getter @Setter - @JsonProperty("descStatsEnabled") - protected Boolean statisticsEnabled; - - @Getter - @Setter - @JsonProperty("featHistEnabled") - protected Boolean histograms; - - @Getter - @Setter - @JsonProperty("featCorrEnabled") - protected Boolean correlations; - - @Getter - @Setter - protected List statisticColumns; + protected StatisticsConfig statisticsConfig = new StatisticsConfig(); private FeatureGroupBaseEngine featureGroupBaseEngine = new FeatureGroupBaseEngine(); protected StatisticsEngine statisticsEngine = new StatisticsEngine(EntityEndpointType.FEATURE_GROUP); @@ -236,14 +222,14 @@ public void appendFeatures(Feature features) throws FeatureStoreException, IOExc /** * Update the statistics configuration of the feature group. - * Change the `statisticsEnabled`, `histograms`, `correlations` or `statisticColumns` attributes and persist + * Change the `enabled`, `histograms`, `correlations` or `columns` attributes and persist * the changes by calling this method. * * @throws FeatureStoreException * @throws IOException */ public void updateStatisticsConfig() throws FeatureStoreException, IOException { - featureGroupBaseEngine.updateStatisticsConfig(this); + featureGroupBaseEngine.updateStatisticsConfig((FeatureGroup) this); } /** @@ -254,7 +240,7 @@ public void updateStatisticsConfig() throws FeatureStoreException, IOException { * @throws IOException */ public Statistics computeStatistics() throws FeatureStoreException, IOException { - if (statisticsEnabled) { + if (statisticsConfig.getEnabled()) { return statisticsEngine.computeStatistics(this, read()); } else { LOGGER.info("StorageWarning: The statistics are not enabled of feature group `" + name + "`, with version `" diff --git a/java/src/main/java/com/logicalclocks/hsfs/metadata/TrainingDatasetApi.java b/java/src/main/java/com/logicalclocks/hsfs/metadata/TrainingDatasetApi.java index 96eeeabaa6..5f1bc949bd 100644 --- a/java/src/main/java/com/logicalclocks/hsfs/metadata/TrainingDatasetApi.java +++ b/java/src/main/java/com/logicalclocks/hsfs/metadata/TrainingDatasetApi.java @@ -24,17 +24,22 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; import org.apache.http.entity.StringEntity; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import static com.logicalclocks.hsfs.metadata.HopsworksClient.PROJECT_PATH; + public class TrainingDatasetApi { private static final String TRAINING_DATASETS_PATH = "/trainingdatasets"; private static final String TRAINING_DATASET_PATH = TRAINING_DATASETS_PATH + "{/tdName}{?version}"; private static final String TRAINING_QUERY_PATH = TRAINING_DATASETS_PATH + "{/tdId}/query{?withLabel}"; + public static final String TRAINING_DATASET_ID_PATH = TRAINING_DATASETS_PATH + "{/fgId}{?updateStatsConfig," + + "updateMetadata}"; private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetApi.class); @@ -103,4 +108,29 @@ public FsQuery getQuery(TrainingDataset trainingDataset, boolean withLabel) return hopsworksClient.handleRequest(getRequest, FsQuery.class); } + + public TrainingDataset updateMetadata(TrainingDataset trainingDataset, String queryParameter) + throws FeatureStoreException, IOException { + HopsworksClient hopsworksClient = HopsworksClient.getInstance(); + String pathTemplate = PROJECT_PATH + + FeatureStoreApi.FEATURE_STORE_PATH + + TRAINING_DATASET_ID_PATH; + + String uri = UriTemplate.fromTemplate(pathTemplate) + .set("projectId", trainingDataset.getFeatureStore().getProjectId()) + .set("fsId", trainingDataset.getFeatureStore().getId()) + .set("fgId", trainingDataset.getId()) + .set(queryParameter, true) + .expand(); + + String trainingDatasetJson = hopsworksClient.getObjectMapper().writeValueAsString(trainingDataset); + HttpPut putRequest = new HttpPut(uri); + putRequest.setHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + putRequest.setEntity(new StringEntity(trainingDatasetJson)); + + LOGGER.info("Sending metadata request: " + uri); + LOGGER.info(trainingDatasetJson); + + return hopsworksClient.handleRequest(putRequest, TrainingDataset.class); + } } diff --git a/mkdocs.yml b/mkdocs.yml index 03a1b953d1..f7ff7689e4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -51,7 +51,7 @@ nav: - Feature: generated/feature.md - Training Dataset: generated/training_dataset.md - Query vs. Dataframe: generated/query_vs_dataframe.md - # - Statistics: guides/statistics.md + - Statistics: generated/statistics.md # - Data Validation: guides/data_validation.md - API Reference: - Connection: generated/api/connection_api.md @@ -60,6 +60,7 @@ nav: - TrainingDataset: generated/api/training_dataset_api.md - Storage Connector: generated/api/storage_connector_api.md - Feature: generated/api/feature_api.md + - StatisticsConfig: generated/api/statistics_config_api.md - Contributing: CONTRIBUTING.md - Support: https://community.hopsworks.ai/ diff --git a/python/hsfs/core/feature_group_base_engine.py b/python/hsfs/core/feature_group_base_engine.py index 0bd06489c0..ab7544cd92 100644 --- a/python/hsfs/core/feature_group_base_engine.py +++ b/python/hsfs/core/feature_group_base_engine.py @@ -41,3 +41,9 @@ def delete_tag(self, feature_group, name): def get_tags(self, feature_group, name): """Get tag with a certain name or all tags for a feature group.""" return [tag.to_dict() for tag in self._tags_api.get(feature_group, name)] + + def update_statistics_config(self, feature_group): + """Update the statistics configuration of a feature group.""" + self._feature_group_api.update_metadata( + feature_group, feature_group, "updateStatsConfig" + ) diff --git a/python/hsfs/core/feature_group_engine.py b/python/hsfs/core/feature_group_engine.py index 33f0306d1c..8665033025 100644 --- a/python/hsfs/core/feature_group_engine.py +++ b/python/hsfs/core/feature_group_engine.py @@ -153,12 +153,6 @@ def commit_delete(feature_group, delete_df, write_options): ) return hudi_engine_instance.delete_record(delete_df, write_options) - def update_statistics_config(self, feature_group): - """Update the statistics configuration of a feature group.""" - self._feature_group_api.update_metadata( - feature_group, feature_group, "updateStatsSettings" - ) - def _get_table_name(self, feature_group): return ( feature_group.feature_store_name diff --git a/python/hsfs/core/training_dataset_api.py b/python/hsfs/core/training_dataset_api.py index adaf52f8a9..2d3d31cba2 100644 --- a/python/hsfs/core/training_dataset_api.py +++ b/python/hsfs/core/training_dataset_api.py @@ -73,7 +73,6 @@ def get_query(self, training_dataset_instance, with_label): def compute(self, training_dataset_instance, td_app_conf): """ Setup a Hopsworks job to compute the query and write the training dataset - Args: training_dataset_instance (training_dataset): the metadata instance of the training dataset app_options ([type]): the configuration for the training dataset job application @@ -95,3 +94,46 @@ def compute(self, training_dataset_instance, td_app_conf): "POST", path_params, headers=headers, data=td_app_conf.json() ) ) + + def update_metadata( + self, training_dataset_instance, training_dataset_copy, query_parameter + ): + """Update the metadata of a training dataset. + + This only updates description and schema/features. The + `training_dataset_copy` is the metadata object sent to the backend, while + `training_dataset_instance` is the user object, which is only updated + after a successful REST call. + + # Arguments + training_dataset_instance: FeatureGroup. User metadata object of the + training dataset. + training_dataset_copy: FeatureGroup. Metadata object of the training + dataset with the information to be updated. + query_parameter: str. Query parameter that will be set to true to + control which information is updated. E.g. "updateMetadata" or + "updateStatsConfig". + + # Returns + FeatureGroup. The updated feature group metadata object. + """ + _client = client.get_instance() + path_params = [ + "project", + _client._project_id, + "featurestores", + self._feature_store_id, + "trainingdatasets", + training_dataset_instance.id, + ] + headers = {"content-type": "application/json"} + query_params = {query_parameter: True} + return training_dataset_instance.update_from_response_json( + _client._send_request( + "PUT", + path_params, + query_params, + headers=headers, + data=training_dataset_copy.json(), + ), + ) diff --git a/python/hsfs/core/training_dataset_engine.py b/python/hsfs/core/training_dataset_engine.py index 1518635bbc..93f479b49c 100644 --- a/python/hsfs/core/training_dataset_engine.py +++ b/python/hsfs/core/training_dataset_engine.py @@ -107,3 +107,9 @@ def delete_tag(self, training_dataset, name): def get_tags(self, training_dataset, name): """Get tag with a certain name or all tags for a training dataset.""" return [tag.to_dict() for tag in self._tags_api.get(training_dataset, name)] + + def update_statistics_config(self, training_dataset): + """Update the statistics configuration of a feature group.""" + self._training_dataset_api.update_metadata( + training_dataset, training_dataset, "updateStatsConfig" + ) diff --git a/python/hsfs/feature_group.py b/python/hsfs/feature_group.py index 006d600198..09b1e4e65d 100644 --- a/python/hsfs/feature_group.py +++ b/python/hsfs/feature_group.py @@ -226,6 +226,33 @@ def get_feature(self, name: str): f"'FeatureGroup' object has no feature called '{name}'." ) + def update_statistics_config(self): + """Update the statistics configuration of the feature group. + + Change the `statistics_config` object and persist the changes by calling + this method. + + # Returns + `FeatureGroup`. The updated metadata object of the feature group. + + # Raises + `RestAPIError`. + """ + self._feature_group_engine.update_statistics_config(self) + return self + + def update_description(self, description: str): + """Update the description of the feature gorup. + + # Arguments + description: str. New description string. + + # Returns + `FeatureGroup`. The updated feature group object. + """ + self._feature_group_engine.update_description(self, description) + return self + def __getattr__(self, name): try: return self.__getitem__(name) @@ -340,10 +367,6 @@ def __init__( features=None, location=None, jobs=None, - desc_stats_enabled=None, - feat_corr_enabled=None, - feat_hist_enabled=None, - statistic_columns=None, online_enabled=False, time_travel_format=None, statistics_config=None, @@ -372,12 +395,6 @@ def __init__( if id is not None: # initialized by backend - self.statistics_config = StatisticsConfig( - desc_stats_enabled, - feat_corr_enabled, - feat_hist_enabled, - statistic_columns, - ) self._primary_key = [ feat.name for feat in self._features if feat.primary is True ] @@ -393,9 +410,12 @@ def __init__( ][0] else: self._hudi_precombine_key = None + self._statistics_config = StatisticsConfig.from_response_json( + statistics_config + ) + else: # initialized by user - self.statistics_config = statistics_config self._primary_key = primary_key self._partition_key = partition_key self._hudi_precombine_key = ( @@ -404,6 +424,7 @@ def __init__( and time_travel_format.upper() == "HUDI" else None ) + self.statistics_config = statistics_config self._feature_group_engine = feature_group_engine.FeatureGroupEngine( featurestore_id @@ -749,10 +770,7 @@ def to_dict(self): "features": self._features, "featurestoreId": self._feature_store_id, "type": "cachedFeaturegroupDTO", - "descStatsEnabled": self._statistics_config.enabled, - "featHistEnabled": self._statistics_config.histograms, - "featCorrEnabled": self._statistics_config.correlations, - "statisticColumns": self._statistics_config.columns, + "statisticsConfig": self._statistics_config, } @property @@ -882,11 +900,6 @@ def __init__( id=None, features=None, jobs=None, - desc_stats_enabled=None, - feat_corr_enabled=None, - feat_hist_enabled=None, - cluster_analysis_enabled=None, - statistic_columns=None, statistics_config=None, ): super().__init__(featurestore_id) @@ -903,10 +916,6 @@ def __init__( self._path = path self._id = id self._jobs = jobs - self._desc_stats_enabled = desc_stats_enabled - self._feat_corr_enabled = feat_corr_enabled - self._feat_hist_enabled = feat_hist_enabled - self._statistic_columns = statistic_columns self._feature_group_engine = ( on_demand_feature_group_engine.OnDemandFeatureGroupEngine(featurestore_id) @@ -919,13 +928,10 @@ def __init__( if features else None ) - - self.statistics_config = StatisticsConfig( - desc_stats_enabled, - feat_corr_enabled, - feat_hist_enabled, - statistic_columns, + self._statistics_config = StatisticsConfig.from_response_json( + statistics_config ) + self._options = ( {option["name"]: option["value"] for option in options} if options @@ -1002,10 +1008,7 @@ def to_dict(self): else None, "storageConnector": self._storage_connector.to_dict(), "type": "onDemandFeaturegroupDTO", - "descStatsEnabled": self._statistics_config.enabled, - "featHistEnabled": self._statistics_config.histograms, - "featCorrEnabled": self._statistics_config.correlations, - "statisticColumns": self._statistics_config.columns, + "statisticsConfig": self._statistics_config, } @property diff --git a/python/hsfs/statistics_config.py b/python/hsfs/statistics_config.py index 90015fc030..db386c1e4f 100644 --- a/python/hsfs/statistics_config.py +++ b/python/hsfs/statistics_config.py @@ -14,14 +14,19 @@ # limitations under the License. # +import json +import humps + +from hsfs import util + class StatisticsConfig: def __init__( self, enabled=True, - correlations=None, - histograms=None, - columns=None, + correlations=False, + histograms=False, + columns=[], ): self._enabled = enabled # use setters for input validation @@ -29,8 +34,25 @@ def __init__( self.histograms = histograms self._columns = columns + @classmethod + def from_response_json(cls, json_dict): + json_decamelized = humps.decamelize(json_dict) + return cls(**json_decamelized) + + def json(self): + return json.dumps(self, cls=util.FeatureStoreEncoder) + + def to_dict(self): + return { + "enabled": self._enabled, + "correlations": self._correlations, + "histograms": self._histograms, + "columns": self._columns, + } + @property def enabled(self): + """Enable statistics, by default this computes only descriptive statistics.""" return self._enabled @enabled.setter @@ -39,34 +61,35 @@ def enabled(self, enabled): @property def correlations(self): + """Enable correlations as an additional statistic to be computed for each + feature pair.""" return self._correlations @correlations.setter def correlations(self, correlations): - if correlations and not self._enabled: - # do validation to fail fast, backend implements same logic - raise ValueError( - "Correlations can only be enabled with general statistics enabled. Set `enabled` in config to `True`." - ) self._correlations = correlations @property def histograms(self): + """Enable histograms as an additional statistic to be computed for each + feature.""" return self._histograms @histograms.setter def histograms(self, histograms): - if histograms and not self._enabled: - # do validation to fail fast, backend implements same logic - raise ValueError( - "Histograms can only be enabled with general statistics enabled. Set `enabled` in config to `True`." - ) self._histograms = histograms @property def columns(self): + """Specify a subset of columns to compute statistics for.""" return self._columns @columns.setter def columns(self, columns): self._columns = columns + + def __str__(self): + return self.json() + + def __repr__(self): + return f"StatisticsConfig({self._enabled}, {self._correlations}, {self._histograms}, {self._columns})" diff --git a/python/hsfs/training_dataset.py b/python/hsfs/training_dataset.py index c1581e55cc..b4ace9a9d6 100644 --- a/python/hsfs/training_dataset.py +++ b/python/hsfs/training_dataset.py @@ -107,7 +107,9 @@ def __init__( ] self._splits = splits self._training_dataset_type = training_dataset_type - self.statistics_config = None + self._statistics_config = StatisticsConfig.from_response_json( + statistics_config + ) self._label = [feat.name for feat in self._features if feat.label] def save( @@ -299,6 +301,21 @@ def get_tag(self, name=None): """ return self._training_dataset_engine.get_tags(self, name) + def update_statistics_config(self): + """Update the statistics configuration of the training dataset. + + Change the `statistics_config` object and persist the changes by calling + this method. + + # Returns + `TrainingDataset`. The updated metadata object of the training dataset. + + # Raises + `RestAPIError`. + """ + self._training_dataset_engine.update_statistics_config(self) + return self + @classmethod def from_response_json(cls, json_dict): json_decamelized = humps.decamelize(json_dict) @@ -342,6 +359,7 @@ def to_dict(self): "splits": self._splits, "seed": self._seed, "queryDTO": self._querydto.to_dict() if self._querydto else None, + "statisticsConfig": self._statistics_config, } @property