From f1551c0f3e06f0a0c19ec02d7bbce9f0e81a123f Mon Sep 17 00:00:00 2001 From: Zoltan Lux Date: Thu, 13 Jun 2024 20:55:55 +0000 Subject: [PATCH] remove vertexai safety settings --- autogen/oai/gemini.py | 30 ++---------------------------- test/oai/test_gemini.py | 35 ----------------------------------- 2 files changed, 2 insertions(+), 63 deletions(-) diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index a19a2a3234b..30d4c3fe518 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -32,7 +32,6 @@ from __future__ import annotations import base64 -import logging import os import random import re @@ -44,7 +43,6 @@ import google.generativeai as genai import requests import vertexai -from flaml.automl.logger import logger_formatter from google.ai.generativelanguage import Content, Part from google.api_core.exceptions import InternalServerError from openai.types.chat import ChatCompletion @@ -53,12 +51,7 @@ from PIL import Image from vertexai.generative_models import Content as VertexAIContent from vertexai.generative_models import GenerativeModel -from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold -from vertexai.generative_models import HarmCategory as VertexAIHarmCategory from vertexai.generative_models import Part as VertexAIPart -from vertexai.generative_models import SafetySetting as VertexAISafetySetting - -logger = logging.getLogger(__name__) class GeminiClient: @@ -173,7 +166,6 @@ def create(self, params: Dict) -> ChatCompletion: if autogen_term in params } safety_settings = params.get("safety_settings", {}) - vertexai_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings) if stream: warnings.warn( @@ -189,7 +181,7 @@ def create(self, params: Dict) -> ChatCompletion: gemini_messages = self._oai_messages_to_gemini_messages(messages) if self.use_vertexai: model = GenerativeModel( - model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings + model_name, generation_config=generation_config, safety_settings=safety_settings ) else: # we use chat model by default @@ -226,7 +218,7 @@ def create(self, params: Dict) -> ChatCompletion: # B. handle the vision model if self.use_vertexai: model = GenerativeModel( - model_name, generation_config=generation_config, safety_settings=vertexai_safety_settings + model_name, generation_config=generation_config, safety_settings=safety_settings ) else: model = genai.GenerativeModel( @@ -380,24 +372,6 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li return rst - @staticmethod - def _to_vertexai_safety_settings(safety_settings): - vertexai_safety_settings = [] - for safety_setting in safety_settings: - if safety_setting["category"] not in VertexAIHarmCategory.__members__: - invalid_category = safety_setting["category"] - logger.error(f"Safety setting category {invalid_category} is invalid") - elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__: - invalid_threshold = safety_setting["threshold"] - logger.error(f"Safety threshold {invalid_threshold} is invalid") - else: - vertexai_safety_setting = VertexAISafetySetting( - category=safety_setting["category"], - threshold=safety_setting["threshold"], - ) - vertexai_safety_settings.append(vertexai_safety_setting) - return vertexai_safety_settings - def _to_pil(data: str) -> Image.Image: """ diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index a75a0586105..35f35620b8b 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -4,9 +4,6 @@ try: from google.api_core.exceptions import InternalServerError - from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold - from vertexai.generative_models import HarmCategory as VertexAIHarmCategory - from vertexai.generative_models import SafetySetting as VertexAISafetySetting from autogen.oai.gemini import GeminiClient @@ -55,38 +52,6 @@ def test_valid_initialization(gemini_client): assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set" -@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") -def test_vertexai_safety_setting_conversion(gemini_client): - safety_settings = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, - ] - converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings) - harm_categories = [ - VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT, - VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH, - VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ] - expected_safety_settings = [ - VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH) - for category in harm_categories - ] - - def compare_safety_settings(converted_safety_settings, expected_safety_settings): - for i, expected_setting in enumerate(expected_safety_settings): - converted_setting = converted_safety_settings[i] - yield expected_setting.to_dict() == converted_setting.to_dict() - - assert len(converted_safety_settings) == len( - expected_safety_settings - ), "The length of the safety settings is incorrect" - settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings) - assert all(settings_comparison), "Converted safety settings are incorrect" - - @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") def test_gemini_message_handling(gemini_client): messages = [