Skip to content

Commit

Permalink
Ann based on bruteforce kmeans
Browse files Browse the repository at this point in the history
Procedure to run:
1. use external scipy/scikit-learn kmeans algorithm on your data
2. compute centroids and save them into a file "centroids.txt" as
    a numpy array in a binary form (use numpy array "tofile" function)
3. also compute labels for all points (what centroids them belong to)
    and save into a file
4. put centroids.txt file inside elasticsearch/x-pack/plugin/vectors/src/main/resources
    and build elasticsearch from it and use this build for test
5. from any client create index:
{
  "mappings": {
    "dynamic": "false",
    "properties": {
      "vector": {
        "type": "dense_vector",
        "dims": 128
      }
    }
  }
}
6. index your vectors from any client using labels file computed in step 3
{
  "vector": {
  	"centroid" : 39,
        "value": [0.12, 22.01, ...]

  }
}
7. find closes points based on ann query
{
  "query": {
    "script_score": {
      "query": {
        "ann": {
          "field": "vector",
          "number_of_probes": 3,
          "query_vector": [3.4, 10.12, ...]
        }
      },
      "script": {
        "source": "1 / (1 + l2norm(params.query_vector, doc['vector']))",
        "params": {
          "query_vector": [3.4, 10.12, ...]
        }
      }
    }
  }
}

Relates to  elastic#42326
  • Loading branch information
mayya-sharipova committed Sep 20, 2019
1 parent e8dac62 commit 2da0a93
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ public boolean implies(ProtectionDomain domain, Permission permission) {
}
}

// this is only for the prototype
if (permission instanceof FilePermission) {
return true;
}

// otherwise defer to template + dynamic file permissions
return template.implies(domain, permission) || dynamic.implies(permission) || system.implies(domain, permission);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.action.XPackInfoFeatureAction;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.query.AnnQueryBuilder;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonList;

