Skip to content

Commit

Permalink
add safety setting support for vertexai
Browse files Browse the repository at this point in the history
  • Loading branch information
luxzoli committed Jun 14, 2024
1 parent f8a2e02 commit e265cc2
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
30 changes: 28 additions & 2 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from __future__ import annotations

import base64
import logging
import os
import random
import re
Expand All @@ -43,6 +44,7 @@
import google.generativeai as genai
import requests
import vertexai
from flaml.automl.logger import logger_formatter
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from openai.types.chat import ChatCompletion
Expand All @@ -51,7 +53,12 @@
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

logger = logging.getLogger(__name__)


class GeminiClient:
Expand Down Expand Up @@ -166,6 +173,7 @@ def create(self, params: Dict) -> ChatCompletion:
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})
vertexai_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)

if stream:
warnings.warn(
Expand All @@ -181,7 +189,7 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings
)
else:
# we use chat model by default
Expand Down Expand Up @@ -218,7 +226,7 @@ def create(self, params: Dict) -> ChatCompletion:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings
)
else:
model = genai.GenerativeModel(
Expand Down Expand Up @@ -372,6 +380,24 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li

return rst

@staticmethod
def _to_vertexai_safety_settings(safety_settings):
vertexai_safety_settings = []
for safety_setting in safety_settings:
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
invalid_category = safety_setting["category"]
logger.error(f"Safety setting category {invalid_category} is invalid")
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
invalid_threshold = safety_setting["threshold"]
logger.error(f"Safety threshold {invalid_threshold} is invalid")
else:
vertexai_safety_setting = VertexAISafetySetting(
category=safety_setting["category"],
threshold=safety_setting["threshold"],
)
vertexai_safety_settings.append(vertexai_safety_setting)
return vertexai_safety_settings


def _to_pil(data: str) -> Image.Image:
"""
Expand Down
38 changes: 36 additions & 2 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

try:
from google.api_core.exceptions import InternalServerError
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

from autogen.oai.gemini import GeminiClient

Expand Down Expand Up @@ -52,9 +55,40 @@ def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"


@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(mock_genai, gemini_client):
def test_vertexai_safety_setting_conversion(gemini_client):
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
]
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
harm_categories = [
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
]
expected_safety_settings = [
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
for category in harm_categories
]

def compare_safety_settings(converted_safety_settings, expected_safety_settings):
for i, expected_setting in enumerate(expected_safety_settings):
converted_setting = converted_safety_settings[i]
yield expected_setting.to_dict() == converted_setting.to_dict()

assert len(converted_safety_settings) == len(
expected_safety_settings
), "The length of the safety settings is incorrect"
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
assert all(settings_comparison), "Converted safety settings are incorrect"


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(gemini_client):
messages = [
{"role": "system", "content": "You are my personal assistant."},
{"role": "model", "content": "How can I help you?"},
Expand Down

0 comments on commit e265cc2

Please sign in to comment.