Skip to content

Commit

Permalink
remove vertexai safety settings
Browse files Browse the repository at this point in the history
  • Loading branch information
luxzoli committed Jun 14, 2024
1 parent e265cc2 commit f1551c0
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 63 deletions.
30 changes: 2 additions & 28 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from __future__ import annotations

import base64
import logging
import os
import random
import re
Expand All @@ -44,7 +43,6 @@
import google.generativeai as genai
import requests
import vertexai
from flaml.automl.logger import logger_formatter
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from openai.types.chat import ChatCompletion
Expand All @@ -53,12 +51,7 @@
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

logger = logging.getLogger(__name__)


class GeminiClient:
Expand Down Expand Up @@ -173,7 +166,6 @@ def create(self, params: Dict) -> ChatCompletion:
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})
vertexai_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)

if stream:
warnings.warn(
Expand All @@ -189,7 +181,7 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings
model_name, generation_config=generation_config, safety_settings=safety_settings
)
else:
# we use chat model by default
Expand Down Expand Up @@ -226,7 +218,7 @@ def create(self, params: Dict) -> ChatCompletion:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings
model_name, generation_config=generation_config, safety_settings=safety_settings
)
else:
model = genai.GenerativeModel(
Expand Down Expand Up @@ -380,24 +372,6 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li

return rst

@staticmethod
def _to_vertexai_safety_settings(safety_settings):
vertexai_safety_settings = []
for safety_setting in safety_settings:
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
invalid_category = safety_setting["category"]
logger.error(f"Safety setting category {invalid_category} is invalid")
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
invalid_threshold = safety_setting["threshold"]
logger.error(f"Safety threshold {invalid_threshold} is invalid")
else:
vertexai_safety_setting = VertexAISafetySetting(
category=safety_setting["category"],
threshold=safety_setting["threshold"],
)
vertexai_safety_settings.append(vertexai_safety_setting)
return vertexai_safety_settings


def _to_pil(data: str) -> Image.Image:
"""
Expand Down
35 changes: 0 additions & 35 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

try:
from google.api_core.exceptions import InternalServerError
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

from autogen.oai.gemini import GeminiClient

Expand Down Expand Up @@ -55,38 +52,6 @@ def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_safety_setting_conversion(gemini_client):
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
]
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
harm_categories = [
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
]
expected_safety_settings = [
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
for category in harm_categories
]

def compare_safety_settings(converted_safety_settings, expected_safety_settings):
for i, expected_setting in enumerate(expected_safety_settings):
converted_setting = converted_safety_settings[i]
yield expected_setting.to_dict() == converted_setting.to_dict()

assert len(converted_safety_settings) == len(
expected_safety_settings
), "The length of the safety settings is incorrect"
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
assert all(settings_comparison), "Converted safety settings are incorrect"


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(gemini_client):
messages = [
Expand Down

0 comments on commit f1551c0

Please sign in to comment.