Skip to content

Commit

Permalink
Allow users compute statistics over retrieved batch datasets (#799)
Browse files Browse the repository at this point in the history
* Refactor bigquery stats, add functionality to compute statistics over retrieved batch datasets

* Rename dataset

* Add documentation

* Fix end to end tests

* Apply spotless

* Avoid comparing histograms
  • Loading branch information
zhilingc authored Jun 23, 2020
1 parent a1207b5 commit 8c2201c
Show file tree
Hide file tree
Showing 25 changed files with 630 additions and 400 deletions.
21 changes: 11 additions & 10 deletions core/src/main/java/feast/core/service/StatsService.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import feast.proto.core.CoreServiceProto.GetFeatureStatisticsResponse;
import feast.proto.core.StoreProto;
import feast.proto.core.StoreProto.Store.StoreType;
import feast.storage.api.statistics.FeatureSetStatistics;
import feast.storage.api.statistics.FeatureStatistics;
import feast.storage.api.statistics.StatisticsRetriever;
import feast.storage.connectors.bigquery.statistics.BigQueryStatisticsRetriever;
import java.io.IOException;
Expand Down Expand Up @@ -200,7 +200,7 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDataset(
// Else, add to the list of features we still need to retrieve statistics for.
for (String featureName : features) {
Feature feature = featureNameToFeature.get(featureName);
Optional<FeatureStatistics> cachedFeatureStatistics = Optional.empty();
Optional<feast.core.model.FeatureStatistics> cachedFeatureStatistics = Optional.empty();
if (!forceRefresh) {
cachedFeatureStatistics =
featureStatisticsRepository.findFeatureStatisticsByFeatureAndDatasetId(
Expand All @@ -216,7 +216,7 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDataset(
// Retrieve the balance of statistics after checking the cache, and add it to the
// list of FeatureNameStatistics.
if (featuresMissingStats.size() > 0) {
FeatureSetStatistics featureSetStatistics =
FeatureStatistics featureSetStatistics =
statisticsRetriever.getFeatureStatistics(
featureSet.toProto().getSpec(), featuresMissingStats, datasetId);

Expand All @@ -226,9 +226,9 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDataset(
continue;
}
Feature feature = featureNameToFeature.get(stat.getName());
FeatureStatistics featureStatistics =
FeatureStatistics.createForDataset(feature, stat, datasetId);
Optional<FeatureStatistics> existingRecord =
feast.core.model.FeatureStatistics featureStatistics =
feast.core.model.FeatureStatistics.createForDataset(feature, stat, datasetId);
Optional<feast.core.model.FeatureStatistics> existingRecord =
featureStatisticsRepository.findFeatureStatisticsByFeatureAndDatasetId(
featureStatistics.getFeature(), datasetId);
existingRecord.ifPresent(statistics -> featureStatistics.setId(statistics.getId()));
Expand Down Expand Up @@ -270,7 +270,7 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDate(
// Else, add to the list of features we still need to retrieve statistics for.
for (String featureName : features) {
Feature feature = featureNameToFeature.get(featureName);
Optional<FeatureStatistics> cachedFeatureStatistics = Optional.empty();
Optional<feast.core.model.FeatureStatistics> cachedFeatureStatistics = Optional.empty();
if (!forceRefresh) {
cachedFeatureStatistics =
featureStatisticsRepository.findFeatureStatisticsByFeatureAndDate(feature, date);
Expand All @@ -285,7 +285,7 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDate(
// Retrieve the balance of statistics after checking the cache, and add it to the
// list of FeatureNameStatistics.
if (featuresMissingStats.size() > 0) {
FeatureSetStatistics featureSetStatistics =
FeatureStatistics featureSetStatistics =
statisticsRetriever.getFeatureStatistics(
featureSet.toProto().getSpec(),
featuresMissingStats,
Expand All @@ -297,8 +297,9 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDate(
continue;
}
Feature feature = featureNameToFeature.get(stat.getName());
FeatureStatistics featureStatistics = FeatureStatistics.createForDate(feature, stat, date);
Optional<FeatureStatistics> existingRecord =
feast.core.model.FeatureStatistics featureStatistics =
feast.core.model.FeatureStatistics.createForDate(feature, stat, date);
Optional<feast.core.model.FeatureStatistics> existingRecord =
featureStatisticsRepository.findFeatureStatisticsByFeatureAndDate(
featureStatistics.getFeature(), date);
existingRecord.ifPresent(statistics -> featureStatistics.setId(statistics.getId()));
Expand Down
32 changes: 32 additions & 0 deletions docs/user-guide/feature-retrieval.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,38 @@ Feast can retrieve features from any amount of feature sets, as long as they occ

Point-in-time-correct joins also prevents the occurrence of feature leakage by trying to accurate the state of the world at a single point in time, instead of just joining features based on the nearest timestamps.

### **Computing statistics over retrieved data**

Feast is able to compute [TFDV](https://tensorflow.google.cn/tfx/tutorials/data_validation/tfdv_basic) compatible statistics over data retrieved from historical stores. The statistics can be used in conjunction with feature schemas and TFDV to verify the integrity of your retrieved dataset, or to [Facets](https://github.com/PAIR-code/facets) to visualize the distribution.

The computation of statistics is not enabled by default. To indicate to Feast that the statistics are to be computed for a given historical retrieval request, pass `compute_statistics=True` to `get_batch_features`.

```python
dataset = client.get_batch_features(
feature_refs=features,
entity_rows=entity_df
compute_statistics=True
)

stats = dataset.statistics()
```

If a schema is already defined over the feature sets on question, tfdv can be used to detect anomalies over the dataset.

```python
# Build combined schema over retrieved dataset
schema = schema_pb2.Schema()
for feature_set in feature_sets:
fs_schema = feature_set.export_tfx_schema()
for feature_schema in fs_schema.feature:
if feature_schema.name in features:
schema.feature.append(feature_schema)

# detect anomalies
anomalies = tfdv.validate_statistics(statistics=stats, schema=schema)
```


## Online feature retrieval

Online feature retrieval works in much the same way as batch retrieval, with one important distinction: Online stores only maintain the current state of features. No historical data is served.
Expand Down
25 changes: 16 additions & 9 deletions protos/feast/serving/ServingService.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package feast.serving;

import "google/protobuf/timestamp.proto";
import "feast/types/Value.proto";
import "tensorflow_metadata/proto/v0/statistics.proto";

option java_package = "feast.proto.serving";
option java_outer_classname = "ServingAPIProto";
Expand Down Expand Up @@ -100,6 +101,18 @@ message GetOnlineFeaturesRequest {
}
}

message GetBatchFeaturesRequest {
// List of features that are being retrieved
repeated FeatureReference features = 3;

// Source of the entity dataset containing the timestamps and entity keys to retrieve
// features for.
DatasetSource dataset_source = 2;

// Compute statistics for the dataset retrieved
bool compute_statistics = 4;
}

message GetOnlineFeaturesResponse {
// Feature values retrieved from feast.
repeated FieldValues field_values = 1;
Expand Down Expand Up @@ -134,15 +147,6 @@ message GetOnlineFeaturesResponse {
}
}

message GetBatchFeaturesRequest {
// List of features that are being retrieved
repeated FeatureReference features = 3;

// Source of the entity dataset containing the timestamps and entity keys to retrieve
// features for.
DatasetSource dataset_source = 2;
}

message GetBatchFeaturesResponse {
Job job = 1;
}
Expand Down Expand Up @@ -196,6 +200,9 @@ message Job {
// Output only. The data format for all the files.
// For CSV format, the files contain both feature values and a column header.
DataFormat data_format = 6;
// Output only. The statistics computed over
// the retrieved dataset. Only available for BigQuery stores.
tensorflow.metadata.v0.DatasetFeatureStatisticsList dataset_feature_statistics_list = 7;
}

message DatasetSource {
Expand Down
8 changes: 6 additions & 2 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
GetFeatureSetRequest,
GetFeatureSetResponse,
GetFeatureStatisticsRequest,
ListFeaturesRequest,
ListFeaturesResponse,
ListFeatureSetsRequest,
ListFeatureSetsResponse,
ListFeaturesRequest,
ListFeaturesResponse,
ListIngestionJobsRequest,
ListProjectsRequest,
ListProjectsResponse,
Expand Down Expand Up @@ -561,6 +561,7 @@ def get_batch_features(
self,
feature_refs: List[str],
entity_rows: Union[pd.DataFrame, str],
compute_statistics: bool = False,
project: str = None,
) -> RetrievalJob:
"""
Expand All @@ -577,6 +578,8 @@ def get_batch_features(
Each entity in a feature set must be present as a column in this
dataframe. The datetime column must contain timestamps in
datetime64 format.
compute_statistics (bool):
Indicates whether Feast should compute statistics over the retrieved dataset.
project: Specifies the project which contain the FeatureSets
which the requested features belong to.
Expand Down Expand Up @@ -656,6 +659,7 @@ def get_batch_features(
file_uris=staged_files, data_format=DataFormat.DATA_FORMAT_AVRO
)
),
compute_statistics=compute_statistics,
)

# Retrieve Feast Job object to manage life cycle of retrieval
Expand Down
21 changes: 21 additions & 0 deletions sdk/python/feast/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from feast.serving.ServingService_pb2_grpc import ServingServiceStub
from feast.source import Source
from feast.wait import wait_retry_backoff
from tensorflow_metadata.proto.v0 import statistics_pb2


class RetrievalJob:
Expand Down Expand Up @@ -193,6 +194,26 @@ def to_chunked_dataframe(
def __iter__(self):
return iter(self.result())

def statistics(
self, timeout_sec: int = int(defaults[CONFIG_TIMEOUT_KEY])
) -> statistics_pb2.DatasetFeatureStatisticsList:
"""
Get statistics computed over the retrieved data set. Statistics will only be computed for
columns that are part of Feast, and not the columns that were provided.
Args:
timeout_sec (int):
Max no of seconds to wait until job is done. If "timeout_sec"
is exceeded, an exception will be raised.
Returns:
DatasetFeatureStatisticsList containing statistics of Feast features over the retrieved dataset.
"""
self.get_avro_files(timeout_sec) # wait for job completion
if self.job_proto.error:
raise Exception(self.job_proto.error)
return self.job_proto.dataset_feature_statistics_list


class IngestJob:
"""
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from feast.core.CoreService_pb2 import (
GetFeastCoreVersionResponse,
GetFeatureSetResponse,
ListFeaturesResponse,
ListFeatureSetsResponse,
ListFeaturesResponse,
ListIngestionJobsResponse,
)
from feast.core.FeatureSet_pb2 import EntitySpec as EntitySpecProto
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ public GetBatchFeaturesResponse getBatchFeatures(GetBatchFeaturesRequest getFeat
public void run() {
HistoricalRetrievalResult result =
retriever.getHistoricalFeatures(
retrievalId, getFeaturesRequest.getDatasetSource(), featureSetRequests);
retrievalId,
getFeaturesRequest.getDatasetSource(),
featureSetRequests,
getFeaturesRequest.getComputeStatistics());
jobService.upsert(resultToJob(result));
}
});
Expand Down Expand Up @@ -111,9 +114,11 @@ private Job resultToJob(HistoricalRetrievalResult result) {
if (result.hasError()) {
return builder.setError(result.getError()).build();
}
return builder
.addAllFileUris(result.getFileUris())
.setDataFormat(result.getDataFormat())
.build();
Builder jobBuilder =
builder.addAllFileUris(result.getFileUris()).setDataFormat(result.getDataFormat());
if (result.getStats() != null) {
jobBuilder.setDatasetFeatureStatisticsList(result.getStats());
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.Serializable;
import java.util.List;
import javax.annotation.Nullable;
import org.tensorflow.metadata.v0.DatasetFeatureStatisticsList;

/** Result of a historical feature retrieval request. */
@AutoValue
Expand All @@ -40,6 +41,9 @@ public abstract class HistoricalRetrievalResult implements Serializable {
@Nullable
public abstract DataFormat getDataFormat();

@Nullable
public abstract DatasetFeatureStatisticsList getStats();

/**
* Instantiates a {@link HistoricalRetrievalResult} indicating that the retrieval was a failure,
* together with its associated error.
Expand Down Expand Up @@ -75,10 +79,29 @@ public static HistoricalRetrievalResult success(
.build();
}

/**
* Adds statistics to the result
*
* @param stats {@link DatasetFeatureStatisticsList} for the retrieved dataset
* @return {@link HistoricalRetrievalResult}
*/
public HistoricalRetrievalResult withStats(DatasetFeatureStatisticsList stats) {
return toBuilder().setStats(stats).build();
}

static Builder newBuilder() {
return new AutoValue_HistoricalRetrievalResult.Builder();
}

Builder toBuilder() {
return newBuilder()
.setId(getId())
.setStatus(getStatus())
.setFileUris(getFileUris())
.setError(getError())
.setDataFormat(getDataFormat());
}

@AutoValue.Builder
abstract static class Builder {
abstract Builder setId(String id);
Expand All @@ -91,6 +114,8 @@ abstract static class Builder {

abstract Builder setDataFormat(DataFormat dataFormat);

abstract Builder setStats(DatasetFeatureStatisticsList stats);

abstract HistoricalRetrievalResult build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ public interface HistoricalRetriever {
* entity columns.
* @param featureSetRequests List of {@link FeatureSetRequest} to feature references in the
* request tied to that feature set.
* @param computeStatistics whether to compute statistics over the resultant dataset.
* @return {@link HistoricalRetrievalResult} if successful, contains the location of the results,
* else contains the error to be returned to the user.
*/
HistoricalRetrievalResult getHistoricalFeatures(
String retrievalId, DatasetSource datasetSource, List<FeatureSetRequest> featureSetRequests);
String retrievalId,
DatasetSource datasetSource,
List<FeatureSetRequest> featureSetRequests,
boolean computeStatistics);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
import com.google.common.collect.ImmutableList;
import org.tensorflow.metadata.v0.FeatureNameStatistics;

/** Feature statistics for a feature set over a bounded set of data. */
/** Feature statistics over a bounded set of data. */
@AutoValue
public abstract class FeatureSetStatistics {
public abstract class FeatureStatistics {

public abstract long getNumExamples();

public abstract ImmutableList<FeatureNameStatistics> getFeatureNameStatistics();

public static Builder newBuilder() {
return new AutoValue_FeatureSetStatistics.Builder();
return new AutoValue_FeatureStatistics.Builder();
}

@AutoValue.Builder
Expand All @@ -43,6 +43,6 @@ public Builder addFeatureNameStatistics(FeatureNameStatistics featureNameStatist
return this;
}

public abstract FeatureSetStatistics build();
public abstract FeatureStatistics build();
}
}
Loading

0 comments on commit 8c2201c

Please sign in to comment.