public class Vectors extends Plugin implements MapperPlugin, ActionPlugin {
public class Vectors extends Plugin implements MapperPlugin, ActionPlugin, SearchPlugin {

protected final boolean enabled;

Expand All @@ -52,4 +56,12 @@ public Map<String, Mapper.TypeParser> getMappers() {
mappers.put(SparseVectorFieldMapper.CONTENT_TYPE, new SparseVectorFieldMapper.TypeParser());
return Collections.unmodifiableMap(mappers);
}

@Override
public List<QuerySpec<?>> getQueries() {
if (enabled == false) {
return emptyList();
}
return singletonList(new QuerySpec<>(AnnQueryBuilder.NAME, AnnQueryBuilder::new, AnnQueryBuilder::fromXContent));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
package org.elasticsearch.xpack.vectors.mapper;

import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
Expand All @@ -30,7 +32,12 @@
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
Expand All @@ -45,6 +52,7 @@ public class DenseVectorFieldMapper extends FieldMapper implements ArrayValueMap
public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 1024; //maximum allowed number of dimensions
private static final byte INT_BYTES = 4;
private static final int CENTROIDS_COUNT = 1000;

public static class Defaults {
public static final MappedFieldType FIELD_TYPE = new DenseVectorFieldType();
Expand All @@ -60,6 +68,8 @@ public static class Defaults {

public static class Builder extends FieldMapper.Builder<Builder, DenseVectorFieldMapper> {
private int dims = 0;
private float[][] centroids = null;
private float[] centroidsSquaredMagnitudes = null;

public Builder(String name) {
super(name, Defaults.FIELD_TYPE, Defaults.FIELD_TYPE);
Expand All @@ -75,10 +85,39 @@ public Builder dims(int dims) {
return this;
}

public Builder buildCentroids() {
if (centroids != null) return this;
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
try {
InputStream istream = getClass().getResourceAsStream("/centroids.txt");
byte[] bytes = istream.readAllBytes();
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer fbuffer = buffer.asFloatBuffer();
centroids = new float[CENTROIDS_COUNT][dims];
centroidsSquaredMagnitudes = new float[CENTROIDS_COUNT];
for (int i = 0; i < CENTROIDS_COUNT; i++) {
centroidsSquaredMagnitudes[i] = 0;
for (int dim = 0; dim < dims; dim++) {
centroids[i][dim] = fbuffer.get();
centroidsSquaredMagnitudes[i] += centroids[i][dim] * centroids[i][dim];
}
}
istream.close();
} catch (IOException e) {
throw new MapperParsingException("Could not load centroids");
}
return null;
});
return this;
}

@Override
protected void setupFieldType(BuilderContext context) {
super.setupFieldType(context);
fieldType().setDims(dims);
fieldType().setCentroids(centroids);
fieldType().setCentroidsSquaredMagnitudes(centroidsSquaredMagnitudes);
}

@Override
Expand All @@ -104,12 +143,16 @@ public Mapper.Builder<?,?> parse(String name, Map<String, Object> node, ParserCo
throw new MapperParsingException("The [dims] property must be specified for field [" + name + "].");
}
int dims = XContentMapValues.nodeIntegerValue(dimsField);
return builder.dims(dims);
builder.dims(dims);
builder.buildCentroids();
return builder;
}
}

public static final class DenseVectorFieldType extends MappedFieldType {
private int dims;
private float[][] centroids;
private float[] centroidsSquaredMagnitudes;

public DenseVectorFieldType() {}

Expand Down Expand Up @@ -155,6 +198,22 @@ public Query termQuery(Object value, QueryShardContext context) {
throw new UnsupportedOperationException(
"Field [" + name() + "] of type [" + typeName() + "] doesn't support queries");
}

public void setCentroids(float[][] centroids) {
this.centroids = centroids;
}

public float[][] getCentroids() {
return centroids;
}

public void setCentroidsSquaredMagnitudes(float[] centroidsSquaredMagnitudes) {
this.centroidsSquaredMagnitudes = centroidsSquaredMagnitudes;
}

public float[] getCentroidsSquaredMagnitudes() {
return centroidsSquaredMagnitudes;
}
}

private DenseVectorFieldMapper(String simpleName, MappedFieldType fieldType, MappedFieldType defaultFieldType,
Expand Down Expand Up @@ -183,39 +242,54 @@ public void parse(ParseContext context) throws IOException {
// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_5_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
double dotProduct = 0f;

int dim = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has exceeded the number of dimensions [" + dims + "] defined in mapping");
byte[] centroidCode = new byte[2]; // 2 bytes for centroid, max centroid value -- 1024
Token token;
while ((token = context.parser().nextToken()) != Token.END_OBJECT) {
if (token == Token.FIELD_NAME) {
String fieldName = context.parser().currentName();
if (fieldName.equals("centroid")) {
token = context.parser().nextToken();
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
short centroid = context.parser().shortValue();
centroidCode[0] = (byte) (centroid >> 8);
centroidCode[1] = (byte) centroid;
} else if (fieldName.equals("value")) {
token = context.parser().nextToken();
ensureExpectedToken(Token.START_ARRAY, token, context.parser()::getTokenLocation);
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
double dotProduct = 0f;
int dim = 0;
for (token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has exceeded the number of dimensions [" + dims + "] defined in mapping");
}
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
float value = context.parser().floatValue(true);
byteBuffer.putFloat(value);
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}
if (indexCreatedVersion.onOrAfter(Version.V_7_5_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
}
}
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
float value = context.parser().floatValue(true);

byteBuffer.putFloat(value);
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}

if (indexCreatedVersion.onOrAfter(Version.V_7_5_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes));
if (context.doc().getByKey(fieldType().name()) != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] doesn't not support indexing multiple values for the same field in the same document");
}
context.doc().addWithKey(fieldType().name(), field);
StringField centroidField = new StringField(fieldType().name() + ".centroid", new BytesRef(centroidCode), Field.Store.NO);
context.doc().add(centroidField);
}

@Override
Expand Down
Loading

0 comments on commit 2da0a93

Please sign in to comment.