Skip to content

Commit

Permalink
feat: [vertexai] add GenerateContentConfig to generateContentStream m…
Browse files Browse the repository at this point in the history
…ethod (#10424)

PiperOrigin-RevId: 609300701

Co-authored-by: Jaycee Li <jayceeli@google.com>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Feb 27, 2024
1 parent 04e9574 commit ec9dd00
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,21 @@ public ResponseStream<GenerateContentResponse> generateContentStream(String text
return generateContentStream(text, null, null);
}

/**
* Generate content with streaming support from generative model given a text and configs.
*
* @param text a text message to send to the generative model
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
* generate content api call
* @return a {@link ResponseStream} that contains a streaming of {@link
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
public ResponseStream<GenerateContentResponse> generateContentStream(
String text, GenerateContentConfig config) throws IOException {
return generateContentStream(ContentMaker.fromString(text), config);
}

/**
* Generate content with streaming support from generative model given a text and generation
* config.
Expand Down Expand Up @@ -716,6 +731,22 @@ public ResponseStream<GenerateContentResponse> generateContentStream(Content con
return generateContentStream(content, null, null);
}

/**
* Generate content with streaming support from generative model given a single content and
* configs.
*
* @param content a {@link com.google.cloud.vertexai.api.Content} to send to the generative model
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
* generate content api call
* @return a {@link ResponseStream} that contains a streaming of {@link
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
public ResponseStream<GenerateContentResponse> generateContentStream(
Content content, GenerateContentConfig config) throws IOException {
return generateContentStream(Arrays.asList(content), config);
}

/**
* Generate content with streaming support from generative model given a single Content and
* generation config.
Expand Down Expand Up @@ -856,6 +887,41 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
return generateContentStream(requestBuilder);
}

/**
* Generate content with streaming support from generative model given a list of contents and
* configs.
*
* @param contents a list of {@link com.google.cloud.vertexai.api.Content} to send to the
* generative model
* @param config a {@link GenerateContentConfig} that contains all the configs in making a
* generate content api call
* @return a {@link ResponseStream} that contains a streaming of {@link
* com.google.cloud.vertexai.api.GenerateContentResponse}
* @throws IOException if an I/O error occurs while making the API call
*/
public ResponseStream<GenerateContentResponse> generateContentStream(
List<Content> contents, GenerateContentConfig config) throws IOException {
GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder().addAllContents(contents);
if (config.getGenerationConfig() != null) {
requestBuilder.setGenerationConfig(config.getGenerationConfig());
} else if (this.generationConfig != null) {
requestBuilder.setGenerationConfig(this.generationConfig);
}
if (config.getSafetySettings().isEmpty() == false) {
requestBuilder.addAllSafetySettings(config.getSafetySettings());
} else if (this.safetySettings != null) {
requestBuilder.addAllSafetySettings(this.safetySettings);
}
if (config.getTools().isEmpty() == false) {
requestBuilder.addAllTools(config.getTools());
} else if (this.tools != null) {
requestBuilder.addAllTools(this.tools);
}

return generateContentStream(requestBuilder);
}

/**
* A base generateContentStream method that will be used internally.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,4 +639,35 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
verify(mockServerStreamCallable).call(request.capture());
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void testGenerateContentStreamwithGenerateContentConfig() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
GenerateContentConfig config =
GenerateContentConfig.newBuilder()
.setGenerationConfig(GENERATION_CONFIG)
.setSafetySettings(safetySettings)
.setTools(tools)
.build();

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
when(mockServerStreamCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockServerStream);
when(mockServerStream.iterator()).thenReturn(mockServerStreamIterator);

ResponseStream unused = model.generateContentStream(TEXT, config);

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockServerStreamCallable).call(request.capture());

assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}
}

0 comments on commit ec9dd00

Please sign in to comment.