diff --git a/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java b/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java index 70a580935bb..dbac1ba23b1 100644 --- a/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java +++ b/solr/modules/llm/src/java/org/apache/solr/llm/embedding/SolrEmbeddingModel.java @@ -42,10 +42,22 @@ public class SolrEmbeddingModel implements Accountable { public static SolrEmbeddingModel getInstance( String className, String name, Map params) throws EmbeddingModelException { try { + /* + * The idea herea is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion + * of control. + * Each model has its own list of parameters we don't know beforehand, but each {@link dev.langchain4j.model.embedding.EmbeddingModel} class + * has its own builder that uses setters with the same name of the parameter in input. + * */ EmbeddingModel embedder; Class modelClass = Class.forName(className); var builder = modelClass.getMethod("builder").invoke(null); if (params != null) { + /** + * Some {@link dev.langchain4j.model.embedding.EmbeddingModel} classes have params of + * specific types that must be constructed, for primitive types we can resort to the + * default. N.B. when adding support to new models, pay attention to all the parameters they + * support, some of them may require to be handled in here as separate switch cases + */ for (String paramName : params.keySet()) { switch (paramName) { case TIMEOUT_PARAM: @@ -65,14 +77,14 @@ public static SolrEmbeddingModel getInstance( .invoke(builder, ((Long) params.get(paramName)).intValue()); break; default: - ArrayList methods = new ArrayList<>(); + ArrayList paramNameMatches = new ArrayList<>(); for (var method : builder.getClass().getMethods()) { if (paramName.equals(method.getName()) && method.getParameterCount() == 1) { - methods.add(method); + paramNameMatches.add(method); } } - if (methods.size() == 1) { - methods.get(0).invoke(builder, params.get(paramName)); + if (paramNameMatches.size() == 1) { + paramNameMatches.get(0).invoke(builder, params.get(paramName)); } else { builder .getClass() @@ -141,7 +153,7 @@ public String getName() { public String getEmbedderClassName() { return embedder.getClass().getName(); } - + public Map getParams() { return params; } diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java index 53d8dd882b9..1076f5d0bbf 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModel.java @@ -73,4 +73,4 @@ public DummyEmbeddingModel build() { return new DummyEmbeddingModel(this.builderEmbeddings); } } -} \ No newline at end of file +} diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java index e3f5c85cadf..3aca97d7e42 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/embedding/DummyEmbeddingModelTest.java @@ -24,25 +24,25 @@ public class DummyEmbeddingModelTest extends SolrTestCase { @Test public void constructAndEmbed() throws Exception { assertEquals( - "[1.0, 2.0, 3.0, 4.0]", - new DummyEmbeddingModel(new float[] {1, 2, 3, 4}) - .embed("hello") - .content() - .vectorAsList() - .toString()); + "[1.0, 2.0, 3.0, 4.0]", + new DummyEmbeddingModel(new float[] {1, 2, 3, 4}) + .embed("hello") + .content() + .vectorAsList() + .toString()); assertEquals( - "[8.0, 7.0, 6.0, 5.0]", - new DummyEmbeddingModel(new float[] {8, 7, 6, 5}) - .embed("world") - .content() - .vectorAsList() - .toString()); + "[8.0, 7.0, 6.0, 5.0]", + new DummyEmbeddingModel(new float[] {8, 7, 6, 5}) + .embed("world") + .content() + .vectorAsList() + .toString()); assertEquals( - "[0.0, 0.0, 4.0, 2.0]", - new DummyEmbeddingModel(new float[] {0, 0, 4, 2}) - .embed("answer") - .content() - .vectorAsList() - .toString()); + "[0.0, 0.0, 4.0, 2.0]", + new DummyEmbeddingModel(new float[] {0, 0, 4, 2}) + .embed("answer") + .content() + .vectorAsList() + .toString()); } -} \ No newline at end of file +} diff --git a/solr/modules/llm/src/test/org/apache/solr/llm/search/TextEmbedderQParserTest.java b/solr/modules/llm/src/test/org/apache/solr/llm/search/TextEmbedderQParserTest.java index 10df9998aa2..b82463bec05 100644 --- a/solr/modules/llm/src/test/org/apache/solr/llm/search/TextEmbedderQParserTest.java +++ b/solr/modules/llm/src/test/org/apache/solr/llm/search/TextEmbedderQParserTest.java @@ -96,7 +96,8 @@ public void missingVectorFieldParam_shouldThrowException() throws Exception { @Test public void vectorByteEncodingField_shouldRaiseException() throws Exception { - final String solrQuery = "{!text_embedder model=dummy-1 f=vector_byte_encoding topK=5}hello world"; + final String solrQuery = + "{!text_embedder model=dummy-1 f=vector_byte_encoding topK=5}hello world"; final SolrQuery query = new SolrQuery(); query.setQuery(solrQuery); query.add("fl", "id"); @@ -255,7 +256,8 @@ public void embeddedQueryUsedInFilters_rejectIncludeExclude() throws Exception { @Test public void embeddedQueryAsSubQuery() throws Exception { - final String solrQuery = "*:* AND {!text_embedder model=dummy-1 f=vector topK=5 v='hello world'}"; + final String solrQuery = + "*:* AND {!text_embedder model=dummy-1 f=vector topK=5 v='hello world'}"; final SolrQuery query = new SolrQuery(); query.setQuery(solrQuery); query.setFilterQueries("id:(2 4 7 9 8 20 3)");