diff --git a/.gitattributes b/.gitattributes index c139e44b4dc..513c7ecbf03 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,91 @@ +# Source code +*.bash text eol=lf +*.bat text eol=crlf +*.cmd text eol=crlf +*.coffee text +*.css text diff=css eol=lf +*.htm text diff=html eol=lf +*.html text diff=html eol=lf +*.inc text +*.ini text +*.js text +*.json text eol=lf +*.jsx text +*.less text +*.ls text +*.map text -diff +*.od text +*.onlydata text +*.php text diff=php +*.pl text +*.ps1 text eol=crlf +*.py text diff=python eol=lf +*.rb text diff=ruby eol=lf +*.sass text +*.scm text +*.scss text diff=css +*.sh text eol=lf +.husky/* text eol=lf +*.sql text +*.styl text +*.tag text +*.ts text +*.tsx text +*.xml text +*.xhtml text diff=html + +# Docker +Dockerfile text eol=lf + +# Documentation +*.ipynb text +*.markdown text diff=markdown eol=lf +*.md text diff=markdown eol=lf +*.mdwn text diff=markdown eol=lf +*.mdown text diff=markdown eol=lf +*.mkd text diff=markdown eol=lf +*.mkdn text diff=markdown eol=lf +*.mdtxt text eol=lf +*.mdtext text eol=lf +*.txt text eol=lf +AUTHORS text eol=lf +CHANGELOG text eol=lf +CHANGES text eol=lf +CONTRIBUTING text eol=lf +COPYING text eol=lf +copyright text eol=lf +*COPYRIGHT* text eol=lf +INSTALL text eol=lf +license text eol=lf +LICENSE text eol=lf +NEWS text eol=lf +readme text eol=lf +*README* text eol=lf +TODO text + +# Configs +*.cnf text eol=lf +*.conf text eol=lf +*.config text eol=lf +.editorconfig text +.env text eol=lf +.gitattributes text eol=lf +.gitconfig text eol=lf +.htaccess text +*.lock text -diff +package.json text eol=lf +package-lock.json text eol=lf -diff +pnpm-lock.yaml text eol=lf -diff +.prettierrc text +yarn.lock text -diff +*.toml text eol=lf +*.yaml text eol=lf +*.yml text eol=lf +browserslist text +Makefile text eol=lf +makefile text eol=lf + +# Images *.png filter=lfs diff=lfs merge=lfs -text *.jpg filter=lfs diff=lfs merge=lfs -text *.jpeg filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 38fab877402..98a47e1e510 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -418,9 +418,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py index b5db55f7eb1..38507cb7998 100644 --- a/autogen/agentchat/contrib/vectordb/pgvectordb.py +++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py @@ -1,7 +1,7 @@ import os import re import urllib.parse -from typing import Callable, List +from typing import Callable, List, Optional, Union import numpy as np from sentence_transformers import SentenceTransformer @@ -231,7 +231,14 @@ def table_exists(self, table_name: str) -> bool: exists = cursor.fetchone()[0] return exists - def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> List[Document]: + def get( + self, + ids: Optional[str] = None, + include: Optional[str] = None, + where: Optional[str] = None, + limit: Optional[Union[int, str]] = None, + offset: Optional[Union[int, str]] = None, + ) -> List[Document]: """ Retrieve documents from the collection. @@ -272,7 +279,6 @@ def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> Li # Construct the full query query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}" - retrieved_documents = [] try: # Execute the query with the appropriate values @@ -380,11 +386,11 @@ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float: def query( self, query_texts: List[str], - collection_name: str = None, - n_results: int = 10, - distance_type: str = "euclidean", - distance_threshold: float = -1, - include_embedding: bool = False, + collection_name: Optional[str] = None, + n_results: Optional[int] = 10, + distance_type: Optional[str] = "euclidean", + distance_threshold: Optional[float] = -1, + include_embedding: Optional[bool] = False, ) -> QueryResults: """ Query documents in the collection. @@ -450,7 +456,7 @@ def query( return results @staticmethod - def convert_string_to_array(array_string) -> List[float]: + def convert_string_to_array(array_string: str) -> List[float]: """ Convert a string representation of an array to a list of floats. @@ -467,7 +473,7 @@ def convert_string_to_array(array_string) -> List[float]: array = [float(num) for num in array_string.split()] return array - def modify(self, metadata, collection_name: str = None) -> None: + def modify(self, metadata, collection_name: Optional[str] = None) -> None: """ Modify metadata for the collection. @@ -486,7 +492,7 @@ def modify(self, metadata, collection_name: str = None) -> None: ) cursor.close() - def delete(self, ids: List[ItemID], collection_name: str = None) -> None: + def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None: """ Delete documents from the collection. @@ -504,7 +510,7 @@ def delete(self, ids: List[ItemID], collection_name: str = None) -> None: cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids) cursor.close() - def delete_collection(self, collection_name: str = None) -> None: + def delete_collection(self, collection_name: Optional[str] = None) -> None: """ Delete the entire collection. @@ -520,7 +526,7 @@ def delete_collection(self, collection_name: str = None) -> None: cursor.execute(f"DROP TABLE IF EXISTS {self.name}") cursor.close() - def create_collection(self, collection_name: str = None) -> None: + def create_collection(self, collection_name: Optional[str] = None) -> None: """ Create a new collection. @@ -557,16 +563,17 @@ class PGVectorDB(VectorDB): def __init__( self, *, - connection_string: str = None, - host: str = None, - port: int = None, - dbname: str = None, - username: str = None, - password: str = None, - connect_timeout: int = 10, + conn: Optional[psycopg.Connection] = None, + connection_string: Optional[str] = None, + host: Optional[str] = None, + port: Optional[Union[int, str]] = None, + dbname: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + connect_timeout: Optional[int] = 10, embedding_function: Callable = None, - metadata: dict = None, - model_name: str = "all-MiniLM-L6-v2", + metadata: Optional[dict] = None, + model_name: Optional[str] = "all-MiniLM-L6-v2", ) -> None: """ Initialize the vector database. @@ -574,6 +581,9 @@ def __init__( Note: connection_string or host + port + dbname must be specified Args: + conn: psycopg.Connection | A customer connection object to connect to the database. + A connection object may include additional key/values: + https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None. host: str | The host to connect to. Default is None. port: int | The port to connect to. Default is None. @@ -593,46 +603,108 @@ def __init__( Returns: None """ + self.client = self.establish_connection( + conn=conn, + connection_string=connection_string, + host=host, + port=port, + dbname=dbname, + username=username, + password=password, + connect_timeout=connect_timeout, + ) + self.model_name = model_name try: - if connection_string: + self.embedding_function = ( + SentenceTransformer(self.model_name) if embedding_function is None else embedding_function + ) + except Exception as e: + logger.error( + f"Validate the model name entered: {self.model_name} " + f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}" + ) + raise e + self.metadata = metadata + register_vector(self.client) + self.active_collection = None + + def establish_connection( + self, + conn: Optional[psycopg.Connection] = None, + connection_string: Optional[str] = None, + host: Optional[str] = None, + port: Optional[Union[int, str]] = None, + dbname: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + connect_timeout: Optional[int] = 10, + ) -> psycopg.Connection: + """ + Establishes a connection to a PostgreSQL database using psycopg. + + Args: + conn: An existing psycopg connection object. If provided, this connection will be used. + connection_string: A string containing the connection information. If provided, a new connection will be established using this string. + host: The hostname of the PostgreSQL server. Used if connection_string is not provided. + port: The port number to connect to at the server host. Used if connection_string is not provided. + dbname: The database name. Used if connection_string is not provided. + username: The username to connect as. Used if connection_string is not provided. + password: The user's password. Used if connection_string is not provided. + connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds. + + Returns: + A psycopg.Connection object representing the established connection. + + Raises: + PermissionError if no credentials are supplied + psycopg.Error: If an error occurs while trying to connect to the database. + """ + try: + if conn: + self.client = conn + elif connection_string: parsed_connection = urllib.parse.urlparse(connection_string) encoded_username = urllib.parse.quote(parsed_connection.username, safe="") encoded_password = urllib.parse.quote(parsed_connection.password, safe="") + encoded_password = f":{encoded_password}@" encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="") + encoded_port = f":{parsed_connection.port}" encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="") connection_string_encoded = ( - f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}" - f"@{encoded_host}:{parsed_connection.port}/{encoded_database}" + f"{parsed_connection.scheme}://{encoded_username}{encoded_password}" + f"{encoded_host}{encoded_port}/{encoded_database}" ) self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True) - elif host and port and dbname: + elif host: + connection_string = "" + if host: + encoded_host = urllib.parse.quote(host, safe="") + connection_string += f"host={encoded_host} " + if port: + connection_string += f"port={port} " + if dbname: + encoded_database = urllib.parse.quote(dbname, safe="") + connection_string += f"dbname={encoded_database} " + if username: + encoded_username = urllib.parse.quote(username, safe="") + connection_string += f"user={encoded_username} " + if password: + encoded_password = urllib.parse.quote(password, safe="") + connection_string += f"password={encoded_password} " + self.client = psycopg.connect( - host=host, - port=port, - dbname=dbname, - username=username, - password=password, + conninfo=connection_string, connect_timeout=connect_timeout, autocommit=True, ) + else: + logger.error("Credentials were not supplied...") + raise PermissionError + self.client.execute("CREATE EXTENSION IF NOT EXISTS vector") except psycopg.Error as e: logger.error("Error connecting to the database: ", e) raise e - self.model_name = model_name - try: - self.embedding_function = ( - SentenceTransformer(self.model_name) if embedding_function is None else embedding_function - ) - except Exception as e: - logger.error( - f"Validate the model name entered: {self.model_name} " - f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}" - ) - raise e - self.metadata = metadata - self.client.execute("CREATE EXTENSION IF NOT EXISTS vector") - register_vector(self.client) - self.active_collection = None + return self.client def create_collection( self, collection_name: str, overwrite: bool = False, get_or_create: bool = True diff --git a/autogen/function_utils.py b/autogen/function_utils.py index dd225fd4719..6b9b6f5b129 100644 --- a/autogen/function_utils.py +++ b/autogen/function_utils.py @@ -353,4 +353,4 @@ def serialize_to_str(x: Any) -> str: elif isinstance(x, BaseModel): return model_dump_json(x) else: - return json.dumps(x) + return json.dumps(x, ensure_ascii=False) diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index be40e7b61b6..de2549cae13 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -44,8 +44,15 @@ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Ollama.Tests", "test\AutoGen.Ollama.Tests\AutoGen.Ollama.Tests.csproj", "{03E31CAA-3728-48D3-B936-9F11CF6C18FE}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Ollama.Sample", "sample\AutoGen.Ollama.Sample\AutoGen.Ollama.Sample.csproj", "{93AA4D0D-6EE4-44D5-AD77-7F73A3934544}" +EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.SemanticKernel.Sample", "sample\AutoGen.SemanticKernel.Sample\AutoGen.SemanticKernel.Sample.csproj", "{52958A60-3FF7-4243-9058-34A6E4F55C31}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Anthropic", "src\AutoGen.Anthropic\AutoGen.Anthropic.csproj", "{6A95E113-B824-4524-8F13-CD0C3E1C8804}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Anthropic.Tests", "test\AutoGen.Anthropic.Tests\AutoGen.Anthropic.Tests.csproj", "{815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Anthropic.Samples", "sample\AutoGen.Anthropic.Samples\AutoGen.Anthropic.Samples.csproj", "{834B4E85-64E5-4382-8465-548F332E5298}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -128,6 +135,18 @@ Global {52958A60-3FF7-4243-9058-34A6E4F55C31}.Debug|Any CPU.Build.0 = Debug|Any CPU {52958A60-3FF7-4243-9058-34A6E4F55C31}.Release|Any CPU.ActiveCfg = Release|Any CPU {52958A60-3FF7-4243-9058-34A6E4F55C31}.Release|Any CPU.Build.0 = Release|Any CPU + {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6A95E113-B824-4524-8F13-CD0C3E1C8804}.Release|Any CPU.Build.0 = Release|Any CPU + {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6}.Release|Any CPU.Build.0 = Release|Any CPU + {834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {834B4E85-64E5-4382-8465-548F332E5298}.Debug|Any CPU.Build.0 = Debug|Any CPU + {834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.ActiveCfg = Release|Any CPU + {834B4E85-64E5-4382-8465-548F332E5298}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -148,6 +167,9 @@ Global {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {6A95E113-B824-4524-8F13-CD0C3E1C8804} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {815E937E-86D6-4476-9EC6-B7FBCBBB5DB6} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} + {834B4E85-64E5-4382-8465-548F332E5298} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} {9F9E6DED-3D92-4970-909A-70FC11F1A665} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {03E31CAA-3728-48D3-B936-9F11CF6C18FE} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {93AA4D0D-6EE4-44D5-AD77-7F73A3934544} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs b/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs new file mode 100644 index 00000000000..94b5f37511e --- /dev/null +++ b/dotnet/sample/AutoGen.Anthropic.Samples/AnthropicSamples.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicSamples.cs + +using AutoGen.Anthropic.Extensions; +using AutoGen.Anthropic.Utils; +using AutoGen.Core; + +namespace AutoGen.Anthropic.Samples; + +public static class AnthropicSamples +{ + public static async Task RunAsync() + { + #region create_anthropic_agent + var apiKey = Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? throw new Exception("Missing ANTHROPIC_API_KEY environment variable."); + var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, apiKey); + var agent = new AnthropicClientAgent(anthropicClient, "assistant", AnthropicConstants.Claude3Haiku); + #endregion + + #region register_middleware + var agentWithConnector = agent + .RegisterMessageConnector() + .RegisterPrintMessage(); + #endregion register_middleware + + await agentWithConnector.SendAsync(new TextMessage(Role.Assistant, "Hello", from: "user")); + } +} diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj new file mode 100644 index 00000000000..33a5aa7f16b --- /dev/null +++ b/dotnet/sample/AutoGen.Anthropic.Samples/AutoGen.Anthropic.Samples.csproj @@ -0,0 +1,18 @@ + + + + Exe + $(TestTargetFramework) + enable + enable + True + + + + + + + + + + diff --git a/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs new file mode 100644 index 00000000000..f3c61508861 --- /dev/null +++ b/dotnet/sample/AutoGen.Anthropic.Samples/Program.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Program.cs + +namespace AutoGen.Anthropic.Samples; + +internal static class Program +{ + public static async Task Main(string[] args) + { + await AnthropicSamples.RunAsync(); + } +} diff --git a/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs new file mode 100644 index 00000000000..e395bb4a225 --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/Agent/AnthropicClientAgent.cs @@ -0,0 +1,91 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Anthropic.DTO; +using AutoGen.Core; + +namespace AutoGen.Anthropic; + +public class AnthropicClientAgent : IStreamingAgent +{ + private readonly AnthropicClient _anthropicClient; + public string Name { get; } + private readonly string _modelName; + private readonly string _systemMessage; + private readonly decimal _temperature; + private readonly int _maxTokens; + + public AnthropicClientAgent( + AnthropicClient anthropicClient, + string name, + string modelName, + string systemMessage = "You are a helpful AI assistant", + decimal temperature = 0.7m, + int maxTokens = 1024) + { + Name = name; + _anthropicClient = anthropicClient; + _modelName = modelName; + _systemMessage = systemMessage; + _temperature = temperature; + _maxTokens = maxTokens; + } + + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var response = await _anthropicClient.CreateChatCompletionsAsync(CreateParameters(messages, options, false), cancellationToken); + return new MessageEnvelope(response, from: this.Name); + } + + public async IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, + GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var message in _anthropicClient.StreamingChatCompletionsAsync( + CreateParameters(messages, options, true), cancellationToken)) + { + yield return new MessageEnvelope(message, from: this.Name); + } + } + + private ChatCompletionRequest CreateParameters(IEnumerable messages, GenerateReplyOptions? options, bool shouldStream) + { + var chatCompletionRequest = new ChatCompletionRequest() + { + SystemMessage = _systemMessage, + MaxTokens = options?.MaxToken ?? _maxTokens, + Model = _modelName, + Stream = shouldStream, + Temperature = (decimal?)options?.Temperature ?? _temperature, + }; + + chatCompletionRequest.Messages = BuildMessages(messages); + + return chatCompletionRequest; + } + + private List BuildMessages(IEnumerable messages) + { + List chatMessages = new(); + foreach (IMessage? message in messages) + { + switch (message) + { + case IMessage chatMessage when chatMessage.Content.Role == "system": + throw new InvalidOperationException( + "system message has already been set and only one system message is supported. \"system\" role for input messages in the Message"); + + case IMessage chatMessage: + chatMessages.Add(chatMessage.Content); + break; + + default: + throw new ArgumentException($"Unexpected message type: {message?.GetType()}"); + } + } + + return chatMessages; + } +} diff --git a/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs new file mode 100644 index 00000000000..8ea0bef86e2 --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/AnthropicClient.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicClient.cs + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Anthropic.Converters; +using AutoGen.Anthropic.DTO; + +namespace AutoGen.Anthropic; + +public sealed class AnthropicClient : IDisposable +{ + private readonly HttpClient _httpClient; + private readonly string _baseUrl; + + private static readonly JsonSerializerOptions JsonSerializerOptions = new() + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + private static readonly JsonSerializerOptions JsonDeserializerOptions = new() + { + Converters = { new ContentBaseConverter() } + }; + + public AnthropicClient(HttpClient httpClient, string baseUrl, string apiKey) + { + _httpClient = httpClient; + _baseUrl = baseUrl; + + _httpClient.DefaultRequestHeaders.Add("x-api-key", apiKey); + _httpClient.DefaultRequestHeaders.Add("anthropic-version", "2023-06-01"); + } + + public async Task CreateChatCompletionsAsync(ChatCompletionRequest chatCompletionRequest, + CancellationToken cancellationToken) + { + var httpResponseMessage = await SendRequestAsync(chatCompletionRequest, cancellationToken); + var responseStream = await httpResponseMessage.Content.ReadAsStreamAsync(); + + if (httpResponseMessage.IsSuccessStatusCode) + return await DeserializeResponseAsync(responseStream, cancellationToken); + + ErrorResponse res = await DeserializeResponseAsync(responseStream, cancellationToken); + throw new Exception(res.Error?.Message); + } + + public async IAsyncEnumerable StreamingChatCompletionsAsync( + ChatCompletionRequest chatCompletionRequest, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var httpResponseMessage = await SendRequestAsync(chatCompletionRequest, cancellationToken); + using var reader = new StreamReader(await httpResponseMessage.Content.ReadAsStreamAsync()); + + var currentEvent = new SseEvent(); + while (await reader.ReadLineAsync() is { } line) + { + if (!string.IsNullOrEmpty(line)) + { + currentEvent.Data = line.Substring("data:".Length).Trim(); + } + else + { + if (currentEvent.Data == "[DONE]") + continue; + + if (currentEvent.Data != null) + { + yield return await JsonSerializer.DeserializeAsync( + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), + cancellationToken: cancellationToken) ?? throw new Exception("Failed to deserialize response"); + } + else if (currentEvent.Data != null) + { + var res = await JsonSerializer.DeserializeAsync( + new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)), cancellationToken: cancellationToken); + + throw new Exception(res?.Error?.Message); + } + + // Reset the current event for the next one + currentEvent = new SseEvent(); + } + } + } + + private Task SendRequestAsync(T requestObject, CancellationToken cancellationToken) + { + var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _baseUrl); + var jsonRequest = JsonSerializer.Serialize(requestObject, JsonSerializerOptions); + httpRequestMessage.Content = new StringContent(jsonRequest, Encoding.UTF8, "application/json"); + return _httpClient.SendAsync(httpRequestMessage, cancellationToken); + } + + private async Task DeserializeResponseAsync(Stream responseStream, CancellationToken cancellationToken) + { + return await JsonSerializer.DeserializeAsync(responseStream, JsonDeserializerOptions, cancellationToken) + ?? throw new Exception("Failed to deserialize response"); + } + + public void Dispose() + { + _httpClient.Dispose(); + } + + private struct SseEvent + { + public string? Data { get; set; } + + public SseEvent(string? data = null) + { + Data = data; + } + } +} diff --git a/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj new file mode 100644 index 00000000000..fefc439e00b --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/AutoGen.Anthropic.csproj @@ -0,0 +1,22 @@ + + + + netstandard2.0 + AutoGen.Anthropic + + + + + + + AutoGen.Anthropic + + Provide support for consuming Anthropic models in AutoGen + + + + + + + + diff --git a/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs new file mode 100644 index 00000000000..281274048ed --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/Converters/ContentBaseConverter.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ContentConverter.cs + +using AutoGen.Anthropic.DTO; + +namespace AutoGen.Anthropic.Converters; + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +public sealed class ContentBaseConverter : JsonConverter +{ + public override ContentBase Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + using var doc = JsonDocument.ParseValue(ref reader); + if (doc.RootElement.TryGetProperty("type", out JsonElement typeProperty) && !string.IsNullOrEmpty(typeProperty.GetString())) + { + string? type = typeProperty.GetString(); + var text = doc.RootElement.GetRawText(); + switch (type) + { + case "text": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + case "image": + return JsonSerializer.Deserialize(text, options) ?? throw new InvalidOperationException(); + } + } + + throw new JsonException("Unknown content type"); + } + + public override void Write(Utf8JsonWriter writer, ContentBase value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value, value.GetType(), options); + } +} diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs new file mode 100644 index 00000000000..fa1654bc11d --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionRequest.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +using System.Text.Json.Serialization; + +namespace AutoGen.Anthropic.DTO; + +using System.Collections.Generic; + +public class ChatCompletionRequest +{ + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("messages")] + public List Messages { get; set; } + + [JsonPropertyName("system")] + public string? SystemMessage { get; set; } + + [JsonPropertyName("max_tokens")] + public int MaxTokens { get; set; } + + [JsonPropertyName("metadata")] + public object? Metadata { get; set; } + + [JsonPropertyName("stop_sequences")] + public string[]? StopSequences { get; set; } + + [JsonPropertyName("stream")] + public bool? Stream { get; set; } + + [JsonPropertyName("temperature")] + public decimal? Temperature { get; set; } + + [JsonPropertyName("top_k")] + public int? TopK { get; set; } + + [JsonPropertyName("top_p")] + public decimal? TopP { get; set; } + + public ChatCompletionRequest() + { + Messages = new List(); + } +} + +public class ChatMessage +{ + [JsonPropertyName("role")] + public string Role { get; set; } + + [JsonPropertyName("content")] + public string Content { get; set; } + + public ChatMessage(string role, string content) + { + Role = role; + Content = content; + } +} diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs new file mode 100644 index 00000000000..c6861f9c315 --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/DTO/ChatCompletionResponse.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +namespace AutoGen.Anthropic.DTO; + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +public class ChatCompletionResponse +{ + [JsonPropertyName("content")] + public List? Content { get; set; } + + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("role")] + public string? Role { get; set; } + + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } + + [JsonPropertyName("stop_sequence")] + public object? StopSequence { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("usage")] + public Usage? Usage { get; set; } + + [JsonPropertyName("delta")] + public Delta? Delta { get; set; } + + [JsonPropertyName("message")] + public StreamingMessage? streamingMessage { get; set; } +} + +public class StreamingMessage +{ + [JsonPropertyName("id")] + public string? Id { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("role")] + public string? Role { get; set; } + + [JsonPropertyName("content")] + public List? Content { get; set; } + + [JsonPropertyName("model")] + public string? Model { get; set; } + + [JsonPropertyName("stop_reason")] + public object? StopReason { get; set; } + + [JsonPropertyName("stop_sequence")] + public object? StopSequence { get; set; } + + [JsonPropertyName("usage")] + public Usage? Usage { get; set; } +} + +public class Usage +{ + [JsonPropertyName("input_tokens")] + public int InputTokens { get; set; } + + [JsonPropertyName("output_tokens")] + public int OutputTokens { get; set; } +} + +public class Delta +{ + [JsonPropertyName("stop_reason")] + public string? StopReason { get; set; } + + [JsonPropertyName("type")] + public string? Type { get; set; } + + [JsonPropertyName("text")] + public string? Text { get; set; } + + [JsonPropertyName("usage")] + public Usage? Usage { get; set; } +} diff --git a/dotnet/src/AutoGen.Anthropic/DTO/Content.cs b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs new file mode 100644 index 00000000000..dd2481bd58f --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/DTO/Content.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Content.cs + +using System.Text.Json.Serialization; + +namespace AutoGen.Anthropic.DTO; + +public abstract class ContentBase +{ + [JsonPropertyName("type")] + public abstract string Type { get; } +} + +public class TextContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "text"; + + [JsonPropertyName("text")] + public string? Text { get; set; } +} + +public class ImageContent : ContentBase +{ + [JsonPropertyName("type")] + public override string Type => "image"; + + [JsonPropertyName("source")] + public ImageSource? Source { get; set; } +} + +public class ImageSource +{ + [JsonPropertyName("type")] + public string Type => "base64"; + + [JsonPropertyName("media_type")] + public string? MediaType { get; set; } + + [JsonPropertyName("data")] + public string? Data { get; set; } +} diff --git a/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs new file mode 100644 index 00000000000..d02a8f6d1cf --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/DTO/ErrorResponse.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ErrorResponse.cs + +using System.Text.Json.Serialization; + +namespace AutoGen.Anthropic.DTO; + +public sealed class ErrorResponse +{ + [JsonPropertyName("error")] + public Error? Error { get; set; } +} + +public sealed class Error +{ + [JsonPropertyName("Type")] + public string? Type { get; set; } + + [JsonPropertyName("message")] + public string? Message { get; set; } +} diff --git a/dotnet/src/AutoGen.Anthropic/Extensions/AnthropicAgentExtension.cs b/dotnet/src/AutoGen.Anthropic/Extensions/AnthropicAgentExtension.cs new file mode 100644 index 00000000000..35ea8ed190a --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/Extensions/AnthropicAgentExtension.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicAgentExtension.cs + +using AutoGen.Anthropic.Middleware; +using AutoGen.Core; + +namespace AutoGen.Anthropic.Extensions; + +public static class AnthropicAgentExtension +{ + /// + /// Register an to the + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this AnthropicClientAgent agent, AnthropicMessageConnector? connector = null) + { + connector ??= new AnthropicMessageConnector(); + + return agent.RegisterStreamingMiddleware(connector); + } + + /// + /// Register an to the where T is + /// + /// the connector to use. If null, a new instance of will be created. + public static MiddlewareStreamingAgent RegisterMessageConnector( + this MiddlewareStreamingAgent agent, AnthropicMessageConnector? connector = null) + { + connector ??= new AnthropicMessageConnector(); + + return agent.RegisterStreamingMiddleware(connector); + } +} diff --git a/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs new file mode 100644 index 00000000000..bfe79190925 --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/Middleware/AnthropicMessageConnector.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicMessageConnector.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Anthropic.DTO; +using AutoGen.Core; + +namespace AutoGen.Anthropic.Middleware; + +public class AnthropicMessageConnector : IStreamingMiddleware +{ + public string? Name => nameof(AnthropicMessageConnector); + + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + { + var messages = context.Messages; + var chatMessages = ProcessMessage(messages, agent); + var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); + + return response is IMessage chatMessage + ? PostProcessMessage(chatMessage.Content, agent) + : response; + } + + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var messages = context.Messages; + var chatMessages = ProcessMessage(messages, agent); + + await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) + { + if (reply is IStreamingMessage chatMessage) + { + var response = ProcessChatCompletionResponse(chatMessage, agent); + if (response is not null) + { + yield return response; + } + } + else + { + yield return reply; + } + } + } + + private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage chatMessage, + IStreamingAgent agent) + { + Delta? delta = chatMessage.Content.Delta; + return delta != null && !string.IsNullOrEmpty(delta.Text) + ? new TextMessageUpdate(role: Role.Assistant, delta.Text, from: agent.Name) + : null; + } + + private IEnumerable ProcessMessage(IEnumerable messages, IAgent agent) + { + return messages.SelectMany(m => + { + return m switch + { + TextMessage textMessage => ProcessTextMessage(textMessage, agent), + _ => [m], + }; + }); + } + + private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from) + { + if (response.Content is null) + throw new ArgumentNullException(nameof(response.Content)); + + if (response.Content.Count != 1) + throw new NotSupportedException($"{nameof(response.Content)} != 1"); + + return new TextMessage(Role.Assistant, ((TextContent)response.Content[0]).Text ?? string.Empty, from: from.Name); + } + + private IEnumerable> ProcessTextMessage(TextMessage textMessage, IAgent agent) + { + IEnumerable messages; + + if (textMessage.From == agent.Name) + { + messages = [new ChatMessage( + "assistant", textMessage.Content)]; + } + else if (textMessage.From is null) + { + if (textMessage.Role == Role.User) + { + messages = [new ChatMessage( + "user", textMessage.Content)]; + } + else if (textMessage.Role == Role.Assistant) + { + messages = [new ChatMessage( + "assistant", textMessage.Content)]; + } + else if (textMessage.Role == Role.System) + { + messages = [new ChatMessage( + "system", textMessage.Content)]; + } + else + { + throw new NotSupportedException($"Role {textMessage.Role} is not supported"); + } + } + else + { + // if from is not null, then the message is from user + messages = [new ChatMessage( + "user", textMessage.Content)]; + } + + return messages.Select(m => new MessageEnvelope(m, from: textMessage.From)); + } +} diff --git a/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs b/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs new file mode 100644 index 00000000000..e70572cbddf --- /dev/null +++ b/dotnet/src/AutoGen.Anthropic/Utils/AnthropicConstants.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Constants.cs + +namespace AutoGen.Anthropic.Utils; + +public static class AnthropicConstants +{ + public static string Endpoint = "https://api.anthropic.com/v1/messages"; + + // Models + public static string Claude3Opus = "claude-3-opus-20240229"; + public static string Claude3Sonnet = "claude-3-sonnet-20240229"; + public static string Claude3Haiku = "claude-3-haiku-20240307"; +} diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs new file mode 100644 index 00000000000..ba31f2297ba --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientAgentTest.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicClientAgentTest.cs + +using AutoGen.Anthropic.Extensions; +using AutoGen.Anthropic.Utils; +using AutoGen.Tests; +using Xunit.Abstractions; + +namespace AutoGen.Anthropic; + +public class AnthropicClientAgentTest +{ + private readonly ITestOutputHelper _output; + + public AnthropicClientAgentTest(ITestOutputHelper output) => _output = output; + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicAgentChatCompletionTestAsync() + { + var client = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var agent = new AnthropicClientAgent( + client, + name: "AnthropicAgent", + AnthropicConstants.Claude3Haiku).RegisterMessageConnector(); + + var singleAgentTest = new SingleAgentTest(_output); + await singleAgentTest.UpperCaseTestAsync(agent); + await singleAgentTest.UpperCaseStreamingTestAsync(agent); + } +} diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs new file mode 100644 index 00000000000..0b64c9e4e3c --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicClientTest.cs @@ -0,0 +1,87 @@ +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using AutoGen.Anthropic.DTO; +using AutoGen.Anthropic.Utils; +using AutoGen.Tests; +using FluentAssertions; +using Xunit; + +namespace AutoGen.Anthropic; + +public class AnthropicClientTests +{ + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicClientChatCompletionTestAsync() + { + var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var request = new ChatCompletionRequest(); + request.Model = AnthropicConstants.Claude3Haiku; + request.Stream = false; + request.MaxTokens = 100; + request.Messages = new List() { new ChatMessage("user", "Hello world") }; + ChatCompletionResponse response = await anthropicClient.CreateChatCompletionsAsync(request, CancellationToken.None); + + Assert.NotNull(response); + Assert.NotNull(response.Content); + Assert.NotEmpty(response.Content); + response.Content.Count.Should().Be(1); + response.Content.First().Should().BeOfType(); + var textContent = (TextContent)response.Content.First(); + Assert.Equal("text", textContent.Type); + Assert.NotNull(response.Usage); + response.Usage.OutputTokens.Should().BeGreaterThan(0); + } + + [ApiKeyFact("ANTHROPIC_API_KEY")] + public async Task AnthropicClientStreamingChatCompletionTestAsync() + { + var anthropicClient = new AnthropicClient(new HttpClient(), AnthropicConstants.Endpoint, AnthropicTestUtils.ApiKey); + + var request = new ChatCompletionRequest(); + request.Model = AnthropicConstants.Claude3Haiku; + request.Stream = true; + request.MaxTokens = 500; + request.SystemMessage = "You are a helpful assistant that convert input to json object"; + request.Messages = new List() + { + new("user", "name: John, age: 41, email: g123456@gmail.com") + }; + + var response = anthropicClient.StreamingChatCompletionsAsync(request, CancellationToken.None); + var results = await response.ToListAsync(); + results.Count.Should().BeGreaterThan(0); + + // Merge the chunks. + StringBuilder sb = new(); + foreach (ChatCompletionResponse result in results) + { + if (result.Delta is not null && !string.IsNullOrEmpty(result.Delta.Text)) + sb.Append(result.Delta.Text); + } + + string resultContent = sb.ToString(); + Assert.NotNull(resultContent); + + var person = JsonSerializer.Deserialize(resultContent); + Assert.NotNull(person); + person.Name.Should().Be("John"); + person.Age.Should().Be(41); + person.Email.Should().Be("g123456@gmail.com"); + Assert.NotNull(results.First().streamingMessage); + results.First().streamingMessage!.Role.Should().Be("assistant"); + } + + private sealed class Person + { + [JsonPropertyName("name")] + public string Name { get; set; } = string.Empty; + + [JsonPropertyName("age")] + public int Age { get; set; } + + [JsonPropertyName("email")] + public string Email { get; set; } = string.Empty; + } +} diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs new file mode 100644 index 00000000000..a5b80eee3bd --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/AnthropicTestUtils.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AnthropicTestUtils.cs + +namespace AutoGen.Anthropic; + +public static class AnthropicTestUtils +{ + public static string ApiKey => Environment.GetEnvironmentVariable("ANTHROPIC_API_KEY") ?? + throw new Exception("Please set ANTHROPIC_API_KEY environment variable."); +} diff --git a/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj new file mode 100644 index 00000000000..8cd1e3003b0 --- /dev/null +++ b/dotnet/test/AutoGen.Anthropic.Tests/AutoGen.Anthropic.Tests.csproj @@ -0,0 +1,23 @@ + + + + $(TestTargetFramework) + enable + false + True + AutoGen.Anthropic.Tests + + + + + + + + + + + + + + + diff --git a/notebook/agentchat_pgvector_RetrieveChat.ipynb b/notebook/agentchat_pgvector_RetrieveChat.ipynb index 068ea55c7fc..9b037b7c468 100644 --- a/notebook/agentchat_pgvector_RetrieveChat.ipynb +++ b/notebook/agentchat_pgvector_RetrieveChat.ipynb @@ -40,17 +40,16 @@ "version: '3.9'\n", "\n", "services:\n", - " db:\n", - " hostname: db\n", - " image: ankane/pgvector\n", + " pgvector:\n", + " image: pgvector/pgvector:pg16\n", + " shm_size: 128mb\n", + " restart: unless-stopped\n", " ports:\n", - " - 5432:5432\n", - " restart: always\n", + " - \"5432:5432\"\n", " environment:\n", - " - POSTGRES_DB=postgres\n", - " - POSTGRES_USER=postgres\n", - " - POSTGRES_PASSWORD=postgres\n", - " - POSTGRES_HOST_AUTH_METHOD=trust\n", + " POSTGRES_USER: \n", + " POSTGRES_PASSWORD: \n", + " POSTGRES_DB: \n", " volumes:\n", " - ./init.sql:/docker-entrypoint-initdb.d/init.sql\n", "```\n", @@ -73,14 +72,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "models to use: ['Meta-Llama-3-8B-Instruct-imatrix', 'gpt-3.5-turbo-0125', 'gpt-35-turbo']\n" + "models to use: ['gpt-35-turbo', 'gpt4-1106-preview', 'gpt-35-turbo-0613']\n" ] } ], @@ -89,6 +88,7 @@ "import os\n", "\n", "import chromadb\n", + "import psycopg\n", "\n", "import autogen\n", "from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n", @@ -137,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -145,7 +145,7 @@ "output_type": "stream", "text": [ "Accepted file formats for `docs_path`:\n", - "['org', 'pdf', 'md', 'docx', 'epub', 'rst', 'rtf', 'xml', 'ppt', 'txt', 'jsonl', 'msg', 'htm', 'yaml', 'html', 'xlsx', 'log', 'yml', 'odt', 'tsv', 'doc', 'pptx', 'csv', 'json']\n" + "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" ] } ], @@ -156,15 +156,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/torch/cuda/__init__.py:141: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11060). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n", - " return torch._C._cuda_getDeviceCount() > 0\n" + "/workspace/anaconda3/envs/autogen/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/workspace/anaconda3/envs/autogen/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" ] } ], @@ -172,7 +174,7 @@ "# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n", "assistant = RetrieveAssistantAgent(\n", " name=\"assistant\",\n", - " system_message=\"You are a helpful assistant.\",\n", + " system_message=\"You are a helpful assistant. You must always reply with some form of text.\",\n", " llm_config={\n", " \"timeout\": 600,\n", " \"cache_seed\": 42,\n", @@ -180,6 +182,9 @@ " },\n", ")\n", "\n", + "# Optionally create psycopg conn object\n", + "# conn = psycopg.connect(conninfo=\"postgresql://postgres:postgres@localhost:5432/postgres\", autocommit=True)\n", + "\n", "# 2. create the RetrieveUserProxyAgent instance named \"ragproxyagent\"\n", "# By default, the human_input_mode is \"ALWAYS\", which means the agent will ask for human input at every step. We set it to \"NEVER\" here.\n", "# `docs_path` is the path to the docs directory. It can also be the path to a single file, or the url to a single file. By default,\n", @@ -208,12 +213,13 @@ " \"collection_name\": \"flaml_collection\",\n", " \"db_config\": {\n", " \"connection_string\": \"postgresql://postgres:postgres@localhost:5432/postgres\", # Optional - connect to an external vector database\n", - " # \"host\": postgres, # Optional vector database host\n", + " # \"host\": \"postgres\", # Optional vector database host\n", " # \"port\": 5432, # Optional vector database port\n", - " # \"database\": postgres, # Optional vector database name\n", - " # \"username\": postgres, # Optional vector database username\n", - " # \"password\": postgres, # Optional vector database password\n", + " # \"dbname\": \"postgres\", # Optional vector database name\n", + " # \"username\": \"postgres\", # Optional vector database username\n", + " # \"password\": \"postgres\", # Optional vector database password\n", " \"model_name\": \"all-MiniLM-L6-v2\", # Sentence embedding model from https://huggingface.co/models?library=sentence-transformers or https://www.sbert.net/docs/pretrained_models.html\n", + " # \"conn\": conn, # Optional - conn object to connect to database\n", " },\n", " \"get_or_create\": True, # set to False if you don't want to reuse an existing collection\n", " \"overwrite\": False, # set to True if you want to overwrite an existing collection\n", @@ -238,14 +244,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-25 11:23:53,000 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `flaml_collection`.\u001b[0m\n" + "2024-05-23 08:48:18,875 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `flaml_collection`.\u001b[0m\n" ] }, { @@ -259,7 +265,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-04-25 11:23:54,745 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n" + "2024-05-23 08:48:19,975 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n", + "2024-05-23 08:48:19,977 - autogen.agentchat.contrib.vectordb.pgvectordb - INFO - Error executing select on non-existent table: flaml_collection. Creating it instead. Error: relation \"flaml_collection\" does not exist\n", + "LINE 1: SELECT id, metadatas, documents, embedding FROM flaml_collec...\n", + " ^\u001b[0m\n", + "2024-05-23 08:48:19,996 - autogen.agentchat.contrib.vectordb.pgvectordb - INFO - Created table flaml_collection\u001b[0m\n" ] }, { @@ -794,60 +804,7 @@ "--------------------------------------------------------------------------------\n", "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", - "To use FLAML for a classification task and perform parallel training using Spark and train for 30 seconds while forcing cancel jobs if the time limit is reached, you can use the following code:\n", - "\n", - "```python\n", - "import flaml\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "from pyspark.ml.feature import VectorAssembler\n", - "\n", - "# load your classification dataset as a pandas DataFrame\n", - "dataframe = ...\n", - "\n", - "# convert the pandas DataFrame to a pandas-on-spark DataFrame\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "\n", - "# define the label column\n", - "label = ...\n", - "\n", - "# use VectorAssembler to merge all feature columns into a single vector column\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "\n", - "# configure the AutoML settings\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": 'accuracy',\n", - " \"task\": 'classification',\n", - " \"log_file_name\": 'classification.log',\n", - " \"estimator_list\": ['lgbm_spark'],\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True\n", - "}\n", - "\n", - "# create and run the AutoML experiment\n", - "automl = flaml.AutoML()\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings\n", - ")\n", - "```\n", - "\n", - "Note that you will need to replace the placeholders with your own dataset and label column names. This code will use FLAML's `lgbm_spark` estimator for training the classification model in parallel using Spark. The training will be restricted to 30 seconds, and if the time limit is reached, FLAML will force cancel the Spark jobs.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", - "UPDATE CONTEXT\n", + "To use FLAML to perform a classification task and use Spark to do parallel training, you need to use the Spark ML estimators for AutoML. First, you need to prepare your data in the required format as described in the previous section. FLAML provides a convenient function \"to_pandas_on_spark\" to convert your data into a pandas-on-spark dataframe/series, which Spark estimators require. After that, use the pandas-on-spark data like non-spark data and pass them using X_train, y_train or dataframe, label. Finally, configure FLAML to use Spark as the parallel backend during parallel tuning by setting the use_spark to true. An example code snippet is provided in the context above.\n", "\n", "--------------------------------------------------------------------------------\n" ] @@ -883,7 +840,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -1153,276 +1110,18 @@ "\n", "\n", "\n", - "--------------------------------------------------------------------------------\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", "\n", "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", "\n", "--------------------------------------------------------------------------------\n", - "\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n", - "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", - "\n", - "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", - "context provided by the user.\n", - "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", - "For code generation, you must obey the following rules:\n", - "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", - "Rule 2. You must follow the formats below to write your code:\n", - "```language\n", - "# your code\n", - "```\n", - "\n", - "User's question is: Who is the author of FLAML?\n", - "\n", - "Context is: # Research\n", - "\n", - "For technical details, please check our research publications.\n", - "\n", - "- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021flaml,\n", - " title={FLAML: A Fast and Lightweight AutoML Library},\n", - " author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n", - " year={2021},\n", - " booktitle={MLSys},\n", - "}\n", - "```\n", - "\n", - "- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021cfo,\n", - " title={Frugal Optimization for Cost-related Hyperparameters},\n", - " author={Qingyun Wu and Chi Wang and Silu Huang},\n", - " year={2021},\n", - " booktitle={AAAI},\n", - "}\n", - "```\n", - "\n", - "- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2021blendsearch,\n", - " title={Economical Hyperparameter Optimization With Blended Search Strategy},\n", - " author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n", - " year={2021},\n", - " booktitle={ICLR},\n", - "}\n", - "```\n", - "\n", - "- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{liuwang2021hpolm,\n", - " title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n", - " author={Susan Xueqing Liu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ACL},\n", - "}\n", - "```\n", - "\n", - "- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2021chacha,\n", - " title={ChaCha for Online AutoML},\n", - " author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n", - " year={2021},\n", - " booktitle={ICML},\n", - "}\n", - "```\n", - "\n", - "- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n", - "\n", - "```bibtex\n", - "@inproceedings{wuwang2021fairautoml,\n", - " title={Fair AutoML},\n", - " author={Qingyun Wu and Chi Wang},\n", - " year={2021},\n", - " booktitle={ArXiv preprint arXiv:2111.06495},\n", - "}\n", - "```\n", - "\n", - "- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n", - "\n", - "```bibtex\n", - "@inproceedings{kayaliwang2022default,\n", - " title={Mining Robust Default Configurations for Resource-constrained AutoML},\n", - " author={Moe Kayali and Chi Wang},\n", - " year={2022},\n", - " booktitle={ArXiv preprint arXiv:2202.09927},\n", - "}\n", - "```\n", - "\n", - "- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n", - "\n", - "```bibtex\n", - "@inproceedings{zhang2023targeted,\n", - " title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n", - " author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n", - " booktitle={International Conference on Learning Representations},\n", - " year={2023},\n", - " url={https://openreview.net/forum?id=0Ij9_q567Ma},\n", - "}\n", - "```\n", - "\n", - "- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wang2023EcoOptiGen,\n", - " title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n", - " author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2303.04673},\n", - "}\n", - "```\n", - "\n", - "- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n", - "\n", - "```bibtex\n", - "@inproceedings{wu2023empirical,\n", - " title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n", - " author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n", - " year={2023},\n", - " booktitle={ArXiv preprint arXiv:2306.01337},\n", - "}\n", - "```\n", - "# Integrate - Spark\n", - "\n", - "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n", - "\n", - "- Use Spark ML estimators for AutoML.\n", - "- Use Spark to run training in parallel spark jobs.\n", - "\n", - "## Spark ML Estimators\n", - "\n", - "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n", - "\n", - "### Data\n", - "\n", - "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n", - "\n", - "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n", - "\n", - "This function also accepts optional arguments `index_col` and `default_index_type`.\n", - "\n", - "- `index_col` is the column name to use as the index, default is None.\n", - "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n", - "\n", - "Here is an example code snippet for Spark Data:\n", - "\n", - "```python\n", - "import pandas as pd\n", - "from flaml.automl.spark.utils import to_pandas_on_spark\n", - "\n", - "# Creating a dictionary\n", - "data = {\n", - " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n", - " \"Age_Years\": [20, 15, 10, 7, 25],\n", - " \"Price\": [100000, 200000, 300000, 240000, 120000],\n", - "}\n", - "\n", - "# Creating a pandas DataFrame\n", - "dataframe = pd.DataFrame(data)\n", - "label = \"Price\"\n", - "\n", - "# Convert to pandas-on-spark dataframe\n", - "psdf = to_pandas_on_spark(dataframe)\n", - "```\n", - "\n", - "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n", - "\n", - "Here is an example of how to use it:\n", - "\n", - "```python\n", - "from pyspark.ml.feature import VectorAssembler\n", - "\n", - "columns = psdf.columns\n", - "feature_cols = [col for col in columns if col != label]\n", - "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", - "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n", - "```\n", - "\n", - "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n", - "\n", - "### Estimators\n", - "\n", - "#### Model List\n", - "\n", - "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n", - "\n", - "#### Usage\n", - "\n", - "First, prepare your data in the required format as described in the previous section.\n", - "\n", - "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n", - "\n", - "Here is an example code snippet using SparkML models in AutoML:\n", - "\n", - "```python\n", - "import flaml\n", - "\n", - "# prepare your data in pandas-on-spark format as we previously mentioned\n", - "\n", - "automl = flaml.AutoML()\n", - "settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n", - " \"task\": \"regression\",\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=psdf,\n", - " label=label,\n", - " **settings,\n", - ")\n", - "```\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n", - "\n", - "## Parallel Spark Jobs\n", - "\n", - "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n", - "\n", - "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n", - "\n", - "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n", - "\n", - "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n", - "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n", - "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n", - "\n", - "An example code snippet for using parallel Spark jobs:\n", - "\n", - "```python\n", - "import flaml\n", - "\n", - "automl_experiment = flaml.AutoML()\n", - "automl_settings = {\n", - " \"time_budget\": 30,\n", - " \"metric\": \"r2\",\n", - " \"task\": \"regression\",\n", - " \"n_concurrent_trials\": 2,\n", - " \"use_spark\": True,\n", - " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n", - "}\n", - "\n", - "automl.fit(\n", - " dataframe=dataframe,\n", - " label=label,\n", - " **automl_settings,\n", - ")\n", - "```\n", - "\n", - "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", - "\n", "The authors of FLAML are Chi Wang, Qingyun Wu, Markus Weimer, and Erkang Zhu.\n", "\n", "--------------------------------------------------------------------------------\n" @@ -1436,6 +1135,13 @@ "qa_problem = \"Who is the author of FLAML?\"\n", "chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1460,7 +1166,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.9" }, "skip_test": "Requires interactive usage" }, diff --git a/setup.py b/setup.py index a4fa4f63aa5..b3a86832750 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os +import platform import setuptools @@ -13,6 +14,9 @@ exec(fp.read(), version) __version__ = version["__version__"] + +current_os = platform.system() + install_requires = [ "openai>=1.3", "diskcache", @@ -25,6 +29,7 @@ # Disallowing 2.6.0 can be removed when this is fixed https://github.com/pydantic/pydantic/issues/8705 "pydantic>=1.10,<3,!=2.6.0", # could be both V1 and V2 "docker", + "packaging", ] jupyter_executor = [ @@ -45,6 +50,13 @@ "markdownify", ] +retrieve_chat_pgvector = [*retrieve_chat, "pgvector>=0.2.5"] + +if current_os in ["Windows", "Darwin"]: + retrieve_chat_pgvector.extend(["psycopg[binary]>=3.1.18"]) +elif current_os == "Linux": + retrieve_chat_pgvector.extend(["psycopg>=3.1.18"]) + extra_require = { "test": [ "ipykernel", @@ -59,11 +71,7 @@ "blendsearch": ["flaml[blendsearch]"], "mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"], "retrievechat": retrieve_chat, - "retrievechat-pgvector": [ - *retrieve_chat, - "pgvector>=0.2.5", - "psycopg>=3.1.18", - ], + "retrievechat-pgvector": retrieve_chat_pgvector, "retrievechat-qdrant": [ *retrieve_chat, "qdrant_client[fastembed]", diff --git a/test/agentchat/contrib/vectordb/test_pgvectordb.py b/test/agentchat/contrib/vectordb/test_pgvectordb.py index bcccef2abfe..d238b657cb9 100644 --- a/test/agentchat/contrib/vectordb/test_pgvectordb.py +++ b/test/agentchat/contrib/vectordb/test_pgvectordb.py @@ -1,5 +1,6 @@ import os import sys +import urllib.parse import pytest from conftest import reason @@ -8,6 +9,7 @@ try: import pgvector + import psycopg import sentence_transformers from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB @@ -24,12 +26,52 @@ reason=reason, ) def test_pgvector(): - # test create collection + # test db config db_config = { "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", } - db = PGVectorDB(connection_string=db_config["connection_string"]) + # test create collection with connection_string authentication + db = PGVectorDB( + connection_string=db_config["connection_string"], + ) + collection_name = "test_collection" + collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True) + assert collection.name == collection_name + + # test create collection with conn object authentication + parsed_connection = urllib.parse.urlparse(db_config["connection_string"]) + encoded_username = urllib.parse.quote(parsed_connection.username, safe="") + encoded_password = urllib.parse.quote(parsed_connection.password, safe="") + encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="") + encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="") + connection_string_encoded = ( + f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}" + f"@{encoded_host}:{parsed_connection.port}/{encoded_database}" + ) + conn = psycopg.connect(conninfo=connection_string_encoded, autocommit=True) + + db = PGVectorDB(conn=conn) + collection_name = "test_collection" + collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True) + assert collection.name == collection_name + + # test create collection with basic authentication + db_config = { + "username": "postgres", + "password": os.environ.get("POSTGRES_PASSWORD", default="postgres"), + "host": "localhost", + "port": 5432, + "dbname": "postgres", + } + + db = PGVectorDB( + username=db_config["username"], + password=db_config["password"], + port=db_config["port"], + host=db_config["host"], + dbname=db_config["dbname"], + ) collection_name = "test_collection" collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True) assert collection.name == collection_name diff --git a/test/test_function_utils.py b/test/test_function_utils.py index adddbf32e26..0475044b49f 100644 --- a/test/test_function_utils.py +++ b/test/test_function_utils.py @@ -391,6 +391,10 @@ async def f( assert actual[1] == "EUR" +def test_serialize_to_str_with_nonascii() -> None: + assert serialize_to_str("中文") == "中文" + + def test_serialize_to_json() -> None: assert serialize_to_str("abc") == "abc" assert serialize_to_str(123) == "123"