Skip to content

Commit

Permalink
Vertex AI embedding client fix (#561)
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshrajamani authored Oct 11, 2023
1 parent e49b780 commit ab34252
Showing 1 changed file with 71 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.stargate.sgv2.jsonapi.service.embedding.operation;

import com.fasterxml.jackson.annotation.JsonIgnore;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import java.net.URI;
import java.util.List;
import java.util.stream.Collectors;
import org.eclipse.microprofile.rest.client.annotation.ClientHeaderParam;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;

Expand All @@ -29,24 +31,86 @@ public interface VertexAIEmbeddingService {
@POST
@Path("/{modelId}:predict")
@ClientHeaderParam(name = "Content-Type", value = "application/json")
List<float[]> embed(
EmbeddingResponse embed(
@HeaderParam("Authorization") String accessToken,
@PathParam("modelId") String modelId,
EmbeddingRequest request);
}

private record EmbeddingRequest(List<Content> instances, Options options) {
private record EmbeddingRequest(List<Content> instances) {
public record Content(String content) {}
}

private static class EmbeddingResponse {
public EmbeddingResponse() {}

private List<Prediction> predictions;

@JsonIgnore private Object metadata;

public List<Prediction> getPredictions() {
return predictions;
}

public void setPredictions(List<Prediction> predictions) {
this.predictions = predictions;
}

public Object getMetadata() {
return metadata;
}

public void setMetadata(Object metadata) {
this.metadata = metadata;
}

protected static class Prediction {
public Prediction() {}

private Embeddings embeddings;

public Embeddings getEmbeddings() {
return embeddings;
}

public void setEmbeddings(Embeddings embeddings) {
this.embeddings = embeddings;
}

protected static class Embeddings {
public Embeddings() {}

private float[] values;

@JsonIgnore private Object statistics;

public float[] getValues() {
return values;
}

public void setValues(float[] values) {
this.values = values;
}

public Object getStatistics() {
return statistics;
}

public record Options(boolean waitForModel) {}
public void setStatistics(Object statistics) {
this.statistics = statistics;
}
}
}
}

@Override
public List<float[]> vectorize(List<String> texts) {
EmbeddingRequest request =
new EmbeddingRequest(
texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList(),
new EmbeddingRequest.Options(true));
return embeddingService.embed("Bearer " + apiKey, modelName, request);
new EmbeddingRequest(texts.stream().map(t -> new EmbeddingRequest.Content(t)).toList());
EmbeddingResponse serviceResponse =
embeddingService.embed("Bearer " + apiKey, modelName, request);
return serviceResponse.getPredictions().stream()
.map(prediction -> prediction.getEmbeddings().getValues())
.collect(Collectors.toList());
}
}

0 comments on commit ab34252

Please sign in to comment.