Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [vertexai] support ToolConfig in GenerativeModel #10950

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand All @@ -46,6 +47,7 @@ public final class GenerativeModel {
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;
private final Optional<ToolConfig> toolConfig;
private final Optional<Content> systemInstruction;

/**
Expand All @@ -65,6 +67,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
Optional.empty(),
vertexAi);
}

Expand All @@ -79,6 +82,10 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
* that will be used by default for generating response
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
* the model as auxiliary tools to generate content.
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} instance that will be used
* to specify the tool configuration.
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} instance that will be
* used by default for generating response.
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
Expand All @@ -87,6 +94,7 @@ private GenerativeModel(
GenerationConfig generationConfig,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
Optional<ToolConfig> toolConfig,
Optional<Content> systemInstruction,
VertexAI vertexAi) {
checkArgument(
Expand All @@ -98,6 +106,8 @@ private GenerativeModel(
checkNotNull(generationConfig, "GenerationConfig can't be null.");
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
checkNotNull(tools, "ImmutableList<Tool> can't be null.");
checkNotNull(toolConfig, "Optional<ToolConfig> can't be null.");
checkNotNull(systemInstruction, "Optional<Content> can't be null.");

this.resourceName = getResourceName(modelName, vertexAi);
// reconcileModelName should be called after getResourceName.
Expand All @@ -106,6 +116,7 @@ private GenerativeModel(
this.generationConfig = generationConfig;
this.safetySettings = safetySettings;
this.tools = tools;
this.toolConfig = toolConfig;
// We remove the role in the system instruction content because it's officially documented
// to be used without role specified:
// https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-system-instruction
Expand All @@ -128,6 +139,7 @@ public static class Builder {
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private Optional<ToolConfig> toolConfig = Optional.empty();
private Optional<Content> systemInstruction = Optional.empty();

public GenerativeModel build() {
Expand All @@ -136,7 +148,13 @@ public GenerativeModel build() {
"modelName is required. Please call setModelName() before building.");
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
modelName,
generationConfig,
safetySettings,
tools,
toolConfig,
systemInstruction,
vertexAi);
}

/**
Expand Down Expand Up @@ -204,6 +222,19 @@ public Builder setTools(List<Tool> tools) {
return this;
}

/**
* Sets a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used by default to
* interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setToolConfig(ToolConfig toolConfig) {
checkNotNull(
toolConfig,
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
this.toolConfig = Optional.of(toolConfig);
return this;
}

/**
* Sets a system instruction that will be used by default to interact with the generative model.
*/
Expand All @@ -228,7 +259,13 @@ public Builder setSystemInstruction(Content systemInstruction) {
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
checkNotNull(generationConfig, "GenerationConfig can't be null.");
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
modelName,
generationConfig,
safetySettings,
tools,
toolConfig,
systemInstruction,
vertexAi);
}

/**
Expand All @@ -247,6 +284,7 @@ public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
generationConfig,
ImmutableList.copyOf(safetySettings),
tools,
toolConfig,
systemInstruction,
vertexAi);
}
Expand All @@ -265,6 +303,28 @@ public GenerativeModel withTools(List<Tool> tools) {
generationConfig,
safetySettings,
ImmutableList.copyOf(tools),
toolConfig,
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated tool config.
*
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
* new model.
* @return a new {@link GenerativeModel} instance with the specified tool config.
*/
public GenerativeModel withToolConfig(ToolConfig toolConfig) {
checkNotNull(
toolConfig,
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
return new GenerativeModel(
modelName,
generationConfig,
safetySettings,
tools,
Optional.of(toolConfig),
systemInstruction,
vertexAi);
}
Expand All @@ -286,6 +346,7 @@ public GenerativeModel withSystemInstruction(Content systemInstruction) {
generationConfig,
safetySettings,
tools,
toolConfig,
Optional.of(systemInstruction),
vertexAi);
}
Expand Down Expand Up @@ -537,6 +598,10 @@ private GenerateContentRequest buildGenerateContentRequest(List<Content> content
.addAllSafetySettings(safetySettings)
.addAllTools(tools);

if (toolConfig.isPresent()) {
requestBuilder.setToolConfig(toolConfig.get());
}

if (systemInstruction.isPresent()) {
requestBuilder.setSystemInstruction(systemInstruction.get());
}
Expand Down Expand Up @@ -568,6 +633,13 @@ public ImmutableList<Tool> getTools() {
return tools;
}

/**
* Returns the optional {@link com.google.cloud.vertexai.api.ToolConfig} of this generative model.
*/
public Optional<ToolConfig> getToolConfig() {
return toolConfig;
}

/** Returns the optional system instruction of this generative model. */
public Optional<Content> getSystemInstruction() {
return systemInstruction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.FunctionCallingConfig;
import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
Expand All @@ -44,6 +45,7 @@
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.api.VertexAISearch;
import java.util.ArrayList;
Expand Down Expand Up @@ -96,6 +98,13 @@ public final class GenerativeModelTest {
.build())
.addRequired("location")))
.build();
private static final ToolConfig DEFAULT_TOOL_CONFIG =
ToolConfig.newBuilder()
.setFunctionCallingConfig(
FunctionCallingConfig.newBuilder()
.setMode(FunctionCallingConfig.Mode.ANY)
.addAllowedFunctionNames("getCurrentWeather"))
.build();
private static final Content DEFAULT_SYSTEM_INSTRUCTION =
ContentMaker.fromString(
"You're a helpful assistant that starts all its answers with: \"COOL\"");
Expand Down Expand Up @@ -404,6 +413,25 @@ public void generateContent_withDefaultTools_requestHasCorrectToolsAndText() thr
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContent_withDefaultToolConfig_requestHasCorrectToolConfigAndText()
throws Exception {
model =
new GenerativeModel.Builder()
.setModelName(MODEL_NAME)
.setVertexAi(vertexAi)
.setToolConfig(DEFAULT_TOOL_CONFIG)
.build();

GenerateContentResponse unused = model.generateContent(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
}

@Test
public void
generateContent_withDefaultSystemInstruction_requestHasCorrectSystemInstructionAndText()
Expand Down Expand Up @@ -433,6 +461,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.withToolConfig(DEFAULT_TOOL_CONFIG)
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
.generateContent(TEXT);

Expand All @@ -444,6 +473,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
}

Expand Down Expand Up @@ -546,6 +576,24 @@ public void generateContentStream_withDefaultTools_requestHasCorrectTools() thro
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContentStream_withDefaultToolConfig_requestHasCorrectToolConfig()
throws Exception {
model =
new GenerativeModel.Builder()
.setModelName(MODEL_NAME)
.setVertexAi(vertexAi)
.setToolConfig(DEFAULT_TOOL_CONFIG)
.build();

ResponseStream unused = model.generateContentStream(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
}

@Test
public void
generateContentStream_withDefaultSystemInstruction_requestHasCorrectSystemInstruction()
Expand Down Expand Up @@ -576,6 +624,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.withToolConfig(DEFAULT_TOOL_CONFIG)
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
.generateContentStream(TEXT);

Expand All @@ -587,6 +636,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
}

Expand Down
Loading