Skip to content

Commit

Permalink
renaming around to make it clearer
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandrobenedetti committed Nov 18, 2024
1 parent af1cb50 commit 216180c
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ public class SolrEmbeddingModel implements Accountable {

private final String name;
private final Map<String, Object> params;
private final EmbeddingModel embedder;
private final EmbeddingModel textToVector;
private final Integer hashCode;

public static SolrEmbeddingModel getInstance(
String className, String name, Map<String, Object> params) throws EmbeddingModelException {
try {
/*
* The idea herea is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion
* The idea here is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion
* of control.
* Each model has its own list of parameters we don't know beforehand, but each {@link dev.langchain4j.model.embedding.EmbeddingModel} class
* has its own builder that uses setters with the same name of the parameter in input.
* */
EmbeddingModel embedder;
EmbeddingModel textToVector;
Class<?> modelClass = Class.forName(className);
var builder = modelClass.getMethod("builder").invoke(null);
if (params != null) {
Expand Down Expand Up @@ -94,22 +94,22 @@ public static SolrEmbeddingModel getInstance(
}
}
}
embedder = (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder);
return new SolrEmbeddingModel(name, embedder, params);
textToVector = (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder);
return new SolrEmbeddingModel(name, textToVector, params);
} catch (final Exception e) {
throw new EmbeddingModelException("Model loading failed for " + className, e);
}
}

