Skip to content

Commit

Permalink
add udaf PatternMatch DTWMatch
Browse files Browse the repository at this point in the history
  • Loading branch information
CritasWang committed Nov 29, 2024
1 parent a2d0117 commit e0b6855
Show file tree
Hide file tree
Showing 20 changed files with 11,993 additions and 0 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.library.match;

public class PatternMatchConfig {

// Minimum height (in percentage, related to the entire section) that is needed to create a new
// section
// This is to avoid that very small sections that are similar to a horizontal line create many
// sections
// Without it the algorithm that divides the sequence in sections would be too much sensible to
// noises
// was 0.01 for tests
public static final double DIVIDE_SECTION_MIN_HEIGHT_DATA = 0.01;

// 0.01 for tests, 0.1 for 1NN
public static final double DIVIDE_SECTION_MIN_HEIGHT_QUERY = 0.01;

public static final int MAX_REGEX_IT = 25;

// query compatibility
/**
* if the number of sections is N, and the number of sections with a different sign is D. The
* algorithm consider the two subsequences as incompatible if D/N > 0.5
*/
public static final double QUERY_SIGN_MAXIMUM_TOLERABLE_DIFFERENT_SIGN_SECTIONS = 0.5;

/**
* keep only one (best) match if the same area is selected in different smooth iterations with not
* experiments it's better a false, so every smooth iteration has a not match, so they are easier
* to view
*/
public static final boolean REMOVE_EQUAL_MATCHES = false;

/** true for tests, false for 1NN */
public static final boolean CHECK_QUERY_COMPATIBILITY = true;

/** the first and last sections are cut to have a good fit */
public static final boolean START_END_CUT_IN_SUBPARTS = true;

/**
* the first and last sections are cut as well in the results, or are returned highlighting the
* whole section. false in tests
*/
public static final boolean START_END_CUT_IN_SUBPARTS_IN_RESULTS = true;

/** true for tests */
public static final boolean RESCALING_Y = true;

public static final int VALUE_DIFFERENCE_WEIGHT = 1;
public static final int RESCALING_COST_WEIGHT = 1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.library.match;

import org.apache.iotdb.library.match.model.DTWMatchResult;
import org.apache.iotdb.library.match.model.DTWState;
import org.apache.iotdb.udf.api.State;
import org.apache.iotdb.udf.api.UDAF;
import org.apache.iotdb.udf.api.customizer.config.UDAFConfigurations;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
import org.apache.iotdb.udf.api.type.Type;
import org.apache.iotdb.udf.api.utils.ResultValue;

import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.BitMap;

import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class UDAFDTWMatch implements UDAF {
// private static final Logger LOGGER = Logger.getLogger(DTWMatch.class);
private Double[] pattern;
private float threshold;
private DTWState state;

@Override
public void beforeStart(UDFParameters udfParameters, UDAFConfigurations udafConfigurations) {
udafConfigurations.setOutputDataType(Type.TEXT);
Map<String, String> attributes = udfParameters.getAttributes();
threshold = Float.parseFloat(attributes.get("threshold"));
pattern =
Arrays.stream(attributes.get("pattern").split(","))
.map(Double::valueOf)
.toArray(Double[]::new);
if (state != null) {
state.setSize(pattern.length);
}
}

@Override
public State createState() {
if (pattern != null) {
state = new DTWState(pattern.length);
} else {
state = new DTWState();
}

return state;
}

@Override
public void addInput(State state, Column[] columns, BitMap bitMap) {
DTWState DTWState = (DTWState) state;

int count = columns[0].getPositionCount();
for (int i = 0; i < count; i++) {
if (bitMap != null && !bitMap.isMarked(i)) {
continue;
}
if (!columns[1].isNull(i)) {
long timestamp = columns[1].getLong(i);
double value = getValue(columns[0], i);
DTWState.updateBuffer(timestamp, value);
if (DTWState.getValueBuffer().length == pattern.length) {
float dtw = calculateDTW(DTWState.getValueBuffer(), pattern);
if (dtw <= threshold) {
((DTWState) state)
.addMatchResult(
new DTWMatchResult(dtw, DTWState.getFirstTime(), DTWState.getLastTime()));
}
}
}
}
}

private double getValue(Column column, int i) {
switch (column.getDataType()) {
case INT32:
return column.getInt(i);
case INT64:
return column.getLong(i);
case FLOAT:
return column.getFloat(i);
case DOUBLE:
return column.getDouble(i);
case BOOLEAN:
return column.getBoolean(i) ? 1.0D : 0.0D;
default:
throw new RuntimeException(String.format("Unsupported datatype %s", column.getDataType()));
}
}

private float calculateDTW(Double[] series1, Double[] series2) {
int n = series1.length;
double[][] dtw = new double[n][n];

// Initialize the DTW matrix
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
dtw[i][j] = Double.POSITIVE_INFINITY;
}
}
dtw[0][0] = 0;

// Compute the DTW distance
for (int i = 1; i < n; i++) {
for (int j = 1; j < n; j++) {
double cost = Math.abs(series1[i] - series2[j]);
dtw[i][j] = cost + Math.min(Math.min(dtw[i - 1][j], dtw[i][j - 1]), dtw[i - 1][j - 1]);
}
}

return (float) dtw[n - 1][n - 1];
}

@Override
public void combineState(State state, State state1) {
DTWState dtwState = (DTWState) state;
DTWState newDTWState = (DTWState) state1;

Long[] times = newDTWState.getTimeBuffer();
Double[] values = newDTWState.getValueBuffer();

for (int i = 0; i < times.length; i++) {
if (times[i] > dtwState.getFirstTime()) {
dtwState.updateBuffer(times[i], values[i]);
if (dtwState.getValueBuffer().length == pattern.length) {
float dtw = calculateDTW(dtwState.getValueBuffer(), pattern);
if (dtw <= threshold) {
dtwState.addMatchResult(
new DTWMatchResult(dtw, dtwState.getFirstTime(), dtwState.getLastTime()));
}
}
}
}
}

public List<DTWMatchResult> calcMatch(
List<Long> times, List<Double> values, Double[] pattern, float threshold) {
this.pattern = pattern;
this.threshold = threshold;
DTWState dtwState = (DTWState) this.createState();
dtwState.reset();
for (int i = 0; i < times.size(); i++) {
dtwState.updateBuffer(times.get(i), values.get(i));
if (dtwState.getValueBuffer().length == pattern.length) {
float dtw = calculateDTW(dtwState.getValueBuffer(), pattern);
if (dtw <= threshold) {
dtwState.addMatchResult(
new DTWMatchResult(dtw, dtwState.getFirstTime(), dtwState.getLastTime()));
}
}
}
return dtwState.getMatchResults();
}

@Override
public void outputFinal(State state, ResultValue resultValue) {
DTWState DTWState = (DTWState) state;
List<DTWMatchResult> matchResults = DTWState.getMatchResults();
if (!matchResults.isEmpty()) {
resultValue.setBinary(new Binary(matchResults.toString(), Charset.defaultCharset()));
} else {
resultValue.setNull();
}
}

@Override
public void removeState(State state, State removed) {}

@Override
public void validate(UDFParameterValidator validator) {
validator
.validateInputSeriesNumber(1)
.validateInputSeriesDataType(
0, Type.INT32, Type.INT64, Type.FLOAT, Type.DOUBLE, Type.BOOLEAN)
.validateRequiredAttribute("pattern")
.validateRequiredAttribute("threshold");
}
}
Loading

0 comments on commit e0b6855

Please sign in to comment.