Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement filter for create dataset api #215

Merged
merged 5 commits into from
Jun 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions core/src/main/java/feast/core/config/TrainingConfig.java
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package feast.core.config;

import com.google.cloud.bigquery.BigQuery;
import com.google.cloud.bigquery.BigQueryOptions;
import com.google.common.base.Charsets;
import com.google.common.io.CharStreams;
import com.hubspot.jinjava.Jinjava;
import feast.core.config.StorageConfig.StorageSpecs;
import feast.core.dao.FeatureInfoRepository;
import feast.core.training.BigQueryDatasetTemplater;
import feast.core.training.BigQueryTraningDatasetCreator;
import feast.core.util.RandomUuidProvider;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
Expand All @@ -18,9 +17,7 @@
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;

/**
* Configuration related to training API
*/
/** Configuration related to training API */
@Configuration
public class TrainingConfig {

Expand All @@ -37,10 +34,9 @@ public BigQueryDatasetTemplater getBigQueryTrainingDatasetTemplater(
@Bean
public BigQueryTraningDatasetCreator getBigQueryTrainingDatasetCreator(
BigQueryDatasetTemplater templater,
StorageSpecs storageSpecs,
@Value("${feast.core.projectId}") String projectId,
@Value("${feast.core.datasetPrefix}") String datasetPrefix) {
BigQuery bigquery = BigQueryOptions.newBuilder().setProjectId(projectId).build().getService();
return new BigQueryTraningDatasetCreator(templater, projectId, datasetPrefix);
return new BigQueryTraningDatasetCreator(
templater, projectId, datasetPrefix, new RandomUuidProvider());
}
}
3 changes: 2 additions & 1 deletion core/src/main/java/feast/core/grpc/DatasetServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public void createDataset(
request.getStartDate(),
request.getEndDate(),
request.getLimit(),
request.getNamePrefix());
request.getNamePrefix(),
request.getFiltersMap());
CreateDatasetResponse response =
CreateDatasetResponse.newBuilder().setDatasetInfo(datasetInfo).build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
import feast.core.model.FeatureInfo;
import feast.core.storage.BigQueryStorageManager;
import feast.specs.StorageSpecProto.StorageSpec;
import feast.types.ValueProto.ValueType.Enum;
import java.time.Instant;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Getter;


