diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponder.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponder.java new file mode 100644 index 000000000000..67da2e63e990 --- /dev/null +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponder.java @@ -0,0 +1,283 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.vertexai.generativeai; + +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; + +/** A responder that automatically calls functions when requested by the GenAI model. */ +public final class AutomaticFunctionCallingResponder { + private int maxFunctionCalls = 1; + private int remainingFunctionCalls; + private final Map callableFunctions = new HashMap<>(); + + private static final Logger logger = + Logger.getLogger(AutomaticFunctionCallingResponder.class.getName()); + + /** Constructs an AutomaticFunctionCallingResponder instance. */ + public AutomaticFunctionCallingResponder() { + this.remainingFunctionCalls = this.maxFunctionCalls; + } + + /** + * Constructs an AutomaticFunctionCallingResponder instance. + * + * @param maxFunctionCalls the maximum number of function calls to make in a row + */ + public AutomaticFunctionCallingResponder(int maxFunctionCalls) { + this.maxFunctionCalls = maxFunctionCalls; + this.remainingFunctionCalls = maxFunctionCalls; + } + + /** Sets the maximum number of function calls to make in a row. */ + public void setMaxFunctionCalls(int maxFunctionCalls) { + this.maxFunctionCalls = maxFunctionCalls; + this.remainingFunctionCalls = this.maxFunctionCalls; + } + + /** Gets the maximum number of function calls to make in a row. */ + public int getMaxFunctionCalls() { + return maxFunctionCalls; + } + + /** Resets the remaining function calls to the maximum number of function calls. */ + void resetRemainingFunctionCalls() { + this.remainingFunctionCalls = this.maxFunctionCalls; + } + + /** + * Adds a callable function to the AutomaticFunctionCallingResponder. + * + *

Note:: This method requires you to compile your code with the "-parameters" flag, so + * that the parameter names can be retrieved. If you don't want to do this, you can use the + * `addCallableFunction(String, Method, String...)` method instead. + * + * @param functionName the name of the function + * @param callableFunction the method to call when the functionName is requested + * @throws IllegalArgumentException if the functionName is already in the responder + */ + public void addCallableFunction(String functionName, Method callableFunction) { + if (callableFunctions.containsKey(functionName)) { + throw new IllegalArgumentException("Duplicate function name: " + functionName); + } else { + callableFunctions.put(functionName, new CallableFunction(callableFunction)); + } + } + + /** + * Adds a callable function to the AutomaticFunctionCallingResponder. + * + * @param functionName the name of the function + * @param callableFunction the method to call when the functionName is requested + * @param orderedParameterNames the names of the parameters in the order they are passed to the + * function + * @throws IllegalArgumentException if the functionName is already in the responder + */ + public void addCallableFunction( + String functionName, Method callableFunction, String... orderedParameterNames) { + if (callableFunctions.containsKey(functionName)) { + throw new IllegalArgumentException("Duplicate function name: " + functionName); + } else { + callableFunctions.put( + functionName, new CallableFunction(callableFunction, orderedParameterNames)); + } + } + + /** + * Automatically calls functions requested by the model and generates a Content that contains the + * results. + * + * @param functionCalls a list of {@link com.google.cloud.vertexai.api.FunctionCall} requested by + * the model + * @return a {@link com.google.cloud.vertexai.api.Content} that contains the results of the + * function calls + * @throws IllegalStateException if the number of automatic calls exceeds the maximum number of + * function calls + * @throws IllegalArgumentException if the model has asked to call a function that was not found + * in the responder + */ + Content getContentFromFunctionCalls(List functionCalls) { + List responseParts = new ArrayList<>(); + + for (FunctionCall functionCall : functionCalls) { + logger.info("functionCall requested from the model: " + functionCall); + if (remainingFunctionCalls <= 0) { + throw new IllegalStateException( + "Exceeded the maximum number of continuous automatic function calls (" + + maxFunctionCalls + + "). If more automatic function calls are needed, please call" + + " `setMaxFunctionCalls() to set a higher number. The last function call is:\n" + + functionCall); + } + remainingFunctionCalls -= 1; + String functionName = functionCall.getName(); + CallableFunction callableFunction = callableFunctions.get(functionName); + if (callableFunction == null) { + throw new IllegalArgumentException( + "Model has asked to call function \"" + functionName + "\" which was not found."); + } + responseParts.add( + PartMaker.fromFunctionResponse( + functionName, + Collections.singletonMap("result", callableFunction.call(functionCall.getArgs())))); + } + + return ContentMaker.fromMultiModalData(responseParts.toArray()); + } + + /** A class that represents a function that can be called automatically. */ + static class CallableFunction { + private final Method callableFunction; + private final ImmutableList orderedParameterNames; + + /** + * Constructs a CallableFunction instance. + * + *

Note:: This method requires you to compile your code with the "-parameters" flag, + * so that the parameter names can be retrieved. If you don't want to do this, you can use the + * `CallableFunction(Method, String...)` constructor instead. + * + * @param callableFunction the method to call + * @throws IllegalArgumentException if the given method is not a static method + * @throws IllegalStateException if the parameter names cannot be retrieved from reflection + */ + CallableFunction(Method callableFunction) { + validateFunction(callableFunction); + this.callableFunction = callableFunction; + ImmutableList.Builder builder = ImmutableList.builder(); + for (Parameter parameter : callableFunction.getParameters()) { + if (parameter.isNamePresent()) { + builder.add(parameter.getName()); + } else { + throw new IllegalStateException( + "Failed to retrieve the parameter name from reflection. Please compile your code with" + + " \"-parameters\" flag or use `addCallableFunction(String, Method, String...)`" + + " to manually enter parameter names"); + } + } + this.orderedParameterNames = builder.build(); + } + + /** + * Constructs a CallableFunction instance. + * + * @param callableFunction the method to call + * @param orderedParameterNames the names of the parameters in the order they are passed to the + * function + * @throws IllegalArgumentException if the given method is not a static method or the number of + * provided parameter names doesn't match the number of parameters in the callable function + */ + CallableFunction(Method callableFunction, String... orderedParameterNames) { + validateFunction(callableFunction); + if (orderedParameterNames.length != callableFunction.getParameters().length) { + throw new IllegalArgumentException( + "The number of provided parameter names doesn't match the number of parameters in the" + + " callable function."); + } + this.callableFunction = callableFunction; + this.orderedParameterNames = ImmutableList.copyOf(orderedParameterNames); + } + + /** + * Calls the callable function with the given arguments. + * + * @param args the arguments to pass to the function + * @return the result of the function call + * @throws IllegalStateException if there are errors when invoking the function + * @throws IllegalArgumentException if the args map doesn't contain all the parameters of the + * function or the value types in the args map are not supported + */ + Object call(Struct args) { + // Extract the arguments from the Struct + Map argsMap = args.getFieldsMap(); + List argsList = new ArrayList<>(); + for (int i = 0; i < orderedParameterNames.size(); i++) { + String parameterName = orderedParameterNames.get(i); + if (!argsMap.containsKey(parameterName)) { + throw new IllegalArgumentException( + "The parameter \"" + + parameterName + + "\" was not found in the arguments requested by the model. Args map: " + + argsMap); + } + Value value = argsMap.get(parameterName); + switch (value.getKindCase()) { + case NUMBER_VALUE: + // Args map only returns double values, but the function may expect other types(int, + // float). So we need to cast the value to the correct type. + Class parameterType = callableFunction.getParameters()[i].getType(); + if (parameterType.equals(int.class)) { + argsList.add((int) value.getNumberValue()); + } else if (parameterType.equals(float.class)) { + argsList.add((float) value.getNumberValue()); + } else { + argsList.add(value.getNumberValue()); + } + break; + case STRING_VALUE: + argsList.add(value.getStringValue()); + break; + case BOOL_VALUE: + argsList.add(value.getBoolValue()); + break; + case NULL_VALUE: + argsList.add(null); + break; + default: + throw new IllegalArgumentException( + "Unsupported value type " + + value.getKindCase() + + " for parameter " + + parameterName); + } + } + + // Invoke the function + logger.info( + "Automatically calling function: " + + callableFunction.getName() + + argsList.toString().replace('[', '(').replace(']', ')')); + try { + return callableFunction.invoke(null, argsList.toArray()); + } catch (Exception e) { + throw new IllegalStateException( + "Error raised when calling function \"" + + callableFunction.getName() + + "\" as requested by the model. ", + e); + } + } + + /** Validates that the given method is a static method. */ + private void validateFunction(Method method) { + if (!Modifier.isStatic(method.getModifiers())) { + throw new IllegalArgumentException("Function calling only supports static methods."); + } + } + } +} diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java index 3e5f9503def2..424cd053f0e3 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java @@ -19,10 +19,12 @@ import static com.google.cloud.vertexai.generativeai.ResponseHandler.aggregateStreamIntoResponse; import static com.google.cloud.vertexai.generativeai.ResponseHandler.getContent; import static com.google.cloud.vertexai.generativeai.ResponseHandler.getFinishReason; +import static com.google.cloud.vertexai.generativeai.ResponseHandler.getFunctionCalls; import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.vertexai.api.Candidate.FinishReason; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.SafetySetting; @@ -37,6 +39,7 @@ public final class ChatSession { private final GenerativeModel model; private final Optional rootChatSession; + private final Optional automaticFunctionCallingResponder; private List history = new ArrayList<>(); private int previousHistorySize = 0; private Optional> currentResponseStream; @@ -47,7 +50,7 @@ public final class ChatSession { * GenerationConfig) inherits from the model. */ public ChatSession(GenerativeModel model) { - this(model, Optional.empty()); + this(model, Optional.empty(), Optional.empty()); } /** @@ -57,12 +60,18 @@ public ChatSession(GenerativeModel model) { * @param model a {@link GenerativeModel} instance that generates contents in the chat. * @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current * chat session will be merged to the root chat session. + * @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance + * that can automatically respond to function calls requested by the model. * @return a {@link ChatSession} instance. */ - private ChatSession(GenerativeModel model, Optional rootChatSession) { + private ChatSession( + GenerativeModel model, + Optional rootChatSession, + Optional automaticFunctionCallingResponder) { checkNotNull(model, "model should not be null"); this.model = model; this.rootChatSession = rootChatSession; + this.automaticFunctionCallingResponder = automaticFunctionCallingResponder; currentResponseStream = Optional.empty(); currentResponse = Optional.empty(); } @@ -77,7 +86,10 @@ private ChatSession(GenerativeModel model, Optional rootChatSession public ChatSession withGenerationConfig(GenerationConfig generationConfig) { ChatSession rootChat = rootChatSession.orElse(this); ChatSession newChatSession = - new ChatSession(model.withGenerationConfig(generationConfig), Optional.of(rootChat)); + new ChatSession( + model.withGenerationConfig(generationConfig), + Optional.of(rootChat), + automaticFunctionCallingResponder); newChatSession.history = history; newChatSession.previousHistorySize = previousHistorySize; return newChatSession; @@ -93,7 +105,10 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) { public ChatSession withSafetySettings(List safetySettings) { ChatSession rootChat = rootChatSession.orElse(this); ChatSession newChatSession = - new ChatSession(model.withSafetySettings(safetySettings), Optional.of(rootChat)); + new ChatSession( + model.withSafetySettings(safetySettings), + Optional.of(rootChat), + automaticFunctionCallingResponder); newChatSession.history = history; newChatSession.previousHistorySize = previousHistorySize; return newChatSession; @@ -108,7 +123,28 @@ public ChatSession withSafetySettings(List safetySettings) { */ public ChatSession withTools(List tools) { ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = new ChatSession(model.withTools(tools), Optional.of(rootChat)); + ChatSession newChatSession = + new ChatSession( + model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder); + newChatSession.history = history; + newChatSession.previousHistorySize = previousHistorySize; + return newChatSession; + } + + /** + * Creates a copy of the current ChatSession with updated AutomaticFunctionCallingResponder. + * + * @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance + * that will be used in the new ChatSession. + * @return a new {@link ChatSession} instance with the specified + * AutomaticFunctionCallingResponder. + */ + public ChatSession withAutomaticFunctionCallingResponder( + AutomaticFunctionCallingResponder automaticFunctionCallingResponder) { + ChatSession rootChat = rootChatSession.orElse(this); + ChatSession newChatSession = + new ChatSession( + model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder)); newChatSession.history = history; newChatSession.previousHistorySize = previousHistorySize; return newChatSession; @@ -141,12 +177,13 @@ public ResponseStream sendMessageStream(Content content try { respStream = model.generateContentStream(history); } catch (IOException e) { - // If the API call fails, remove the last content from the history before throwing. + // If the API call fails, revert the history before throwing. revertHistory(); throw e; } setCurrentResponseStream(Optional.of(respStream)); + // TODO(jayceeli) enable AFC in sendMessageStream return respStream; } @@ -169,17 +206,46 @@ public GenerateContentResponse sendMessage(String text) throws IOException { public GenerateContentResponse sendMessage(Content content) throws IOException { checkLastResponseAndEditHistory(); history.add(content); - GenerateContentResponse response; try { response = model.generateContent(history); - } catch (IOException e) { - // If the API call fails, remove the last content from the history before throwing. + setCurrentResponse(Optional.of(response)); + return autoRespond(response); + } catch (Exception e) { + // If any step fails, reset the history and current response. + checkLastResponseAndEditHistory(); revertHistory(); throw e; } - setCurrentResponse(Optional.of(response)); + } + + /** + * Automatically responds to the model if there is an AutomaticFunctionCallingResponder and + * model's response contains function calls. + */ + private GenerateContentResponse autoRespond(GenerateContentResponse originalResponse) + throws IOException { + // Return the original response if there is no AFC responder or no function calls in the + // response. + if (!automaticFunctionCallingResponder.isPresent()) { + return originalResponse; + } + ImmutableList functionCalls = getFunctionCalls(originalResponse); + if (functionCalls.isEmpty()) { + return originalResponse; + } + setPreviousHistorySize(getPreviousHistorySize() - 2); + GenerateContentResponse response; + try { + // Let the responder generate the response content and send it to the model. + Content autoRespondedContent = + automaticFunctionCallingResponder.get().getContentFromFunctionCalls(functionCalls); + response = sendMessage(autoRespondedContent); + } finally { + // Reset the responder whether it succeeds or fails. + automaticFunctionCallingResponder.get().resetRemainingFunctionCalls(); + } return response; } @@ -200,7 +266,7 @@ private void checkLastResponseAndEditHistory() { setCurrentResponse(Optional.empty()); checkFinishReasonAndEditHistory(currentResponse); history.add(getContent(currentResponse)); - setPreviousHistorySize(history.size()); + setPreviousHistorySize(getPreviousHistorySize() + 2); }); getCurrentResponseStream() .ifPresent( @@ -212,7 +278,7 @@ private void checkLastResponseAndEditHistory() { GenerateContentResponse response = aggregateStreamIntoResponse(responseStream); checkFinishReasonAndEditHistory(response); history.add(getContent(response)); - setPreviousHistorySize(history.size()); + setPreviousHistorySize(getPreviousHistorySize() + 2); } }); } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java index c118df5b0d6f..b38ae344f0cf 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java @@ -19,10 +19,17 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.vertexai.api.FunctionDeclaration; +import com.google.cloud.vertexai.api.Schema; +import com.google.cloud.vertexai.api.Type; import com.google.common.base.Strings; import com.google.gson.JsonObject; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.List; /** Helper class to create {@link com.google.cloud.vertexai.api.FunctionDeclaration} */ public final class FunctionDeclarationMaker { @@ -60,4 +67,99 @@ public static FunctionDeclaration fromJsonObject(JsonObject jsonObject) checkNotNull(jsonObject, "JsonObject can't be null."); return fromJsonString(jsonObject.toString()); } + + /** + * Creates a FunctionDeclaration from a Java static method + * + *

Note:: This method requires you to compile your code with the "-parameters" flag, so + * that the parameter names can be retrieved. If you don't want to do this, you can use the + * `fromFunc(String, Method, String...)` method instead. + * + * @param functionDescription A description of the method. + * @param function A Java method. + * @return a {@link com.google.cloud.vertexai.api.FunctionDeclaration} instance. + * @throws IllegalArgumentException if the method is not a static method or parameter types in + * this method are not String, boolean, int, double, or float. + */ + public static FunctionDeclaration fromFunc(String functionDescription, Method function) { + List orderedParameterNames = new ArrayList<>(); + for (Parameter parameter : function.getParameters()) { + if (!parameter.isNamePresent()) { + throw new IllegalStateException( + "Failed to retrieve the parameter name from reflection. Please compile your" + + " code with \"-parameters\" flag or use `fromFunc(String, Method," + + " String...)` to manually enter parameter names"); + } + orderedParameterNames.add(parameter.getName()); + } + return fromFunc(functionDescription, function, orderedParameterNames.toArray(new String[0])); + } + + /** + * Creates a FunctionDeclaration from a Java static method + * + * @param functionDescription A description of the method. + * @param function A Java method. + * @param orderedParameterNames A list of parameter names in the order they are passed to the + * method. + * @return a {@link com.google.cloud.vertexai.api.FunctionDeclaration} instance. + * @throws IllegalArgumentException if the method is not a static method or the number of provided + * parameter names doesn't match the number of parameters in the callable function or + * parameter types in this method are not String, boolean, int, double, or float. + */ + public static FunctionDeclaration fromFunc( + String functionDescription, Method function, String... orderedParameterNames) { + if (!Modifier.isStatic(function.getModifiers())) { + throw new IllegalArgumentException( + "Instance methods are not supported. Please use static methods."); + } + Schema.Builder parametersBuilder = Schema.newBuilder().setType(Type.OBJECT); + + Parameter[] parameters = function.getParameters(); + if (parameters.length != orderedParameterNames.length) { + throw new IllegalArgumentException( + "The number of parameter names does not match the number of parameters in the method."); + } + + for (int i = 0; i < parameters.length; i++) { + addParameterToParametersBuilder( + parametersBuilder, orderedParameterNames[i], parameters[i].getType()); + } + + return FunctionDeclaration.newBuilder() + .setName(function.getName()) + .setDescription(functionDescription) + .setParameters(parametersBuilder.build()) + .build(); + } + + /** Adds a parameter to the parameters builder. */ + private static void addParameterToParametersBuilder( + Schema.Builder parametersBuilder, String parameterName, Class parameterType) { + Schema.Builder parameterBuilder = Schema.newBuilder().setDescription(parameterName); + switch (parameterType.getName()) { + case "java.lang.String": + parameterBuilder.setType(Type.STRING); + break; + case "boolean": + parameterBuilder.setType(Type.BOOLEAN); + break; + case "int": + parameterBuilder.setType(Type.INTEGER); + break; + case "double": + case "float": + parameterBuilder.setType(Type.NUMBER); + break; + default: + throw new IllegalArgumentException( + "Unsupported parameter type " + + parameterType.getName() + + " for parameter " + + parameterName); + } + parametersBuilder + .addRequired(parameterName) + .putProperties(parameterName, parameterBuilder.build()); + } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponderTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponderTest.java new file mode 100644 index 000000000000..eefd1993456d --- /dev/null +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/AutomaticFunctionCallingResponderTest.java @@ -0,0 +1,305 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vertexai.generativeai; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AutomaticFunctionCallingResponderTest { + private static final int MAX_FUNCTION_CALLS = 5; + private static final int DEFAULT_MAX_FUNCTION_CALLS = 1; + private static final String FUNCTION_NAME_1 = "getCurrentWeather"; + private static final String FUNCTION_NAME_2 = "getCurrentTemperature"; + private static final String STRING_PARAMETER_NAME = "location"; + private static final FunctionCall FUNCTION_CALL_1 = + FunctionCall.newBuilder() + .setName(FUNCTION_NAME_1) + .setArgs( + Struct.newBuilder() + .putFields( + STRING_PARAMETER_NAME, Value.newBuilder().setStringValue("Boston").build())) + .build(); + private static final FunctionCall FUNCTION_CALL_2 = + FunctionCall.newBuilder() + .setName(FUNCTION_NAME_2) + .setArgs( + Struct.newBuilder() + .putFields( + STRING_PARAMETER_NAME, + Value.newBuilder().setStringValue("Vancouver").build())) + .build(); + private static final FunctionCall FUNCTION_CALL_WITH_FALSE_FUNCTION_NAME = + FunctionCall.newBuilder() + .setName("nonExistFunction") + .setArgs( + Struct.newBuilder() + .putFields( + STRING_PARAMETER_NAME, Value.newBuilder().setStringValue("Boston").build())) + .build(); + private static final FunctionCall FUNCTION_CALL_WITH_FALSE_PARAMETER_NAME = + FunctionCall.newBuilder() + .setName(FUNCTION_NAME_1) + .setArgs( + Struct.newBuilder() + .putFields( + "nonExistParameter", Value.newBuilder().setStringValue("Boston").build())) + .build(); + private static final FunctionCall FUNCTION_CALL_WITH_FALSE_PARAMETER_VALUE = + FunctionCall.newBuilder() + .setName(FUNCTION_NAME_1) + .setArgs( + Struct.newBuilder() + .putFields(STRING_PARAMETER_NAME, Value.newBuilder().setBoolValue(false).build())) + .build(); + + public static String getCurrentWeather(String location) { + if (location.equals("Boston")) { + return "snowing"; + } else if (location.equals("Vancouver")) { + return "raining"; + } else { + return "sunny"; + } + } + + public static int getCurrentTemperature(String location) { + if (location.equals("Boston")) { + return 32; + } else if (location.equals("Vancouver")) { + return 45; + } else { + return 75; + } + } + + public boolean nonStaticMethod() { + return true; + } + + @Test + public void testInitAutomaticFunctionCallingResponder_containsRightFields() { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + + assertThat(responder.getMaxFunctionCalls()).isEqualTo(DEFAULT_MAX_FUNCTION_CALLS); + } + + @Test + public void testInitAutomaticFunctionCallingResponderWithMaxFunctionCalls_containsRightFields() { + AutomaticFunctionCallingResponder responder = + new AutomaticFunctionCallingResponder(MAX_FUNCTION_CALLS); + + assertThat(responder.getMaxFunctionCalls()).isEqualTo(MAX_FUNCTION_CALLS); + } + + @Test + public void testSetMaxFunctionCalls_containsRightFields() { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + responder.setMaxFunctionCalls(MAX_FUNCTION_CALLS); + + assertThat(responder.getMaxFunctionCalls()).isEqualTo(MAX_FUNCTION_CALLS); + } + + @Test + public void testAddCallableFunctionWithoutOrderedParameterNames_throwsIllegalArgumentException() + throws NoSuchMethodException { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + + IllegalStateException thrown = + assertThrows( + IllegalStateException.class, + () -> responder.addCallableFunction(FUNCTION_NAME_1, callableFunction)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "Failed to retrieve the parameter name from reflection. Please compile your code with" + + " \"-parameters\" flag or use `addCallableFunction(String, Method, String...)`" + + " to manually enter parameter names"); + } + + @Test + public void testAddNonStaticCallableFunction_throwsIllegalArgumentException() + throws NoSuchMethodException { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method nonStaticMethod = + AutomaticFunctionCallingResponderTest.class.getMethod("nonStaticMethod"); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + responder.addCallableFunction( + FUNCTION_NAME_1, nonStaticMethod, STRING_PARAMETER_NAME)); + assertThat(thrown).hasMessageThat().isEqualTo("Function calling only supports static methods."); + } + + @Test + public void testAddRepeatedCallableFunction_throwsIllegalArgumentException() + throws NoSuchMethodException { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + responder.addCallableFunction(FUNCTION_NAME_1, callableFunction, STRING_PARAMETER_NAME); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + responder.addCallableFunction( + FUNCTION_NAME_1, callableFunction, STRING_PARAMETER_NAME)); + assertThat(thrown).hasMessageThat().isEqualTo("Duplicate function name: " + FUNCTION_NAME_1); + } + + @Test + public void testAddCallableFunctionWithWrongParameterNames_throwsIllegalArgumentException() + throws NoSuchMethodException { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + responder.addCallableFunction( + FUNCTION_NAME_1, callableFunction, STRING_PARAMETER_NAME, "anotherParameter")); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "The number of provided parameter names doesn't match the number of parameters in the" + + " callable function."); + } + + @Test + public void testRespondToFunctionCall_returnsCorrectResponse() throws Exception { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(2); + Method callableFunction1 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + Method callableFunction2 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_2, String.class); + responder.addCallableFunction(FUNCTION_NAME_1, callableFunction1, STRING_PARAMETER_NAME); + responder.addCallableFunction(FUNCTION_NAME_2, callableFunction2, STRING_PARAMETER_NAME); + List functionCalls = Arrays.asList(FUNCTION_CALL_1, FUNCTION_CALL_2); + + Content response = responder.getContentFromFunctionCalls(functionCalls); + + Content expectedResponse = + ContentMaker.fromMultiModalData( + PartMaker.fromFunctionResponse( + FUNCTION_NAME_1, Collections.singletonMap("result", "snowing")), + PartMaker.fromFunctionResponse( + FUNCTION_NAME_2, Collections.singletonMap("result", 45))); + + assertThat(response).isEqualTo(expectedResponse); + } + + @Test + public void testRespondToFunctionCallExceedsMaxFunctionCalls_throwsIllegalStateException() + throws Exception { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction1 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + Method callableFunction2 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_2, String.class); + responder.addCallableFunction(FUNCTION_NAME_1, callableFunction1, STRING_PARAMETER_NAME); + responder.addCallableFunction(FUNCTION_NAME_2, callableFunction2, STRING_PARAMETER_NAME); + List functionCalls = Arrays.asList(FUNCTION_CALL_1, FUNCTION_CALL_2); + + IllegalStateException thrown = + assertThrows( + IllegalStateException.class, + () -> responder.getContentFromFunctionCalls(functionCalls)); + assertThat(thrown) + .hasMessageThat() + .contains("Exceeded the maximum number of continuous automatic function calls"); + } + + @Test + public void testRespondToFunctionCallWithNonExistFunction_throwsIllegalArgumentException() + throws Exception { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction1 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + responder.addCallableFunction(FUNCTION_NAME_1, callableFunction1, STRING_PARAMETER_NAME); + List functionCalls = Arrays.asList(FUNCTION_CALL_WITH_FALSE_FUNCTION_NAME); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> responder.getContentFromFunctionCalls(functionCalls)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo("Model has asked to call function \"nonExistFunction\" which was not found."); + } + + @Test + public void testRespondToFunctionCallWithNonExistParameter_throwsIllegalArgumentException() + throws Exception { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction1 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + responder.addCallableFunction(FUNCTION_NAME_1, callableFunction1, STRING_PARAMETER_NAME); + List functionCalls = Arrays.asList(FUNCTION_CALL_WITH_FALSE_PARAMETER_NAME); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> responder.getContentFromFunctionCalls(functionCalls)); + assertThat(thrown) + .hasMessageThat() + .contains( + "The parameter \"" + + STRING_PARAMETER_NAME + + "\" was not found in the arguments requested by the" + + " model."); + } + + @Test + public void testRespondToFunctionCallWithWrongParameterValue_throwsIllegalStateException() + throws Exception { + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + Method callableFunction1 = + AutomaticFunctionCallingResponderTest.class.getMethod(FUNCTION_NAME_1, String.class); + responder.addCallableFunction(FUNCTION_NAME_1, callableFunction1, STRING_PARAMETER_NAME); + List functionCalls = Arrays.asList(FUNCTION_CALL_WITH_FALSE_PARAMETER_VALUE); + + IllegalStateException thrown = + assertThrows( + IllegalStateException.class, + () -> responder.getContentFromFunctionCalls(functionCalls)); + assertThat(thrown) + .hasMessageThat() + .contains( + "Error raised when calling function \"" + + FUNCTION_NAME_1 + + "\" as requested by the model. "); + } +} diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index 26cb1f9cee3d..0537d96fb0dc 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -28,6 +28,7 @@ import com.google.cloud.vertexai.api.Candidate; import com.google.cloud.vertexai.api.Candidate.FinishReason; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; import com.google.cloud.vertexai.api.FunctionDeclaration; import com.google.cloud.vertexai.api.GenerateContentRequest; import com.google.cloud.vertexai.api.GenerateContentResponse; @@ -40,8 +41,11 @@ import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.api.Type; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import org.junit.Before; @@ -59,6 +63,10 @@ public final class ChatSessionTest { private static final String PROJECT = "test_project"; private static final String LOCATION = "test_location"; + private static final String FUNCTION_CALL_MESSAGE = "What is the current weather in Boston?"; + private static final String FUNCTION_CALL_NAME = "getCurrentWeather"; + private static final String FUNCTION_CALL_PARAMETER_NAME = "location"; + private static final String SAMPLE_MESSAGE1 = "how are you?"; private static final String RESPONSE_STREAM_CHUNK1_TEXT = "I do not have any feelings"; private static final String RESPONSE_STREAM_CHUNK2_TEXT = "But I'm happy to help you!"; @@ -117,6 +125,30 @@ public final class ChatSessionTest { .setContent( Content.newBuilder().addParts(Part.newBuilder().setText(FULL_RESPONSE_TEXT)))) .build(); + private static final GenerateContentResponse RESPONSE_WITH_FUNCTION_CALL = + GenerateContentResponse.newBuilder() + .addCandidates( + Candidate.newBuilder() + .setFinishReason(FinishReason.STOP) + .setContent( + Content.newBuilder() + .addParts( + Part.newBuilder() + .setFunctionCall( + FunctionCall.newBuilder() + .setName(FUNCTION_CALL_NAME) + .setArgs( + Struct.newBuilder() + .putFields( + FUNCTION_CALL_PARAMETER_NAME, + Value.newBuilder() + .setStringValue("Boston") + .build())))))) + .build(); + private static final Content FUNCTION_RESPONSE_CONTENT = + ContentMaker.fromMultiModalData( + PartMaker.fromFunctionResponse( + FUNCTION_CALL_NAME, Collections.singletonMap("result", "snowing"))); private static final GenerationConfig GENERATION_CONFIG = GenerationConfig.newBuilder().setCandidateCount(1).build(); @@ -156,6 +188,17 @@ public final class ChatSessionTest { private ChatSession chat; + /** Callable function getCurrentWeather for testing automatic function calling. */ + public static String getCurrentWeather(String location) { + if (location.equals("Boston")) { + return "snowing"; + } else if (location.equals("Vancouver")) { + return "raining"; + } else { + return "sunny"; + } + } + @Before public void doBeforeEachTest() { chat = new ChatSession(mockGenerativeModel); @@ -303,6 +346,151 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot assertThat(history.size()).isEqualTo(0); } + @Test + public void sendMessageWithAutomaticFunctionCallingResponder_autoRespondsToFunctionCalls() + throws IOException, NoSuchMethodException { + // (Arrange) Set up the return value of the generateContent + when(mockGenerativeModel.generateContent( + Arrays.asList(ContentMaker.fromString(FUNCTION_CALL_MESSAGE)))) + .thenReturn(RESPONSE_WITH_FUNCTION_CALL); + when(mockGenerativeModel.generateContent( + Arrays.asList( + ContentMaker.fromString(FUNCTION_CALL_MESSAGE), + ResponseHandler.getContent(RESPONSE_WITH_FUNCTION_CALL), + FUNCTION_RESPONSE_CONTENT))) + .thenReturn(RESPONSE_FROM_UNARY_CALL); + + // (Act) Send text message via sendMessage + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + responder.addCallableFunction( + FUNCTION_CALL_NAME, + ChatSessionTest.class.getMethod(FUNCTION_CALL_NAME, String.class), + FUNCTION_CALL_PARAMETER_NAME); + GenerateContentResponse response = + chat.withAutomaticFunctionCallingResponder(responder).sendMessage(FUNCTION_CALL_MESSAGE); + + // (Act & Assert) get history and assert that the history contains 4 contents and the response + // is the final response instead of the intermediate one. + assertThat(chat.getHistory().size()).isEqualTo(4); + assertThat(response).isEqualTo(RESPONSE_FROM_UNARY_CALL); + } + + @Test + public void sendMessageWithAutomaticFunctionCallingResponderIOException_chatHistoryGetReverted() + throws IOException, NoSuchMethodException { + // (Arrange) Set up the return value of the generateContent + when(mockGenerativeModel.generateContent( + Arrays.asList(ContentMaker.fromString(FUNCTION_CALL_MESSAGE)))) + .thenReturn(RESPONSE_WITH_FUNCTION_CALL); + when(mockGenerativeModel.generateContent( + Arrays.asList( + ContentMaker.fromString(FUNCTION_CALL_MESSAGE), + ResponseHandler.getContent(RESPONSE_WITH_FUNCTION_CALL), + FUNCTION_RESPONSE_CONTENT))) + .thenThrow(new IOException("Server error")); + + // (Act) Send text message via sendMessage + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + responder.addCallableFunction( + FUNCTION_CALL_NAME, + ChatSessionTest.class.getMethod(FUNCTION_CALL_NAME, String.class), + FUNCTION_CALL_PARAMETER_NAME); + + IOException thrown = + assertThrows( + IOException.class, + () -> + chat.withAutomaticFunctionCallingResponder(responder) + .sendMessage(FUNCTION_CALL_MESSAGE)); + assertThat(thrown).hasMessageThat().isEqualTo("Server error"); + + // (Act & Assert) get history and assert that the history contains no contents since the + // intermediate response got an error and all contents got reverted. + assertThat(chat.getHistory().size()).isEqualTo(0); + } + + @Test + public void + sendMessageWithAutomaticFunctionCallingResponderIllegalStateException_chatHistoryGetReverted() + throws IOException, NoSuchMethodException { + // (Arrange) Set up the return value of the generateContent + when(mockGenerativeModel.generateContent( + Arrays.asList(ContentMaker.fromString(FUNCTION_CALL_MESSAGE)))) + .thenReturn(RESPONSE_WITH_FUNCTION_CALL); + when(mockGenerativeModel.generateContent( + Arrays.asList( + ContentMaker.fromString(FUNCTION_CALL_MESSAGE), + ResponseHandler.getContent(RESPONSE_WITH_FUNCTION_CALL), + FUNCTION_RESPONSE_CONTENT))) + .thenReturn(RESPONSE_WITH_FUNCTION_CALL); + when(mockGenerativeModel.generateContent( + Arrays.asList( + ContentMaker.fromString(FUNCTION_CALL_MESSAGE), + ResponseHandler.getContent(RESPONSE_WITH_FUNCTION_CALL), + FUNCTION_RESPONSE_CONTENT, + ResponseHandler.getContent(RESPONSE_WITH_FUNCTION_CALL), + FUNCTION_RESPONSE_CONTENT))) + .thenReturn(RESPONSE_FROM_UNARY_CALL); + + // (Act) Send text message via sendMessage + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + responder.addCallableFunction( + FUNCTION_CALL_NAME, + ChatSessionTest.class.getMethod(FUNCTION_CALL_NAME, String.class), + FUNCTION_CALL_PARAMETER_NAME); + + // After mocking, there should be 2 consecutive auto function calls, but the max number of + // function calls in the responder is 1, so an IllegalStateException will be thrown. + IllegalStateException thrown = + assertThrows( + IllegalStateException.class, + () -> + chat.withAutomaticFunctionCallingResponder(responder) + .sendMessage(FUNCTION_CALL_MESSAGE)); + assertThat(thrown) + .hasMessageThat() + .contains("Exceeded the maximum number of continuous automatic function calls"); + + // (Act & Assert) get history and assert that the history contains no contents since the + // intermediate response got an error and all contents got reverted. + assertThat(chat.getHistory().size()).isEqualTo(0); + } + + @Test + public void + sendMessageWithAutomaticFunctionCallingResponderFinishReasonNotStop_chatHistoryGetReverted() + throws IOException, NoSuchMethodException { + // (Arrange) Set up the return value of the generateContent + when(mockGenerativeModel.generateContent( + Arrays.asList(ContentMaker.fromString(FUNCTION_CALL_MESSAGE)))) + .thenReturn(RESPONSE_WITH_FUNCTION_CALL); + when(mockGenerativeModel.generateContent( + Arrays.asList( + ContentMaker.fromString(FUNCTION_CALL_MESSAGE), + ResponseHandler.getContent(RESPONSE_WITH_FUNCTION_CALL), + FUNCTION_RESPONSE_CONTENT))) + .thenReturn(RESPONSE_FROM_UNARY_CALL_WITH_OTHER_FINISH_REASON); + + // (Act) Send text message via sendMessage + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + responder.addCallableFunction( + FUNCTION_CALL_NAME, + ChatSessionTest.class.getMethod(FUNCTION_CALL_NAME, String.class), + FUNCTION_CALL_PARAMETER_NAME); + GenerateContentResponse response = + chat.withAutomaticFunctionCallingResponder(responder).sendMessage(FUNCTION_CALL_MESSAGE); + + // (Act & Assert) get history will throw since the final response stopped with error. The + // history should be reverted. + assertThat(response).isEqualTo(RESPONSE_FROM_UNARY_CALL_WITH_OTHER_FINISH_REASON); + IllegalStateException thrown = + assertThrows(IllegalStateException.class, () -> chat.getHistory()); + assertThat(thrown).hasMessageThat().isEqualTo("Rerun getHistory() to get cleaned history."); + // Assert that the history can be fetched again and it's empty. + List history = chat.getHistory(); + assertThat(history.size()).isEqualTo(0); + } + @Test public void testChatSessionMergeHistoryToRootChatSession() throws Exception { diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java index f8894ba1ca99..a17c81a4d94c 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMakerTest.java @@ -21,14 +21,91 @@ import com.google.cloud.vertexai.api.FunctionDeclaration; import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Type; +import com.google.common.collect.ImmutableList; +import com.google.gson.Gson; import com.google.gson.JsonObject; import com.google.protobuf.InvalidProtocolBufferException; +import java.lang.reflect.Method; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public final class FunctionDeclarationMakerTest { + private static final String FUNCTION_NAME = "functionName"; + private static final String FUNCTION_DESCRIPTION = "functionDescription"; + private static final String STRING_PARAM_NAME = "stringParam"; + private static final String INTEGER_PARAM_NAME = "integerParam"; + private static final String DOUBLE_PARAM_NAME = "doubleParam"; + private static final String FLOAT_PARAM_NAME = "floatParam"; + private static final String BOOLEAN_PARAM_NAME = "booleanParam"; + private static final ImmutableList REQUIRED_PARAM_NAMES = + ImmutableList.of( + STRING_PARAM_NAME, + INTEGER_PARAM_NAME, + DOUBLE_PARAM_NAME, + FLOAT_PARAM_NAME, + BOOLEAN_PARAM_NAME); + + private static final FunctionDeclaration EXPECTED_FUNCTION_DECLARATION = + FunctionDeclaration.newBuilder() + .setName(FUNCTION_NAME) + .setDescription(FUNCTION_DESCRIPTION) + .setParameters( + Schema.newBuilder() + .setType(Type.OBJECT) + .putProperties( + STRING_PARAM_NAME, + Schema.newBuilder() + .setType(Type.STRING) + .setDescription(STRING_PARAM_NAME) + .build()) + .putProperties( + INTEGER_PARAM_NAME, + Schema.newBuilder() + .setType(Type.INTEGER) + .setDescription(INTEGER_PARAM_NAME) + .build()) + .putProperties( + DOUBLE_PARAM_NAME, + Schema.newBuilder() + .setType(Type.NUMBER) + .setDescription(DOUBLE_PARAM_NAME) + .build()) + .putProperties( + FLOAT_PARAM_NAME, + Schema.newBuilder() + .setType(Type.NUMBER) + .setDescription(FLOAT_PARAM_NAME) + .build()) + .putProperties( + BOOLEAN_PARAM_NAME, + Schema.newBuilder() + .setType(Type.BOOLEAN) + .setDescription(BOOLEAN_PARAM_NAME) + .build()) + .addAllRequired(REQUIRED_PARAM_NAMES)) + .build(); + + /** A function (static method) to test fromFunc functionalities. */ + public static int functionName( + String stringParam, + int integerParam, + double doubleParam, + float floatParam, + boolean booleanParam) { + return 0; + } + + /** An instance method to test fromFunc. */ + public int instanceMethod(String stringParam) { + return 1; + } + + /** A function with invalid parameter type to test fromFunc. */ + public static int functionWithInvalidType(Object objectParam) { + return 2; + } @Test public void fromValidJsonStringTested_returnsFunctionDeclaration() @@ -40,31 +117,35 @@ public void fromValidJsonStringTested_returnsFunctionDeclaration() + " \"parameters\": {\n" + " \"type\": \"OBJECT\", \n" + " \"properties\": {\n" - + " \"param1\": {\n" + + " \"stringParam\": {\n" + " \"type\": \"STRING\",\n" - + " \"description\": \"param1Description\"\n" + + " \"description\": \"stringParam\"\n" + + " },\n" + + " \"integerParam\": {\n" + + " \"type\": \"INTEGER\",\n" + + " \"description\": \"integerParam\"\n" + + " },\n" + + " \"doubleParam\": {\n" + + " \"type\": \"NUMBER\",\n" + + " \"description\": \"doubleParam\"\n" + + " },\n" + + " \"floatParam\": {\n" + + " \"type\": \"NUMBER\",\n" + + " \"description\": \"floatParam\"\n" + + " },\n" + + " \"booleanParam\": {\n" + + " \"type\": \"BOOLEAN\",\n" + + " \"description\": \"booleanParam\"\n" + " }\n" - + " }\n" + + " },\n" + + " \"required\": [\"stringParam\", \"integerParam\", \"doubleParam\"," + + " \"floatParam\", \"booleanParam\"]\n" + " }\n" + "}"; FunctionDeclaration functionDeclaration = FunctionDeclarationMaker.fromJsonString(jsonString); - FunctionDeclaration expectedFunctionDeclaration = - FunctionDeclaration.newBuilder() - .setName("functionName") - .setDescription("functionDescription") - .setParameters( - Schema.newBuilder() - .setType(Type.OBJECT) - .putProperties( - "param1", - Schema.newBuilder() - .setType(Type.STRING) - .setDescription("param1Description") - .build())) - .build(); - assertThat(functionDeclaration).isEqualTo(expectedFunctionDeclaration); + assertThat(functionDeclaration).isEqualTo(EXPECTED_FUNCTION_DECLARATION); } @Test @@ -140,38 +221,135 @@ public void fromJsonStringStringIsNull_throwsIllegalArgumentException() @Test public void fromValidJsonObject_returnsFunctionDeclaration() throws InvalidProtocolBufferException { - JsonObject param1JsonObject = new JsonObject(); - param1JsonObject.addProperty("type", "STRING"); - param1JsonObject.addProperty("description", "param1Description"); + JsonObject stringParamJsonObject = new JsonObject(); + stringParamJsonObject.addProperty("type", "STRING"); + stringParamJsonObject.addProperty("description", STRING_PARAM_NAME); + + JsonObject integerParamJsonObject = new JsonObject(); + integerParamJsonObject.addProperty("type", "INTEGER"); + integerParamJsonObject.addProperty("description", INTEGER_PARAM_NAME); + + JsonObject doubleParamJsonObject = new JsonObject(); + doubleParamJsonObject.addProperty("type", "NUMBER"); + doubleParamJsonObject.addProperty("description", DOUBLE_PARAM_NAME); + + JsonObject floatParamJsonObject = new JsonObject(); + floatParamJsonObject.addProperty("type", "NUMBER"); + floatParamJsonObject.addProperty("description", FLOAT_PARAM_NAME); + + JsonObject booleanParamJsonObject = new JsonObject(); + booleanParamJsonObject.addProperty("type", "BOOLEAN"); + booleanParamJsonObject.addProperty("description", BOOLEAN_PARAM_NAME); JsonObject propertiesJsonObject = new JsonObject(); - propertiesJsonObject.add("param1", param1JsonObject); + propertiesJsonObject.add(STRING_PARAM_NAME, stringParamJsonObject); + propertiesJsonObject.add(INTEGER_PARAM_NAME, integerParamJsonObject); + propertiesJsonObject.add(DOUBLE_PARAM_NAME, doubleParamJsonObject); + propertiesJsonObject.add(FLOAT_PARAM_NAME, floatParamJsonObject); + propertiesJsonObject.add(BOOLEAN_PARAM_NAME, booleanParamJsonObject); JsonObject parametersJsonObject = new JsonObject(); parametersJsonObject.addProperty("type", "OBJECT"); parametersJsonObject.add("properties", propertiesJsonObject); + parametersJsonObject.add( + "required", new Gson().toJsonTree(REQUIRED_PARAM_NAMES).getAsJsonArray()); JsonObject jsonObject = new JsonObject(); - jsonObject.addProperty("name", "functionName"); - jsonObject.addProperty("description", "functionDescription"); + jsonObject.addProperty("name", FUNCTION_NAME); + jsonObject.addProperty("description", FUNCTION_DESCRIPTION); jsonObject.add("parameters", parametersJsonObject); FunctionDeclaration functionDeclaration = FunctionDeclarationMaker.fromJsonObject(jsonObject); - FunctionDeclaration expectedFunctionDeclaration = - FunctionDeclaration.newBuilder() - .setName("functionName") - .setDescription("functionDescription") - .setParameters( - Schema.newBuilder() - .setType(Type.OBJECT) - .putProperties( - "param1", - Schema.newBuilder() - .setType(Type.STRING) - .setDescription("param1Description") - .build())) - .build(); - assertThat(functionDeclaration).isEqualTo(expectedFunctionDeclaration); + assertThat(functionDeclaration).isEqualTo(EXPECTED_FUNCTION_DECLARATION); + } + + @Test + public void fromFuncWithoutParameterNamesWithoutReflection_throwsIllegalStateException() + throws NoSuchMethodException { + Method function = + FunctionDeclarationMakerTest.class.getMethod( + FUNCTION_NAME, String.class, int.class, double.class, float.class, boolean.class); + + IllegalStateException thrown = + assertThrows( + IllegalStateException.class, + () -> FunctionDeclarationMaker.fromFunc(FUNCTION_DESCRIPTION, function)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "Failed to retrieve the parameter name from reflection. Please compile your" + + " code with \"-parameters\" flag or use `fromFunc(String, Method," + + " String...)` to manually enter parameter names"); + } + + @Test + public void fromFuncWithParameterNames_returnsFunctionDeclaration() throws NoSuchMethodException { + Method function = + FunctionDeclarationMakerTest.class.getMethod( + FUNCTION_NAME, String.class, int.class, double.class, float.class, boolean.class); + + FunctionDeclaration functionDeclaration = + FunctionDeclarationMaker.fromFunc( + FUNCTION_DESCRIPTION, + function, + STRING_PARAM_NAME, + INTEGER_PARAM_NAME, + DOUBLE_PARAM_NAME, + FLOAT_PARAM_NAME, + BOOLEAN_PARAM_NAME); + + assertThat(functionDeclaration).isEqualTo(EXPECTED_FUNCTION_DECLARATION); + } + + @Test + public void fromFuncWithInstanceMethod_throwsIllegalArgumentException() + throws NoSuchMethodException { + Method function = FunctionDeclarationMakerTest.class.getMethod("instanceMethod", String.class); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + FunctionDeclarationMaker.fromFunc( + FUNCTION_DESCRIPTION, function, STRING_PARAM_NAME)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo("Instance methods are not supported. Please use static methods."); + } + + @Test + public void fromFuncWithInvalidParameterType_throwsIllegalArgumentException() + throws NoSuchMethodException { + Method function = + FunctionDeclarationMakerTest.class.getMethod("functionWithInvalidType", Object.class); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> FunctionDeclarationMaker.fromFunc(FUNCTION_DESCRIPTION, function, "objectParam")); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "Unsupported parameter type " + Object.class.getName() + " for parameter objectParam"); + } + + @Test + public void fromFuncWithUnmatchedParameterNames_throwsIllegalArgumentException() + throws NoSuchMethodException { + Method function = + FunctionDeclarationMakerTest.class.getMethod( + FUNCTION_NAME, String.class, int.class, double.class, float.class, boolean.class); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + FunctionDeclarationMaker.fromFunc( + FUNCTION_DESCRIPTION, function, STRING_PARAM_NAME)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + "The number of parameter names does not match the number of parameters in the method."); } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java index 03d65fe4ab24..1e3e5152044d 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java @@ -24,6 +24,7 @@ import com.google.cloud.vertexai.api.HarmCategory; import com.google.cloud.vertexai.api.SafetySetting; import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.generativeai.AutomaticFunctionCallingResponder; import com.google.cloud.vertexai.generativeai.ChatSession; import com.google.cloud.vertexai.generativeai.ContentMaker; import com.google.cloud.vertexai.generativeai.FunctionDeclarationMaker; @@ -32,13 +33,12 @@ import com.google.cloud.vertexai.generativeai.ResponseHandler; import com.google.cloud.vertexai.generativeai.ResponseStream; import com.google.common.collect.ImmutableList; -import com.google.gson.JsonObject; import java.io.IOException; +import java.lang.reflect.Method; import java.util.Collections; import java.util.logging.Logger; import org.junit.After; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -55,6 +55,17 @@ public class ITChatSessionIntegrationTest { private GenerativeModel model; private ChatSession chat; + /** Callable function getCurrentWeather for testing automatic function calling. */ + public static String getCurrentWeather(String location) { + if (location.equals("Boston")) { + return "snowing"; + } else if (location.equals("Vancouver")) { + return "raining"; + } else { + return "sunny"; + } + } + @Before public void setUp() throws IOException { vertexAi = new VertexAI(PROJECT_ID, LOCATION); @@ -89,9 +100,6 @@ private static void assertSizeAndAlternatingRolesInHistory( } } - @Ignore( - "TODO(b/335830545): The Gen AI API is too flaky to handle three sets of simultanenous IT on" - + " the GitHub side.") @Test public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException { // Arrange @@ -119,9 +127,6 @@ public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException { ImmutableList.of(expectedFirstContent, expectedThirdContent)); } - @Ignore( - "TODO(b/335830545): The Gen AI API is too flaky to handle three sets of simultanenous IT on" - + " the GitHub side.") @Test public void sendMessageWithNewConfigs_historyContainsFullConversation() throws IOException { // Arrange @@ -162,47 +167,20 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I ImmutableList.of(expectedFirstContent, expectedThirdContent)); } - @Ignore( - "TODO(b/335830545): The Gen AI API is too flaky to handle three sets of simultanenous IT on" - + " the GitHub side.") @Test - public void sendMessageWithFunctionCalling_functionCallInResponse() throws IOException { + public void sendMessageWithFunctionCalling_functionCallInResponse() + throws IOException, NoSuchMethodException { // Arrange String firstMessage = "hello!"; String secondMessage = "What is the weather in Boston?"; - // Making an Json object representing a function declaration - // The following code makes a function declaration - // { - // "name": "getCurrentWeather", - // "description": "Get the current weather in a given location", - // "parameters": { - // "type": "OBJECT", - // "properties": { - // "location": { - // "type": "STRING", - // "description": "location" - // } - // } - // } - // } - JsonObject locationJsonObject = new JsonObject(); - locationJsonObject.addProperty("type", "STRING"); - locationJsonObject.addProperty("description", "location"); - - JsonObject propertiesJsonObject = new JsonObject(); - propertiesJsonObject.add("location", locationJsonObject); - - JsonObject parametersJsonObject = new JsonObject(); - parametersJsonObject.addProperty("type", "OBJECT"); - parametersJsonObject.add("properties", propertiesJsonObject); - - JsonObject jsonObject = new JsonObject(); - jsonObject.addProperty("name", "getCurrentWeather"); - jsonObject.addProperty("description", "Get the current weather in a given location"); - jsonObject.add("parameters", parametersJsonObject); + + Method function = + ITChatSessionIntegrationTest.class.getMethod("getCurrentWeather", String.class); Tool tool = Tool.newBuilder() - .addFunctionDeclarations(FunctionDeclarationMaker.fromJsonObject(jsonObject)) + .addFunctionDeclarations( + FunctionDeclarationMaker.fromFunc( + "Get the current weather in a given location", function, "location")) .build(); ImmutableList tools = ImmutableList.of(tool); @@ -232,4 +210,45 @@ public void sendMessageWithFunctionCalling_functionCallInResponse() throws IOExc ContentMaker.fromString(secondMessage), functionResponse)); } + + @Test + public void sendMessageWithAutomaticFunctionCalling_autoRespondToFunctionCall() + throws IOException, NoSuchMethodException { + // Arrange + String message = "What is the weather in Boston?"; + Method function = + ITChatSessionIntegrationTest.class.getMethod("getCurrentWeather", String.class); + Tool tool = + Tool.newBuilder() + .addFunctionDeclarations( + FunctionDeclarationMaker.fromFunc( + "Get the current weather in a given location", function, "location")) + .build(); + ImmutableList tools = ImmutableList.of(tool); + + AutomaticFunctionCallingResponder responder = new AutomaticFunctionCallingResponder(); + responder.addCallableFunction("getCurrentWeather", function, "location"); + + // Act + chat = model.startChat(); + GenerateContentResponse response = + chat.withTools(tools).withAutomaticFunctionCallingResponder(responder).sendMessage(message); + + // Assert + assertThat(response.getCandidatesList()).hasSize(1); + // The final response should not contain any function calls since the function was called + // automatically. + assertThat(ResponseHandler.getFunctionCalls(response)).isEmpty(); + + ImmutableList history = chat.getHistory(); + Content expectedFunctionResponse = + ContentMaker.fromMultiModalData( + PartMaker.fromFunctionResponse( + "getCurrentWeather", Collections.singletonMap("result", "snowing"))); + assertSizeAndAlternatingRolesInHistory( + Thread.currentThread().getStackTrace()[1].getMethodName(), + history, + 4, + ImmutableList.of(ContentMaker.fromString(message), expectedFunctionResponse)); + } }