Skip to content

Commit

Permalink
feat: [vertexai] support ToolConfig in GenerativeModel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642059737
  • Loading branch information
jaycee-li authored and copybara-github committed Jun 10, 2024
1 parent d3cf203 commit 4637ad8
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand All @@ -46,6 +47,7 @@ public final class GenerativeModel {
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;
private final Optional<ToolConfig> toolConfig;
private final Optional<Content> systemInstruction;

/**
Expand All @@ -65,6 +67,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
Optional.empty(),
vertexAi);
}

Expand All @@ -79,6 +82,10 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
* that will be used by default for generating response
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
* the model as auxiliary tools to generate content.
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} instance that will be used
* to specify the tool configuration.
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} instance that will be
* used by default for generating response.
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
Expand All @@ -87,6 +94,7 @@ private GenerativeModel(
GenerationConfig generationConfig,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
Optional<ToolConfig> toolConfig,
Optional<Content> systemInstruction,
VertexAI vertexAi) {
checkArgument(
Expand All @@ -98,6 +106,8 @@ private GenerativeModel(
checkNotNull(generationConfig, "GenerationConfig can't be null.");
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
checkNotNull(tools, "ImmutableList<Tool> can't be null.");
checkNotNull(toolConfig, "Optional<ToolConfig> can't be null.");
checkNotNull(systemInstruction, "Optional<Content> can't be null.");

this.resourceName = getResourceName(modelName, vertexAi);
// reconcileModelName should be called after getResourceName.
Expand All @@ -106,6 +116,7 @@ private GenerativeModel(
this.generationConfig = generationConfig;
this.safetySettings = safetySettings;
this.tools = tools;
this.toolConfig = toolConfig;
// We remove the role in the system instruction content because it's officially documented
// to be used without role specified:
// https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-system-instruction
Expand All @@ -128,6 +139,7 @@ public static class Builder {
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private Optional<ToolConfig> toolConfig = Optional.empty();
private Optional<Content> systemInstruction = Optional.empty();

public GenerativeModel build() {
Expand All @@ -136,7 +148,13 @@ public GenerativeModel build() {
"modelName is required. Please call setModelName() before building.");
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
modelName,
generationConfig,
safetySettings,
tools,
toolConfig,
systemInstruction,
vertexAi);
}

/**
Expand Down Expand Up @@ -204,6 +222,19 @@ public Builder setTools(List<Tool> tools) {
return this;
}

/**
* Sets a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used by default to
* interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setToolConfig(ToolConfig toolConfig) {
checkNotNull(
toolConfig,
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
this.toolConfig = Optional.of(toolConfig);
return this;
}

/**
* Sets a system instruction that will be used by default to interact with the generative model.
*/
Expand All @@ -228,7 +259,13 @@ public Builder setSystemInstruction(Content systemInstruction) {
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
checkNotNull(generationConfig, "GenerationConfig can't be null.");
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
modelName,
generationConfig,
safetySettings,
tools,
toolConfig,
systemInstruction,
vertexAi);
}

/**
Expand All @@ -247,6 +284,7 @@ public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
generationConfig,
ImmutableList.copyOf(safetySettings),
tools,
toolConfig,
systemInstruction,
vertexAi);
}
Expand All @@ -265,6 +303,28 @@ public GenerativeModel withTools(List<Tool> tools) {
generationConfig,
safetySettings,
ImmutableList.copyOf(tools),
toolConfig,
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated tool config.
*
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
* new model.
* @return a new {@link GenerativeModel} instance with the specified tool config.
*/
public GenerativeModel withToolConfig(ToolConfig toolConfig) {
checkNotNull(
toolConfig,
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
return new GenerativeModel(
modelName,
generationConfig,
safetySettings,
tools,
Optional.of(toolConfig),
systemInstruction,
vertexAi);
}
Expand All @@ -286,6 +346,7 @@ public GenerativeModel withSystemInstruction(Content systemInstruction) {
generationConfig,
safetySettings,
tools,
toolConfig,
Optional.of(systemInstruction),
vertexAi);
}
Expand Down Expand Up @@ -537,6 +598,10 @@ private GenerateContentRequest buildGenerateContentRequest(List<Content> content
.addAllSafetySettings(safetySettings)
.addAllTools(tools);

if (toolConfig.isPresent()) {
requestBuilder.setToolConfig(toolConfig.get());
}

if (systemInstruction.isPresent()) {
requestBuilder.setSystemInstruction(systemInstruction.get());
}
Expand Down Expand Up @@ -568,6 +633,13 @@ public ImmutableList<Tool> getTools() {
return tools;
}

/**
* Returns the optional {@link com.google.cloud.vertexai.api.ToolConfig} of this generative model.
*/
public Optional<ToolConfig> getToolConfig() {
return toolConfig;
}

/** Returns the optional system instruction of this generative model. */
public Optional<Content> getSystemInstruction() {
return systemInstruction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.FunctionCallingConfig;
import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
Expand All @@ -44,6 +45,7 @@
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.api.VertexAISearch;
import java.util.ArrayList;
Expand Down Expand Up @@ -96,6 +98,13 @@ public final class GenerativeModelTest {
.build())
.addRequired("location")))
.build();
private static final ToolConfig DEFAULT_TOOL_CONFIG =
ToolConfig.newBuilder()
.setFunctionCallingConfig(
FunctionCallingConfig.newBuilder()
.setMode(FunctionCallingConfig.Mode.ANY)
.addAllowedFunctionNames("getCurrentWeather"))
.build();
private static final Content DEFAULT_SYSTEM_INSTRUCTION =
ContentMaker.fromString(
"You're a helpful assistant that starts all its answers with: \"COOL\"");
Expand Down Expand Up @@ -404,6 +413,25 @@ public void generateContent_withDefaultTools_requestHasCorrectToolsAndText() thr
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContent_withDefaultToolConfig_requestHasCorrectToolConfigAndText()
throws Exception {
model =
new GenerativeModel.Builder()
.setModelName(MODEL_NAME)
.setVertexAi(vertexAi)
.setToolConfig(DEFAULT_TOOL_CONFIG)
.build();

GenerateContentResponse unused = model.generateContent(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
}

@Test
public void
generateContent_withDefaultSystemInstruction_requestHasCorrectSystemInstructionAndText()
Expand Down Expand Up @@ -433,6 +461,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.withToolConfig(DEFAULT_TOOL_CONFIG)
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
.generateContent(TEXT);

Expand All @@ -444,6 +473,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
}

Expand Down Expand Up @@ -546,6 +576,24 @@ public void generateContentStream_withDefaultTools_requestHasCorrectTools() thro
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContentStream_withDefaultToolConfig_requestHasCorrectToolConfig()
throws Exception {
model =
new GenerativeModel.Builder()
.setModelName(MODEL_NAME)
.setVertexAi(vertexAi)
.setToolConfig(DEFAULT_TOOL_CONFIG)
.build();

ResponseStream unused = model.generateContentStream(TEXT);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
}

@Test
public void
generateContentStream_withDefaultSystemInstruction_requestHasCorrectSystemInstruction()
Expand Down Expand Up @@ -576,6 +624,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
.withGenerationConfig(GENERATION_CONFIG)
.withSafetySettings(safetySettings)
.withTools(tools)
.withToolConfig(DEFAULT_TOOL_CONFIG)
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
.generateContentStream(TEXT);

Expand All @@ -587,6 +636,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
}

Expand Down

0 comments on commit 4637ad8

Please sign in to comment.