Skip to content

Commit

Permalink
added comments to make cleare the inversion of control part + gradle …
Browse files Browse the repository at this point in the history
…tidy
  • Loading branch information
alessandrobenedetti committed Nov 18, 2024
1 parent 7e10707 commit af1cb50
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,22 @@ public class SolrEmbeddingModel implements Accountable {
public static SolrEmbeddingModel getInstance(
String className, String name, Map<String, Object> params) throws EmbeddingModelException {
try {
/*
* The idea herea is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion
* of control.
* Each model has its own list of parameters we don't know beforehand, but each {@link dev.langchain4j.model.embedding.EmbeddingModel} class
* has its own builder that uses setters with the same name of the parameter in input.
* */
EmbeddingModel embedder;
Class<?> modelClass = Class.forName(className);
var builder = modelClass.getMethod("builder").invoke(null);
if (params != null) {
/**
* Some {@link dev.langchain4j.model.embedding.EmbeddingModel} classes have params of
* specific types that must be constructed, for primitive types we can resort to the
* default. N.B. when adding support to new models, pay attention to all the parameters they
* support, some of them may require to be handled in here as separate switch cases
*/
for (String paramName : params.keySet()) {
switch (paramName) {
case TIMEOUT_PARAM:
Expand All @@ -65,14 +77,14 @@ public static SolrEmbeddingModel getInstance(
.invoke(builder, ((Long) params.get(paramName)).intValue());
break;
default:
ArrayList<Method> methods = new ArrayList<>();
ArrayList<Method> paramNameMatches = new ArrayList<>();
for (var method : builder.getClass().getMethods()) {
if (paramName.equals(method.getName()) && method.getParameterCount() == 1) {
methods.add(method);
paramNameMatches.add(method);
}
}
if (methods.size() == 1) {
methods.get(0).invoke(builder, params.get(paramName));
if (paramNameMatches.size() == 1) {
paramNameMatches.get(0).invoke(builder, params.get(paramName));
} else {
builder
.getClass()
Expand Down Expand Up @@ -141,7 +153,7 @@ public String getName() {
public String getEmbedderClassName() {
return embedder.getClass().getName();
}

public Map<String, Object> getParams() {
return params;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ public DummyEmbeddingModel build() {
return new DummyEmbeddingModel(this.builderEmbeddings);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,25 @@ public class DummyEmbeddingModelTest extends SolrTestCase {
@Test
public void constructAndEmbed() throws Exception {
assertEquals(
"[1.0, 2.0, 3.0, 4.0]",
new DummyEmbeddingModel(new float[] {1, 2, 3, 4})
.embed("hello")
.content()
.vectorAsList()
.toString());
"[1.0, 2.0, 3.0, 4.0]",
new DummyEmbeddingModel(new float[] {1, 2, 3, 4})
.embed("hello")
.content()
.vectorAsList()
.toString());
assertEquals(
"[8.0, 7.0, 6.0, 5.0]",
new DummyEmbeddingModel(new float[] {8, 7, 6, 5})
.embed("world")
.content()
.vectorAsList()
.toString());
"[8.0, 7.0, 6.0, 5.0]",
new DummyEmbeddingModel(new float[] {8, 7, 6, 5})
.embed("world")
.content()
.vectorAsList()
.toString());
assertEquals(
"[0.0, 0.0, 4.0, 2.0]",
new DummyEmbeddingModel(new float[] {0, 0, 4, 2})
.embed("answer")
.content()
.vectorAsList()
.toString());
"[0.0, 0.0, 4.0, 2.0]",
new DummyEmbeddingModel(new float[] {0, 0, 4, 2})
.embed("answer")
.content()
.vectorAsList()
.toString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ public void missingVectorFieldParam_shouldThrowException() throws Exception {

@Test
public void vectorByteEncodingField_shouldRaiseException() throws Exception {
final String solrQuery = "{!text_embedder model=dummy-1 f=vector_byte_encoding topK=5}hello world";
final String solrQuery =
"{!text_embedder model=dummy-1 f=vector_byte_encoding topK=5}hello world";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "id");
Expand Down Expand Up @@ -255,7 +256,8 @@ public void embeddedQueryUsedInFilters_rejectIncludeExclude() throws Exception {

@Test
public void embeddedQueryAsSubQuery() throws Exception {
final String solrQuery = "*:* AND {!text_embedder model=dummy-1 f=vector topK=5 v='hello world'}";
final String solrQuery =
"*:* AND {!text_embedder model=dummy-1 f=vector topK=5 v='hello world'}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.setFilterQueries("id:(2 4 7 9 8 20 3)");
Expand Down

0 comments on commit af1cb50

Please sign in to comment.