public class BigQueryDatasetTemplater {

Expand All @@ -45,7 +45,9 @@ public class BigQueryDatasetTemplater {
private final DateTimeFormatter formatter;

public BigQueryDatasetTemplater(
Jinjava jinjava, String templateString, StorageSpec storageSpec,
Jinjava jinjava,
String templateString,
StorageSpec storageSpec,
FeatureInfoRepository featureInfoRepository) {
this.storageSpec = storageSpec;
this.featureInfoRepository = featureInfoRepository;
Expand All @@ -65,35 +67,84 @@ protected StorageSpec getStorageSpec() {
* @param startDate start date
* @param endDate end date
* @param limit limit
* @param filters additional WHERE clause
* @return SQL query for creating training table.
*/
String createQuery(FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) {
String createQuery(
FeatureSet featureSet,
Timestamp startDate,
Timestamp endDate,
long limit,
Map<String, String> filters) {
List<String> featureIds = featureSet.getFeatureIdsList();
List<FeatureInfo> featureInfos = featureInfoRepository.findAllById(featureIds);
String tableId = featureInfos.size() > 0 ? getBqTableId(featureInfos.get(0)) : "";
Features features = new Features(featureInfos, tableId);
List<FeatureInfo> featureInfos = getFeatureInfosOrThrow(featureIds);

// split filter based on ValueType of the feature
Map<String, String> tmpFilter = new HashMap<>(filters);
Map<String, String> numberFilters = new HashMap<>();
Map<String, String> stringFilters = new HashMap<>();
if (filters.containsKey("job_id")) {
stringFilters.put("job_id", tmpFilter.get("job_id"));
tmpFilter.remove("job_id");
}

List<FeatureInfo> featureFilterInfos = getFeatureInfosOrThrow(new ArrayList<>(tmpFilter.keySet()));
Map<String, FeatureInfo> featureInfoMap = new HashMap<>();
for (FeatureInfo featureInfo: featureFilterInfos) {
featureInfoMap.put(featureInfo.getId(), featureInfo);
}


for (Map.Entry<String, String> filter : tmpFilter.entrySet()) {
FeatureInfo featureInfo = featureInfoMap.get(filter.getKey());
if (isMappableToString(featureInfo.getValueType())) {
stringFilters.put(featureInfo.getName(), filter.getValue());
} else {
numberFilters.put(featureInfo.getName(), filter.getValue());
}
}

List<String> featureNames = getFeatureNames(featureInfos);
String tableId = getBqTableId(featureInfos.get(0));
String startDateStr = formatDateString(startDate);
String endDateStr = formatDateString(endDate);
String limitStr = (limit != 0) ? String.valueOf(limit) : null;
return renderTemplate(tableId, featureNames, startDateStr, endDateStr, limitStr,
numberFilters, stringFilters);
}

private boolean isMappableToString(Enum valueType) {
return valueType.equals(Enum.STRING);
}

private List<String> getFeatureNames(List<FeatureInfo> featureInfos) {
return featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList());
}

private List<FeatureInfo> getFeatureInfosOrThrow(List<String> featureIds) {
List<FeatureInfo> featureInfos = featureInfoRepository.findAllById(featureIds);
if (featureInfos.size() < featureIds.size()) {
Set<String> foundFeatureIds =
featureInfos.stream().map(FeatureInfo::getId).collect(Collectors.toSet());
featureIds.removeAll(foundFeatureIds);
throw new NoSuchElementException("features not found: " + featureIds);
}

String startDateStr = formatDateString(startDate);
String endDateStr = formatDateString(endDate);
String limitStr = (limit != 0) ? String.valueOf(limit) : null;
return renderTemplate(features, startDateStr, endDateStr, limitStr);
return featureInfos;
}

private String renderTemplate(
Features features, String startDateStr, String endDateStr, String limitStr) {
String tableId, List<String> features, String startDateStr, String endDateStr, String limitStr,
Map<String, String> numberFilters,
Map<String, String> stringFilters) {
Map<String, Object> context = new HashMap<>();

context.put("feature_set", features);
context.put("table_id", tableId);
context.put("features", features);
context.put("start_date", startDateStr);
context.put("end_date", endDateStr);
context.put("limit", limitStr);
context.put("number_filters", numberFilters);
context.put("string_filters", stringFilters);
return jinjava.render(template, context);
}

Expand All @@ -117,16 +168,4 @@ private String formatDateString(Timestamp timestamp) {
Instant instant = Instant.ofEpochSecond(timestamp.getSeconds()).truncatedTo(ChronoUnit.DAYS);
return formatter.format(instant);
}

@Getter
static final class Features {

final List<String> columns;
final String tableId;

Features(List<FeatureInfo> featureInfos, String tableId) {
columns = featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList());
this.tableId = tableId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,10 @@
import feast.core.DatasetServiceProto.DatasetInfo;
import feast.core.DatasetServiceProto.FeatureSet;
import feast.core.exception.TrainingDatasetCreationException;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import feast.core.util.UuidProvider;
import java.time.Instant;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -49,26 +44,34 @@ public class BigQueryTraningDatasetCreator {
private final DateTimeFormatter formatter;
private final String projectId;
private final String datasetPrefix;
private final UuidProvider uuidProvider;
private transient BigQuery bigQuery;

public BigQueryTraningDatasetCreator(
BigQueryDatasetTemplater templater,
String projectId,
String datasetPrefix) {
this(templater, projectId, datasetPrefix,
String datasetPrefix,
UuidProvider uuidProvider) {
this(
templater,
projectId,
datasetPrefix,
uuidProvider,
BigQueryOptions.newBuilder().setProjectId(projectId).build().getService());
}

public BigQueryTraningDatasetCreator(
BigQueryDatasetTemplater templater,
String projectId,
String datasetPrefix,
UuidProvider uuidProvider,
BigQuery bigQuery) {
this.templater = templater;
this.formatter = DateTimeFormatter.ofPattern("yyyyMMdd").withZone(ZoneId.of("UTC"));
this.projectId = projectId;
this.datasetPrefix = datasetPrefix;
this.bigQuery = bigQuery;
this.uuidProvider = uuidProvider;
}

/**
Expand All @@ -80,18 +83,19 @@ public BigQueryTraningDatasetCreator(
* @param endDate end date of the training dataset (inclusive)
* @param limit maximum number of row should be created.
* @param namePrefix prefix for dataset name
* @param filters additional where clause
* @return dataset info associated with the created training dataset
*/
public DatasetInfo createDataset(
FeatureSet featureSet,
Timestamp startDate,
Timestamp endDate,
long limit,
String namePrefix) {
String namePrefix,
Map<String, String> filters) {
try {
String query = templater.createQuery(featureSet, startDate, endDate, limit);
String tableName =
createBqTableName(datasetPrefix, featureSet, startDate, endDate, namePrefix);
String query = templater.createQuery(featureSet, startDate, endDate, limit, filters);
String tableName = createBqTableName(datasetPrefix, featureSet, namePrefix);
String tableDescription = createBqTableDescription(featureSet, startDate, endDate, query);

Map<String, String> options = templater.getStorageSpec().getOptionsMap();
Expand Down Expand Up @@ -124,47 +128,22 @@ public DatasetInfo createDataset(
throw new TrainingDatasetCreationException("Failed creating training dataset", e);
} catch (InterruptedException e) {
log.error("Training dataset creation was interrupted", e);
throw new TrainingDatasetCreationException("Training dataset creation was interrupted",
e);
throw new TrainingDatasetCreationException("Training dataset creation was interrupted", e);
}
}

private String createBqTableName(
String datasetPrefix,
FeatureSet featureSet,
Timestamp startDate,
Timestamp endDate,
String namePrefix) {

List<String> features = new ArrayList(featureSet.getFeatureIdsList());
Collections.sort(features);
private String createBqTableName(String datasetPrefix, FeatureSet featureSet, String namePrefix) {

String datasetId = String.format("%s_%s_%s", features, startDate, endDate);
StringBuilder hashText;

// create hash from datasetId
try {
MessageDigest md = MessageDigest.getInstance("SHA-1");
byte[] messageDigest = md.digest(datasetId.getBytes());
BigInteger no = new BigInteger(1, messageDigest);
hashText = new StringBuilder(no.toString(16));
while (hashText.length() < 32) {
hashText.insert(0, "0");
}
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
String suffix = uuidProvider.getUuid();

if (!Strings.isNullOrEmpty(namePrefix)) {
// only alphanumeric and underscore are allowed
namePrefix = namePrefix.replaceAll("[^a-zA-Z0-9_]", "_");
return String.format(
"%s_%s_%s_%s", datasetPrefix, featureSet.getEntityName(), namePrefix,
hashText.toString());
"%s_%s_%s_%s", datasetPrefix, featureSet.getEntityName(), namePrefix, suffix);
}

return String.format(
"%s_%s_%s", datasetPrefix, featureSet.getEntityName(), hashText.toString());
return String.format("%s_%s_%s", datasetPrefix, featureSet.getEntityName(), suffix);
}

private String createBqTableDescription(
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/feast/core/util/RandomUuidProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package feast.core.util;

import java.util.UUID;

public class RandomUuidProvider implements UuidProvider {
@Override
public String getUuid() {
return UUID.randomUUID().toString().replace("-","");
}
}
5 changes: 5 additions & 0 deletions core/src/main/java/feast/core/util/UuidProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package feast.core.util;

public interface UuidProvider {
String getUuid();
}
8 changes: 5 additions & 3 deletions core/src/main/resources/templates/bq_training.tmpl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
SELECT
id,
event_timestamp{%- if feature_set.columns | length > 0 %},{%- endif %}
{{ feature_set.columns | join(',') }}
event_timestamp{%- if features | length > 0 %},{%- endif %}
{{ features | join(',') }}
FROM
`{{ feature_set.tableId }}`
`{{ table_id }}`
WHERE event_timestamp >= TIMESTAMP("{{ start_date }}") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("{{ end_date }}", INTERVAL 1 DAY))
{%- for key, val in number_filters.items() %} AND {{ key }} = {{ val }} {%- endfor %}
{%- for key, val in string_filters.items() %} AND {{ key }} = "{{ val }}" {%- endfor %}
{% if limit is not none -%}
LIMIT {{ limit }}
{%- endif %}
Loading