From f19d5edf6bded5cfda2598dba4bfbaeac52f78d3 Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Thu, 13 Jun 2024 14:55:02 -0700 Subject: [PATCH] docs: add thread safety information to GenerativeModel and ChatSession classes. PiperOrigin-RevId: 643130259 --- .../vertexai/generativeai/ChatSession.java | 108 ++++++++++++------ .../generativeai/GenerativeModel.java | 13 ++- .../generativeai/ChatSessionTest.java | 20 +++- 3 files changed, 103 insertions(+), 38 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java index 5d3e7dd4df12..9f5b07f1eb93 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java @@ -29,19 +29,24 @@ import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.SafetySetting; import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.ToolConfig; import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Optional; -/** Represents a conversation between the user and the model */ +/** + * Represents a conversation between the user and the model. + * + *

Note: this class is NOT thread-safe. + */ public final class ChatSession { private final GenerativeModel model; private final Optional rootChatSession; private final Optional automaticFunctionCallingResponder; - private List history = new ArrayList<>(); - private int previousHistorySize = 0; + private List history; + private int previousHistorySize; private Optional> currentResponseStream; private Optional currentResponse; @@ -50,7 +55,7 @@ public final class ChatSession { * GenerationConfig) inherits from the model. */ public ChatSession(GenerativeModel model) { - this(model, Optional.empty(), Optional.empty()); + this(model, new ArrayList<>(), 0, Optional.empty(), Optional.empty()); } /** @@ -58,6 +63,9 @@ public ChatSession(GenerativeModel model) { * Configurations of the chat (e.g., GenerationConfig) inherits from the model. * * @param model a {@link GenerativeModel} instance that generates contents in the chat. + * @param history a list of {@link Content} containing interleaving conversation between "user" + * and "model". + * @param previousHistorySize the size of the previous history. * @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current * chat session will be merged to the root chat session. * @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance @@ -66,10 +74,14 @@ public ChatSession(GenerativeModel model) { */ private ChatSession( GenerativeModel model, + List history, + int previousHistorySize, Optional rootChatSession, Optional automaticFunctionCallingResponder) { checkNotNull(model, "model should not be null"); this.model = model; + this.history = history; + this.previousHistorySize = previousHistorySize; this.rootChatSession = rootChatSession; this.automaticFunctionCallingResponder = automaticFunctionCallingResponder; currentResponseStream = Optional.empty(); @@ -84,15 +96,12 @@ private ChatSession( * @return a new {@link ChatSession} instance with the specified GenerationConfig. */ public ChatSession withGenerationConfig(GenerationConfig generationConfig) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model.withGenerationConfig(generationConfig), - Optional.of(rootChat), - automaticFunctionCallingResponder); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model.withGenerationConfig(generationConfig), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); } /** @@ -103,15 +112,12 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) { * @return a new {@link ChatSession} instance with the specified SafetySettings. */ public ChatSession withSafetySettings(List safetySettings) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model.withSafetySettings(safetySettings), - Optional.of(rootChat), - automaticFunctionCallingResponder); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model.withSafetySettings(safetySettings), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); } /** @@ -122,13 +128,44 @@ public ChatSession withSafetySettings(List safetySettings) { * @return a new {@link ChatSession} instance with the specified Tools. */ public ChatSession withTools(List tools) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model.withTools(tools), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); + } + + /** + * Creates a copy of the current ChatSession with updated ToolConfig. + * + * @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the + * new ChatSession. + * @return a new {@link ChatSession} instance with the specified ToolConfigs. + */ + public ChatSession withToolConfig(ToolConfig toolConfig) { + return new ChatSession( + model.withToolConfig(toolConfig), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); + } + + /** + * Creates a copy of the current ChatSession with updated SystemInstruction. + * + * @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system + * instructions. + * @return a new {@link ChatSession} instance with the specified ToolConfigs. + */ + public ChatSession withSystemInstruction(Content systemInstruction) { + return new ChatSession( + model.withSystemInstruction(systemInstruction), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); } /** @@ -141,13 +178,12 @@ public ChatSession withTools(List tools) { */ public ChatSession withAutomaticFunctionCallingResponder( AutomaticFunctionCallingResponder automaticFunctionCallingResponder) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder)); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model, + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + Optional.of(automaticFunctionCallingResponder)); } /** diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index d88fdc5da081..ed86b5993607 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -39,7 +39,13 @@ import java.util.List; import java.util.Optional; -/** This class holds a generative model that can complete what you provided. */ +/** + * This class holds a generative model that can complete what you provided. This class is + * thread-safe. + * + *

Note: The instances of {@link ChatSession} returned by {@link GenerativeModel#startChat()} are + * NOT thread-safe. + */ public final class GenerativeModel { private final String modelName; private final String resourceName; @@ -645,6 +651,11 @@ public Optional getSystemInstruction() { return systemInstruction; } + /** + * Returns a new {@link ChatSession} instance that can be used to start a chat with this model. + * + *

Note: the returned {@link ChatSession} instance is NOT thread-safe. + */ public ChatSession startChat() { return new ChatSession(this); } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index 0537d96fb0dc..aa0c7bf911df 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -29,6 +29,7 @@ import com.google.cloud.vertexai.api.Candidate.FinishReason; import com.google.cloud.vertexai.api.Content; import com.google.cloud.vertexai.api.FunctionCall; +import com.google.cloud.vertexai.api.FunctionCallingConfig; import com.google.cloud.vertexai.api.FunctionDeclaration; import com.google.cloud.vertexai.api.GenerateContentRequest; import com.google.cloud.vertexai.api.GenerateContentResponse; @@ -40,6 +41,7 @@ import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold; import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.ToolConfig; import com.google.cloud.vertexai.api.Type; import com.google.protobuf.Struct; import com.google.protobuf.Value; @@ -174,6 +176,16 @@ public final class ChatSessionTest { .build()) .addRequired("location"))) .build(); + private static final ToolConfig TOOL_CONFIG = + ToolConfig.newBuilder() + .setFunctionCallingConfig( + FunctionCallingConfig.newBuilder() + .setMode(FunctionCallingConfig.Mode.ANY) + .addAllowedFunctionNames("getCurrentWeather")) + .build(); + private static final Content SYSTEM_INSTRUCTION = + ContentMaker.fromString( + "You're a helpful assistant that starts all its answers with: \"COOL\""); @Rule public final MockitoRule mocksRule = MockitoJUnit.rule(); @@ -518,7 +530,9 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception { rootChat .withGenerationConfig(GENERATION_CONFIG) .withSafetySettings(Arrays.asList(SAFETY_SETTING)) - .withTools(Arrays.asList(TOOL)); + .withTools(Arrays.asList(TOOL)) + .withToolConfig(TOOL_CONFIG) + .withSystemInstruction(SYSTEM_INSTRUCTION); response = childChat.sendMessage(SAMPLE_MESSAGE_2); // (Assert) root chat history should contain all 4 contents @@ -532,8 +546,12 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception { ArgumentCaptor request = ArgumentCaptor.forClass(GenerateContentRequest.class); verify(mockUnaryCallable, times(2)).call(request.capture()); + Content expectedSystemInstruction = SYSTEM_INSTRUCTION.toBuilder().clearRole().build(); assertThat(request.getAllValues().get(1).getGenerationConfig()).isEqualTo(GENERATION_CONFIG); assertThat(request.getAllValues().get(1).getSafetySettings(0)).isEqualTo(SAFETY_SETTING); assertThat(request.getAllValues().get(1).getTools(0)).isEqualTo(TOOL); + assertThat(request.getAllValues().get(1).getToolConfig()).isEqualTo(TOOL_CONFIG); + assertThat(request.getAllValues().get(1).getSystemInstruction()) + .isEqualTo(expectedSystemInstruction); } }