public SolrEmbeddingModel(String name, EmbeddingModel embedder, Map<String, Object> params) {
public SolrEmbeddingModel(String name, EmbeddingModel textToVector, Map<String, Object> params) {
this.name = name;
this.embedder = embedder;
this.textToVector = textToVector;
this.params = params;
this.hashCode = calculateHashCode();
}

public float[] vectorise(String text) {
Embedding vector = embedder.embed(text).content();
Embedding vector = textToVector.embed(text).content();
return vector.vector();
}

Expand All @@ -122,7 +122,7 @@ public String toString() {
public long ramBytesUsed() {
return BASE_RAM_BYTES
+ RamUsageEstimator.sizeOfObject(name)
+ RamUsageEstimator.sizeOfObject(embedder);
+ RamUsageEstimator.sizeOfObject(textToVector);
}

@Override
Expand All @@ -134,7 +134,7 @@ private int calculateHashCode() {
final int prime = 31;
int result = 1;
result = (prime * result) + Objects.hashCode(name);
result = (prime * result) + Objects.hashCode(embedder);
result = (prime * result) + Objects.hashCode(textToVector);
return result;
}

Expand All @@ -143,15 +143,15 @@ public boolean equals(Object obj) {
if (this == obj) return true;
if (!(obj instanceof SolrEmbeddingModel)) return false;
final SolrEmbeddingModel other = (SolrEmbeddingModel) obj;
return Objects.equals(embedder, other.embedder) && Objects.equals(name, other.name);
return Objects.equals(textToVector, other.textToVector) && Objects.equals(name, other.name);
}

public String getName() {
return name;
}

public String getEmbedderClassName() {
return embedder.getClass().getName();
public String getEmbeddingModelClassName() {
return textToVector.getClass().getName();
}

public Map<String, Object> getParams() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@
import org.apache.solr.search.neural.KnnQParser;

/**
* A neural query parser that embed the query and then run K-nearest neighbors search on Dense
* Vector fields. See Wiki page
* A neural query parser that encode the query to a vector and then run K-nearest neighbors search
* on Dense Vector fields. See Wiki page
* https://solr.apache.org/guide/solr/latest/query-guide/dense-vector-search.html
*/
public class TextEmbedderQParserPlugin extends QParserPlugin
public class TextToVectorQParserPlugin extends QParserPlugin
implements ResourceLoaderAware, ManagedResourceObserver {
public static final String EMBEDDING_MODEL_PARAM = "model";
private ManagedEmbeddingModelStore modelStore = null;

@Override
public QParser createParser(
String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
return new TextEmbedderQParser(qstr, localParams, params, req);
return new TextToVectorQParser(qstr, localParams, params, req);
}

@Override
Expand All @@ -67,26 +67,25 @@ public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res)
modelStore = (ManagedEmbeddingModelStore) res;
}
if (modelStore != null) {
// now we can safely load the models
modelStore.loadStoredModels();
}
}

public class TextEmbedderQParser extends KnnQParser {
public class TextToVectorQParser extends KnnQParser {

public TextEmbedderQParser(
String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(qstr, localParams, params, req);
public TextToVectorQParser(
String queryString, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
super(queryString, localParams, params, req);
}

@Override
public Query parse() throws SyntaxError {
checkParam(qstr, "Query string is empty, nothing to embed");
checkParam(qstr, "Query string is empty, nothing to vectorise");
final String embeddingModelName = localParams.get(EMBEDDING_MODEL_PARAM);
checkParam(embeddingModelName, "The 'model' parameter is missing");
SolrEmbeddingModel embedder = modelStore.getModel(embeddingModelName);
SolrEmbeddingModel textToVector = modelStore.getModel(embeddingModelName);

if (embedder != null) {
if (textToVector != null) {
final SchemaField schemaField = req.getCore().getLatestSchema().getField(getFieldName());
final DenseVectorField denseVectorType = getCheckedFieldType(schemaField);
int fieldDimensions = denseVectorType.getDimension();
Expand All @@ -96,7 +95,7 @@ public Query parse() throws SyntaxError {
switch (vectorEncoding) {
case FLOAT32:
{
float[] vectorToSearch = embedder.vectorise(qstr);
float[] vectorToSearch = textToVector.vectorise(qstr);
checkVectorDimension(vectorToSearch.length, fieldDimensions);
return new KnnFloatVectorQuery(
schemaField.getName(), vectorToSearch, topK, getFilterQuery());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public static SolrEmbeddingModel fromEmbeddingModelMap(Map<String, Object> embed
private static LinkedHashMap<String, Object> toEmbeddingModelMap(SolrEmbeddingModel model) {
final LinkedHashMap<String, Object> modelMap = new LinkedHashMap<>(5, 1.0f);
modelMap.put(NAME_KEY, model.getName());
modelMap.put(CLASS_KEY, model.getEmbedderClassName());
modelMap.put(CLASS_KEY, model.getEmbeddingModelClassName());
modelMap.put(PARAMS_KEY, model.getParams());
return modelMap;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
</requestDispatcher>

<!-- Query parser used to run neural queries-->
<queryParser name="text_embedder"
class="org.apache.solr.llm.search.TextEmbedderQParserPlugin" />
<queryParser name="text_to_vector"
class="org.apache.solr.llm.search.TextToVectorQParserPlugin" />

<query>
<filterCache class="solr.CaffeineCache" size="4096"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.junit.BeforeClass;
import org.junit.Test;

public class TextEmbedderQParserTest extends TestLlmBase {
public class TextToVectorQParserTest extends TestLlmBase {
@BeforeClass
public static void init() throws Exception {
setupTest("solrconfig-llm.xml", "schema.xml", true, false);
Expand All @@ -31,7 +31,7 @@ public static void init() throws Exception {

@Test
public void notExistentModel_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder model=not-exist f=vector topK=5}hello world";
final String solrQuery = "{!text_to_vector model=not-exist f=vector topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -44,7 +44,7 @@ public void notExistentModel_shouldThrowException() throws Exception {

@Test
public void missingModelParam_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder f=vector topK=5}hello world";
final String solrQuery = "{!text_to_vector f=vector topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -57,7 +57,7 @@ public void missingModelParam_shouldThrowException() throws Exception {

@Test
public void incorrectVectorFieldType_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=id topK=5}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=id topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -70,7 +70,7 @@ public void incorrectVectorFieldType_shouldThrowException() throws Exception {

@Test
public void undefinedVectorField_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=notExistent topK=5}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=notExistent topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -83,7 +83,7 @@ public void undefinedVectorField_shouldThrowException() throws Exception {

@Test
public void missingVectorFieldParam_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 topK=5}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -97,7 +97,7 @@ public void missingVectorFieldParam_shouldThrowException() throws Exception {
@Test
public void vectorByteEncodingField_shouldRaiseException() throws Exception {
final String solrQuery =
"{!text_embedder model=dummy-1 f=vector_byte_encoding topK=5}hello world";
"{!text_to_vector model=dummy-1 f=vector_byte_encoding topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -110,20 +110,21 @@ public void vectorByteEncodingField_shouldRaiseException() throws Exception {

@Test
public void missingQueryToEmbed_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=5}";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=5}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");

assertJQ(
"/query" + query.toQueryString(),
"/error/msg=='Query string is empty, nothing to embed'",
"/error/msg=='Query string is empty, nothing to vectorise'",
"/error/code==400");
}

@Test
public void incorrectVectorToSearchDimension_shouldThrowException() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=2048_float_vector topK=5}hello world";
final String solrQuery =
"{!text_to_vector model=dummy-1 f=2048_float_vector topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -136,7 +137,7 @@ public void incorrectVectorToSearchDimension_shouldThrowException() throws Excep

@Test
public void topK_shouldEmbedAndReturnOnlyTopKResults() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=5}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -153,7 +154,7 @@ public void topK_shouldEmbedAndReturnOnlyTopKResults() throws Exception {

@Test
public void vectorFieldParam_shouldSearchOnThatField() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector2 topK=5}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector2 topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -168,7 +169,7 @@ public void vectorFieldParam_shouldSearchOnThatField() throws Exception {

@Test
public void embeddedQuery_shouldRankBySimilarityFunction() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=10}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=10}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -191,7 +192,7 @@ public void embeddedQuery_shouldRankBySimilarityFunction() throws Exception {
@Test
public void embeddedQueryUsedInFilter_shouldFilterResultsBeforeTheQueryExecution()
throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=4}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=4}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery("id:(3 4 9 2)");
query.setFilterQueries(solrQuery);
Expand All @@ -207,7 +208,7 @@ public void embeddedQueryUsedInFilter_shouldFilterResultsBeforeTheQueryExecution
@Test
public void embeddedQueryUsedInFilters_shouldFilterResultsBeforeTheQueryExecution()
throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=4}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=4}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery("id:(3 4 9 2)");
query.setFilterQueries(solrQuery, "id:(4 20 9)");
Expand All @@ -222,7 +223,7 @@ public void embeddedQueryUsedInFilters_shouldFilterResultsBeforeTheQueryExecutio
public void embeddedQueryUsedInFiltersWithPreFilter_shouldFilterResultsBeforeTheQueryExecution()
throws Exception {
final String solrQuery =
"{!text_embedder model=dummy-1 f=vector topK=4 preFilter='id:(1 4 7 8 9)'}hello world";
"{!text_to_vector model=dummy-1 f=vector topK=4 preFilter='id:(1 4 7 8 9)'}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery("id:(3 4 9 2)");
query.setFilterQueries(solrQuery, "id:(4 20 9)");
Expand All @@ -240,8 +241,8 @@ public void embeddedQueryUsedInFiltersWithPreFilter_shouldFilterResultsBeforeThe
public void embeddedQueryUsedInFilters_rejectIncludeExclude() throws Exception {
for (String fq :
Arrays.asList(
"{!text_embedder model=dummy-1 f=vector topK=5 includeTags=xxx}hello world",
"{!text_embedder model=dummy-1 f=vector topK=5 excludeTags=xxx}hello world")) {
"{!text_to_vector model=dummy-1 f=vector topK=5 includeTags=xxx}hello world",
"{!text_to_vector model=dummy-1 f=vector topK=5 excludeTags=xxx}hello world")) {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.setFilterQueries(fq);
Expand All @@ -257,7 +258,7 @@ public void embeddedQueryUsedInFilters_rejectIncludeExclude() throws Exception {
@Test
public void embeddedQueryAsSubQuery() throws Exception {
final String solrQuery =
"*:* AND {!text_embedder model=dummy-1 f=vector topK=5 v='hello world'}";
"*:* AND {!text_to_vector model=dummy-1 f=vector topK=5 v='hello world'}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.setFilterQueries("id:(2 4 7 9 8 20 3)");
Expand All @@ -276,7 +277,7 @@ public void embeddedQueryAsSubQuery() throws Exception {
@Test
public void embeddedQueryAsSubQuery_withPreFilter() throws Exception {
final String solrQuery =
"*:* AND {!text_embedder model=dummy-1 f=vector topK=5 preFilter='id:(2 4 7 9 8 20 3)' v='hello world'}";
"*:* AND {!text_to_vector model=dummy-1 f=vector topK=5 preFilter='id:(2 4 7 9 8 20 3)' v='hello world'}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand All @@ -297,8 +298,8 @@ public void embeddedQueryAsSubQuery_withPreFilter() throws Exception {
public void embeddedQueryAsSubQuery_rejectIncludeExclude() throws Exception {
for (String q :
Arrays.asList(
"{!text_embedder model=dummy-1 f=vector topK=5 includeTags=xxx}hello world",
"{!text_embedder model=dummy-1 f=vector topK=5 excludeTags=xxx}hello world")) {
"{!text_to_vector model=dummy-1 f=vector topK=5 includeTags=xxx}hello world",
"{!text_to_vector model=dummy-1 f=vector topK=5 excludeTags=xxx}hello world")) {
final SolrQuery query = new SolrQuery();
query.setQuery("*:* OR " + q);
query.add("fl", "id");
Expand All @@ -312,7 +313,7 @@ public void embeddedQueryAsSubQuery_rejectIncludeExclude() throws Exception {

@Test
public void embeddedQueryWithCostlyFq_shouldPerformKnnSearchWithPostFilter() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=10}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=10}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.setFilterQueries("{!frange cache=false l=0.99}$q");
Expand All @@ -331,7 +332,7 @@ public void embeddedQueryWithCostlyFq_shouldPerformKnnSearchWithPostFilter() thr
@Test
public void embeddedQueryWithFilterQueries_shouldPerformKnnSearchWithPreFiltersAndPostFilters()
throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=4}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=4}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.setFilterQueries("id:(3 4 9 2)", "{!frange cache=false l=0.99}$q");
Expand All @@ -347,7 +348,7 @@ public void embeddedQueryWithFilterQueries_shouldPerformKnnSearchWithPreFiltersA
@Test
public void embeddedQueryWithNegativeFilterQuery_shouldPerformKnnSearchInPreFilteredResults()
throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector topK=4}hello world";
final String solrQuery = "{!text_to_vector model=dummy-1 f=vector topK=4}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.setFilterQueries("-id:4");
Expand All @@ -370,7 +371,7 @@ public void embeddedQueryWithNegativeFilterQuery_shouldPerformKnnSearchInPreFilt
public void embeddedQueryAsRerank_shouldAddSimilarityFunctionScore() throws Exception {
final SolrQuery query = new SolrQuery();
query.set("rq", "{!rerank reRankQuery=$rqq reRankDocs=4 reRankWeight=1}");
query.set("rqq", "{!text_embedder model=dummy-1 f=vector topK=4}hello world");
query.set("rqq", "{!text_to_vector model=dummy-1 f=vector topK=4}hello world");
query.setQuery("id:(3 4 9 2)");
query.add("fl", "id");

Expand Down
Loading

0 comments on commit 216180c

Please sign in to comment.