Skip to content

Commit

Permalink
feat: [vertexai] add fromFunctionResponse in PartMaker (#10272)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600847017

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Jan 23, 2024
1 parent e761894 commit 20c8252
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 1 deletion.
2 changes: 1 addition & 1 deletion java-vertexai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ If you are using Maven with [BOM][libraries-bom], add this to your pom.xml file:
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>libraries-bom</artifactId>
<version>26.30.0</version>
<version>26.29.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@

import com.google.cloud.vertexai.api.Blob;
import com.google.cloud.vertexai.api.FileData;
import com.google.cloud.vertexai.api.FunctionResponse;
import com.google.cloud.vertexai.api.Part;
import com.google.protobuf.ByteString;
import com.google.protobuf.NullValue;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import java.net.URI;
import java.util.Map;

/** Helper class to create {@link com.google.cloud.vertexai.api.Part} */
public class PartMaker {
Expand Down Expand Up @@ -77,4 +82,55 @@ public static Part fromMimeTypeAndData(String mimeType, Object partData) {
}
return part;
}

/**
* Make a {@link com.google.cloud.vertexai.api.Part} from the output of {@link
* com.google.cloud.vertexai.api.FunctionCall}.
*
* @param name a string represents the name of the {@link
* com.google.cloud.vertexai.api.FunctionDeclaration}
* @param response a structured JSON object containing any output from the function call
*/
public static Part fromFunctionResponse(String name, Struct response) {
return Part.newBuilder()
.setFunctionResponse(FunctionResponse.newBuilder().setName(name).setResponse(response))
.build();
}

/**
* Make a {@link com.google.cloud.vertexai.api.Part} from the result output of {@link
* com.google.cloud.vertexai.api.FunctionCall}.
*
* @param name a string represents the name of the {@link
* com.google.cloud.vertexai.api.FunctionDeclaration}
* @param response a map containing the output from the function call, supported output type:
* String, Double, Boolean, null
*/
public static Part fromFunctionResponse(String name, Map<String, Object> response) {
Struct.Builder structBuilder = Struct.newBuilder();
response.forEach(
(key, value) -> {
if (value instanceof String) {
String stringValue = (String) value;
structBuilder.putFields(key, Value.newBuilder().setStringValue(stringValue).build());
} else if (value instanceof Double) {
Double doubleValue = (Double) value;
structBuilder.putFields(key, Value.newBuilder().setNumberValue(doubleValue).build());
} else if (value instanceof Boolean) {
Boolean boolValue = (Boolean) value;
structBuilder.putFields(key, Value.newBuilder().setBoolValue(boolValue).build());
} else if (value == null) {
structBuilder.putFields(
key, Value.newBuilder().setNullValue(NullValue.NULL_VALUE).build());
} else {
throw new IllegalArgumentException(
"The value in the map can only be one of the following format: "
+ "String, Double, Boolean, null.");
}
});

return Part.newBuilder()
.setFunctionResponse(FunctionResponse.newBuilder().setName(name).setResponse(structBuilder))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
package com.google.cloud.vertexai.generativeai.preview;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.cloud.vertexai.api.Part;
import com.google.protobuf.ByteString;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -72,4 +77,56 @@ public void fromMimeTypeAndData_dataInURI() throws URISyntaxException {
assertThat(part.getFileData().getMimeType()).isEqualTo("image/png");
assertThat(part.getFileData().getFileUri()).isEqualTo(fileUri.toString());
}

@Test
public void testFromFunctionResponseWithStruct() {
String functionName = "getCurrentWeather";
Struct functionResponse =
Struct.newBuilder()
.putFields("currentWeather", Value.newBuilder().setStringValue("Super nice!").build())
.putFields("currentTemperature", Value.newBuilder().setNumberValue(85.0).build())
.putFields("isRaining", Value.newBuilder().setBoolValue(false).build())
.build();

Part part = PartMaker.fromFunctionResponse(functionName, functionResponse);

assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather");
assertThat(part.getFunctionResponse().getResponse()).isEqualTo(functionResponse);
}

@Test
public void testFromFunctionResponseWithMap() {
String functionName = "getCurrentWeather";
Map<String, Object> functionResponse = new HashMap<>();
functionResponse.put("currentWeather", "Super nice!");
functionResponse.put("currentTemperature", 85.0);
functionResponse.put("isRaining", false);
functionResponse.put("other", null);

Part part = PartMaker.fromFunctionResponse(functionName, functionResponse);

assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather");

Map<String, Value> fieldsMap = part.getFunctionResponse().getResponse().getFieldsMap();
assertThat(fieldsMap.get("currentWeather").getStringValue()).isEqualTo("Super nice!");
assertThat(fieldsMap.get("currentTemperature").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("isRaining").getBoolValue()).isEqualTo(false);
assertThat(fieldsMap.get("other").hasNullValue()).isEqualTo(true);
}

@Test
public void testFromFunctionResponseWithInvalidMap() {
String functionName = "getCurrentWeather";
Map<String, Object> invalidResponse = new HashMap<>();
invalidResponse.put("currentWeather", new byte[] {1, 2, 3});
IllegalArgumentException thrown =
assertThrows(
IllegalArgumentException.class,
() -> PartMaker.fromFunctionResponse(functionName, invalidResponse));
assertThat(thrown)
.hasMessageThat()
.isEqualTo(
"The value in the map can only be one of the following format: "
+ "String, Double, Boolean, null.");
}
}

0 comments on commit 20c8252

Please sign in to comment.