Skip to content

Commit

Permalink
fixing some ml tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masseyke committed Dec 6, 2023
1 parent c3a783f commit c1ce720
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 263 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,56 +227,57 @@ private static void createIndex(String index, boolean isDatastream) {
}

private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) {
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numTrainingRows; i++) {
List<Object> source = List.of(
"@timestamp",
"2020-12-12",
BOOLEAN_FIELD,
BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD,
NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
DISCRETE_NUMERICAL_FIELD,
DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
TEXT_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
KEYWORD_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
NESTED_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())
);
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
bulkRequestBuilder.add(indexRequest);
}
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
List<Object> source = new ArrayList<>();
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
}
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
}
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(
List.of(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()))
try (BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)) {
for (int i = 0; i < numTrainingRows; i++) {
List<Object> source = List.of(
"@timestamp",
"2020-12-12",
BOOLEAN_FIELD,
BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
NUMERICAL_FIELD,
NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
DISCRETE_NUMERICAL_FIELD,
DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
TEXT_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
KEYWORD_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
NESTED_FIELD,
KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())
);
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
bulkRequestBuilder.add(indexRequest);
}
if (TEXT_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(TEXT_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
List<Object> source = new ArrayList<>();
if (BOOLEAN_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size())));
}
if (NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size())));
}
if (DISCRETE_NUMERICAL_FIELD.equals(dependentVariable) == false) {
source.addAll(
List.of(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()))
);
}
if (TEXT_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(TEXT_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
if (NESTED_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
}
source.addAll(List.of("@timestamp", "2020-12-12"));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
bulkRequestBuilder.add(indexRequest);
}
if (NESTED_FIELD.equals(dependentVariable) == false) {
source.addAll(List.of(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
source.addAll(List.of("@timestamp", "2020-12-12"));
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray()).opType(DocWriteRequest.OpType.CREATE);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}

Expand Down
Loading

0 comments on commit c1ce720

Please sign in to comment.