diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 4431ee58b4f..8a75a1487ff 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -9,18 +9,24 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:3.10 # Update the OS and maybe install packages # ENV DEBIAN_FRONTEND=noninteractive + +# add git lhs to apt +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash + RUN apt-get update \ && apt-get upgrade -y \ - && apt-get -y install --no-install-recommends build-essential npm \ + && apt-get -y install --no-install-recommends build-essential npm git-lfs \ && apt-get autoremove -y \ && apt-get clean -y \ - && wget https://github.com/quarto-dev/quarto-cli/releases/download/v1.5.23/quarto-1.5.23-linux-amd64.deb \ - && dpkg -i quarto-1.5.23-linux-amd64.deb \ - && rm -rf /var/lib/apt/lists/* quarto-1.5.23-linux-amd64.deb + && arch=$(arch | sed s/aarch64/arm64/ | sed s/x86_64/amd64/) \ + && wget https://github.com/quarto-dev/quarto-cli/releases/download/v1.5.23/quarto-1.5.23-linux-${arch}.deb \ + && dpkg -i quarto-1.5.23-linux-${arch}.deb \ + && rm -rf /var/lib/apt/lists/* quarto-1.5.23-linux-${arch}.deb ENV DEBIAN_FRONTEND=dialog # For docs RUN npm install --global yarn +RUN pip install --upgrade pip RUN pip install pydoc-markdown RUN pip install pyyaml RUN pip install colored diff --git a/.devcontainer/dev/Dockerfile b/.devcontainer/dev/Dockerfile index 4749e41ba6d..04f4c54edf4 100644 --- a/.devcontainer/dev/Dockerfile +++ b/.devcontainer/dev/Dockerfile @@ -1,10 +1,13 @@ # Basic setup FROM python:3.11-slim-bookworm +# add git lhs to apt +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash + # Update and install necessary packages RUN apt-get update && apt-get -y update # added vim and nano for convenience -RUN apt-get install -y sudo git npm vim nano curl wget +RUN apt-get install -y sudo git npm vim nano curl wget git-lfs # Setup a non-root user 'autogen' with sudo access RUN adduser --disabled-password --gecos '' autogen @@ -44,6 +47,7 @@ ENV PATH="${PATH}:/home/autogen/quarto/quarto-1.5.23/bin/" EXPOSE 3000 # Pre-load popular Python packages +RUN pip install --upgrade pip RUN pip install numpy pandas matplotlib seaborn scikit-learn requests urllib3 nltk pillow pytest beautifulsoup4 # Set the default command to bash diff --git a/.devcontainer/full/Dockerfile b/.devcontainer/full/Dockerfile index 15122b2ac55..0787ad24027 100644 --- a/.devcontainer/full/Dockerfile +++ b/.devcontainer/full/Dockerfile @@ -1,9 +1,12 @@ FROM python:3.11-slim-bookworm +# add git lhs to apt +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash + # Update and install dependencies RUN apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - software-properties-common sudo\ + software-properties-common sudo git-lfs \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* diff --git a/.devcontainer/studio/Dockerfile b/.devcontainer/studio/Dockerfile index 5bf2d4c27d8..d612cea9dab 100644 --- a/.devcontainer/studio/Dockerfile +++ b/.devcontainer/studio/Dockerfile @@ -9,9 +9,13 @@ FROM mcr.microsoft.com/vscode/devcontainers/python:3.10 # Update the OS and maybe install packages # ENV DEBIAN_FRONTEND=noninteractive + +# add git lhs to apt +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash + RUN apt-get update \ && apt-get upgrade -y \ - && apt-get -y install --no-install-recommends build-essential npm \ + && apt-get -y install --no-install-recommends build-essential npm git-lfs \ && apt-get autoremove -y \ && apt-get clean -y \ && rm -rf /var/lib/apt/lists/* @@ -19,4 +23,5 @@ ENV DEBIAN_FRONTEND=dialog # For docs RUN npm install --global yarn +RUN pip install --upgrade pip RUN pip install pydoc-markdown diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000000..c139e44b4dc --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c02e7bc2ecc..1c32eee6036 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,9 +31,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -43,10 +43,10 @@ jobs: python -c "import autogen" pip install pytest mock - name: Install optional dependencies for code executors - # code executors auto skip without deps, so only run for python 3.11 + # code executors and udfs auto skip without deps, so only run for python 3.11 if: matrix.python-version == '3.11' run: | - pip install -e ".[jupyter-executor]" + pip install -e ".[jupyter-executor,test]" python -m ipykernel install --user --name python3 - name: Set AUTOGEN_USE_DOCKER based on OS shell: bash @@ -57,16 +57,16 @@ jobs: - name: Test with pytest skipping openai tests if: matrix.python-version != '3.10' && matrix.os == 'ubuntu-latest' run: | - pytest test --skip-openai + pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0 - name: Test with pytest skipping openai and docker tests if: matrix.python-version != '3.10' && matrix.os != 'ubuntu-latest' run: | - pytest test --skip-openai --skip-docker + pytest test --ignore=test/agentchat/contrib --skip-openai --skip-docker --durations=10 --durations-min=1.0 - name: Coverage if: matrix.python-version == '3.10' run: | - pip install -e .[test,redis] - coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai + pip install -e .[test,redis,websockets] + coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0 coverage xml - name: Upload coverage to Codecov if: matrix.python-version == '3.10' diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 6443aa62de6..c60a45b3ad1 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -27,11 +27,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -41,11 +41,15 @@ jobs: pip install -e . python -c "import autogen" pip install coverage pytest-asyncio + - name: Install PostgreSQL + run: | + sudo apt install postgresql -y + - name: Start PostgreSQL service + run: sudo service postgresql start - name: Install packages for test when needed run: | pip install docker - pip install qdrant_client[fastembed] - pip install -e .[retrievechat] + pip install -e .[retrievechat-qdrant,retrievechat-pgvector] - name: Coverage env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -53,7 +57,7 @@ jobs: AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | - coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py + coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat test/agentchat/contrib/test_pgvector_retrievechat.py::test_retrievechat coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 @@ -70,11 +74,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -111,11 +115,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -152,11 +156,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -190,11 +194,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -231,11 +235,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -270,11 +274,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -299,3 +303,77 @@ jobs: with: file: ./coverage.xml flags: unittests + ImageGen: + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.12"] + runs-on: ${{ matrix.os }} + environment: openai1 + steps: + # checkout to pr branch + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies + run: | + docker --version + python -m pip install --upgrade pip wheel + pip install -e .[lmm] + python -c "import autogen" + pip install coverage pytest + - name: Coverage + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + coverage run -a -m pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py + coverage xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests + + AgentOptimizer: + strategy: + matrix: + os: [ ubuntu-latest ] + python-version: [ "3.11" ] + runs-on: ${{ matrix.os }} + environment: openai1 + steps: + # checkout to pr branch + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies + run: | + docker --version + python -m pip install --upgrade pip wheel + pip install -e . + python -c "import autogen" + pip install coverage pytest + - name: Coverage + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} + OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} + run: | + coverage run -a -m pytest test/agentchat/contrib/test_agent_optimizer.py + coverage xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index c4d698655d9..4e042b458e0 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -15,8 +15,9 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} -permissions: {} - # actions: read +permissions: + {} + # actions: read # checks: read # contents: read # deployments: read @@ -29,9 +30,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -41,27 +42,30 @@ jobs: - name: Install qdrant_client when python-version is 3.10 if: matrix.python-version == '3.10' run: | - pip install qdrant_client[fastembed] - - name: Install unstructured when python-version is 3.9 and not windows - if: matrix.python-version == '3.9' && matrix.os != 'windows-2019' + pip install .[retrievechat-qdrant] + - name: Install unstructured when python-version is 3.9 and on linux + run: | + sudo apt-get update + sudo apt-get install -y tesseract-ocr poppler-utils + pip install unstructured[all-docs]==0.13.0 + - name: Install and Start PostgreSQL + runs-on: ubuntu-latest run: | - pip install unstructured[all-docs] - - name: Install packages and dependencies for RetrieveChat + sudo apt install postgresql -y + sudo service postgresql start + - name: Install packages and dependencies for PGVector run: | - pip install -e .[retrievechat] + pip install -e .[retrievechat-pgvector] - name: Set AUTOGEN_USE_DOCKER based on OS shell: bash run: | if [[ ${{ matrix.os }} != ubuntu-latest ]]; then echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV fi - - name: Test RetrieveChat - run: | - pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai - name: Coverage run: | pip install coverage>=5.3 - coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai + coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 @@ -77,9 +81,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.8"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -114,9 +118,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.10"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -151,9 +155,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -188,9 +192,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -225,9 +229,11 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + lfs: true - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -246,7 +252,45 @@ jobs: - name: Coverage run: | pip install coverage>=5.3 - coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai + coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai + coverage xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests + + GeminiTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest + - name: Install packages and dependencies for Gemini + run: | + pip install -e .[gemini,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + coverage run -a -m pytest test/oai/test_gemini.py --skip-openai coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 @@ -262,9 +306,9 @@ jobs: os: [ubuntu-latest, macos-latest, windows-2019] python-version: ["3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies for all tests @@ -290,3 +334,40 @@ jobs: with: file: ./coverage.xml flags: unittests + + TransformMessages: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.11"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest + - name: Install packages and dependencies for Transform Messages + run: | + pip install -e . + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pip install coverage>=5.3 + coverage run -a -m pytest test/agentchat/contrib/capabilities/test_transform_messages.py --skip-openai + coverage xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittest diff --git a/.github/workflows/deploy-website.yml b/.github/workflows/deploy-website.yml index 0f2702e5e98..8798fca7ca6 100644 --- a/.github/workflows/deploy-website.yml +++ b/.github/workflows/deploy-website.yml @@ -4,15 +4,15 @@ on: pull_request: branches: [main] path: - - 'autogen/*' - - 'website/*' - - '.github/workflows/deploy-website.yml' + - "autogen/*" + - "website/*" + - ".github/workflows/deploy-website.yml" push: branches: [main] path: - - 'autogen/*' - - 'website/*' - - '.github/workflows/deploy-website.yml' + - "autogen/*" + - "website/*" + - ".github/workflows/deploy-website.yml" workflow_dispatch: merge_group: types: [checks_requested] @@ -26,18 +26,22 @@ jobs: run: working-directory: website steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + lfs: true - uses: actions/setup-node@v4 with: node-version: 18.x - name: setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" - name: pydoc-markdown install run: | python -m pip install --upgrade pip pip install pydoc-markdown pyyaml termcolor + # Pin databind packages as version 4.5.0 is not compatible with pydoc-markdown. + pip install databind.core==4.4.2 databind.json==4.4.2 - name: pydoc-markdown run run: | pydoc-markdown @@ -69,18 +73,22 @@ jobs: run: working-directory: website steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + with: + lfs: true - uses: actions/setup-node@v4 with: node-version: 18.x - name: setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" - name: pydoc-markdown install run: | python -m pip install --upgrade pip pip install pydoc-markdown pyyaml termcolor + # Pin databind packages as version 4.5.0 is not compatible with pydoc-markdown. + pip install databind.core==4.4.2 databind.json==4.4.2 - name: pydoc-markdown run run: | pydoc-markdown diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml index 18031661758..d223fffd28b 100644 --- a/.github/workflows/dotnet-build.yml +++ b/.github/workflows/dotnet-build.yml @@ -28,9 +28,9 @@ jobs: run: working-directory: dotnet steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup .NET - uses: actions/setup-dotnet@v3 + uses: actions/setup-dotnet@v4 with: global-json-file: dotnet/global.json - name: Restore dependencies @@ -53,9 +53,9 @@ jobs: if: success() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dotnet') needs: build steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup .NET - uses: actions/setup-dotnet@v3 + uses: actions/setup-dotnet@v4 with: global-json-file: dotnet/global.json - name: Restore dependencies @@ -83,12 +83,12 @@ jobs: echo "ls output directory" ls -R ./output - name: Upload package - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: nightly path: ./dotnet/output/nightly - name: Upload package - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: release path: ./dotnet/output/release @@ -102,17 +102,17 @@ jobs: needs: openai-test steps: - name: Setup .NET - uses: actions/setup-dotnet@v3 + uses: actions/setup-dotnet@v4 with: dotnet-version: '6.0.x' source-url: https://devdiv.pkgs.visualstudio.com/DevDiv/_packaging/AutoGen/nuget/v3/index.json env: NUGET_AUTH_TOKEN: ${{ secrets.AZURE_DEVOPS_TOKEN }} - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4 with: name: nightly path: ./dotnet/output/nightly - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4 with: name: release path: ./dotnet/output/release diff --git a/.github/workflows/dotnet-release.yml b/.github/workflows/dotnet-release.yml new file mode 100644 index 00000000000..d66f21a6cd6 --- /dev/null +++ b/.github/workflows/dotnet-release.yml @@ -0,0 +1,69 @@ +# This workflow will build a .NET project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-net + +name: dotnet-release + +on: + workflow_dispatch: + push: + branches: + - dotnet/release + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} + cancel-in-progress: true + +permissions: + contents: read + packages: write + +jobs: + build: + name: Build and release + runs-on: ubuntu-latest + environment: dotnet + defaults: + run: + working-directory: dotnet + steps: + - uses: actions/checkout@v4 + - name: Setup .NET + uses: actions/setup-dotnet@v4 + with: + global-json-file: dotnet/global.json + - name: Restore dependencies + run: | + dotnet restore -bl + - name: Build + run: | + echo "Build AutoGen" + dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true + - name: Unit Test + run: dotnet test --no-build -bl --configuration Release + env: + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }} + AZURE_GPT_35_MODEL_ID: ${{ secrets.AZURE_GPT_35_MODEL_ID }} + OEPNAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + - name: Pack + run: | + echo "Create release build package" + dotnet pack --no-build --configuration Release --output './output/release' -bl + + echo "ls output directory" + ls -R ./output + - name: Publish package to Nuget + run: | + echo "Publish package to Nuget" + echo "ls output directory" + ls -R ./output/release + dotnet nuget push --api-key AzureArtifacts ./output/release/*.nupkg --skip-duplicate --api-key ${{ secrets.AUTOGEN_NUGET_API_KEY }} + - name: Tag commit + run: | + Write-Host "Tag commit" + # version = eng/MetaInfo.props.Project.PropertyGroup.VersionPrefix + $metaInfoContent = cat ./eng/MetaInfo.props + $version = $metaInfoContent | Select-String -Pattern "(.*)" | ForEach-Object { $_.Matches.Groups[1].Value } + git tag -a "$version" -m "AutoGen.Net release $version" + git push origin --tags + shell: pwsh \ No newline at end of file diff --git a/.github/workflows/lfs-check.yml b/.github/workflows/lfs-check.yml new file mode 100644 index 00000000000..e2bcfb5668e --- /dev/null +++ b/.github/workflows/lfs-check.yml @@ -0,0 +1,15 @@ +name: "Git LFS Check" + +on: pull_request +permissions: {} +jobs: + lfs-check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Check Git LFS files for consistency + run: | + git lfs fsck diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml index 2018aa8e6c0..d2780eea542 100644 --- a/.github/workflows/openai.yml +++ b/.github/workflows/openai.yml @@ -36,11 +36,11 @@ jobs: steps: # checkout to pr branch - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install packages and dependencies @@ -63,7 +63,7 @@ jobs: AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | - coverage run -a -m pytest test --ignore=test/agentchat/contrib + coverage run -a -m pytest test --ignore=test/agentchat/contrib --durations=10 --durations-min=1.0 coverage xml - name: Coverage and check notebook outputs if: matrix.python-version != '3.9' @@ -75,7 +75,7 @@ jobs: OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | pip install nbconvert nbformat ipykernel - coverage run -a -m pytest test/test_notebook.py + coverage run -a -m pytest test/test_notebook.py --durations=10 --durations-min=1.0 coverage xml cat "$(pwd)/test/executed_openai_notebook_output.txt" - name: Upload coverage to Codecov diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 18b23afd18e..8404de61154 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,13 +18,15 @@ jobs: pre-commit-check: runs-on: ubuntu-latest + env: + SKIP: "mypy" steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 - name: Set $PY environment variable run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> $GITHUB_ENV - - uses: actions/cache@v3 + - uses: actions/cache@v4 with: path: ~/.cache/pre-commit key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} - - uses: pre-commit/action@v3.0.0 + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4f57c10ef70..f2967c13f5f 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -22,9 +22,9 @@ jobs: environment: package steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 # - name: Cache conda - # uses: actions/cache@v3 + # uses: actions/cache@v4 # with: # path: ~/conda_pkgs_dir # key: conda-${{ matrix.os }}-python-${{ matrix.python-version }}-${{ hashFiles('environment.yml') }} diff --git a/.github/workflows/samples-tools-tests.yml b/.github/workflows/samples-tools-tests.yml new file mode 100644 index 00000000000..12c8de3b7af --- /dev/null +++ b/.github/workflows/samples-tools-tests.yml @@ -0,0 +1,46 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: SamplesToolsTests + +on: + pull_request: + branches: ["main"] + paths: + - "autogen/**" + - "samples/tools/**" + - ".github/workflows/samples-tools-tests.yml" + - "setup.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +permissions: {} +jobs: + SamplesToolsFineTuningTests: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ["3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install -e . + pip install pytest + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Test finetuning tools + run: | + pytest samples/tools/finetuning/tests/ diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml new file mode 100644 index 00000000000..f6896d1145d --- /dev/null +++ b/.github/workflows/type-check.yml @@ -0,0 +1,27 @@ +name: Type check +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: # Trigger the workflow on pull request or merge + pull_request: + merge_group: + types: [checks_requested] +defaults: + run: + shell: bash +permissions: {} +jobs: + type-check: + strategy: + fail-fast: true + matrix: + version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.version }} + # All additional modules should be defined in setup.py + - run: pip install ".[types]" + # Any additional configuration should be defined in pyproject.toml + - run: | + mypy diff --git a/.gitignore b/.gitignore index 25e88f30c77..4c925f739ec 100644 --- a/.gitignore +++ b/.gitignore @@ -172,11 +172,21 @@ test/my_tmp/* # Storage for the AgentEval output test/test_files/agenteval-in-out/out/ +# local cache or coding foler +local_cache/ +coding/ + # Files created by tests *tmp_code_* test/agentchat/test_agent_scripts/* # test cache .cache_test +.db +local_cache + notebook/result.png +samples/apps/autogen-studio/autogenstudio/models/test/ + +notebook/coding diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db89b0f034b..fcea09223c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,11 +4,11 @@ exclude: 'dotnet' ci: autofix_prs: true autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' - autoupdate_schedule: 'quarterly' + autoupdate_schedule: 'monthly' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-added-large-files - id: check-ast @@ -23,30 +23,47 @@ repos: - id: end-of-file-fixer - id: no-commit-to-branch - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 24.3.0 hooks: - id: black - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.261 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.4 hooks: - id: ruff - args: ["--fix"] + types_or: [ python, pyi, jupyter ] + args: ["--fix", "--ignore=E402"] - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: - id: codespell - args: ["-L", "ans,linar,nam,"] + args: ["-L", "ans,linar,nam,tread,ot,"] exclude: | (?x)^( pyproject.toml | website/static/img/ag.svg | website/yarn.lock | + website/docs/tutorial/code-executors.ipynb | + website/docs/topics/code-execution/custom-executor.ipynb | + website/docs/topics/non-openai-models/cloud-gemini.ipynb | notebook/.* )$ + # See https://jaredkhan.com/blog/mypy-pre-commit + - repo: local + hooks: + - id: mypy + name: mypy + entry: "./scripts/pre-commit-mypy-run.sh" + language: python + # use your preferred Python version + # language_version: python3.8 + additional_dependencies: [] + types: [python] + # use require_serial so that script + # is only called once per commit + require_serial: true + # Print the number of files as a sanity-check + verbose: true - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.1 + rev: 1.8.5 hooks: - - id: nbqa-ruff - # Don't require notebooks to have all imports at the top - args: ["--fix", "--ignore=E402"] - id: nbqa-black diff --git a/OAI_CONFIG_LIST_sample b/OAI_CONFIG_LIST_sample index ef027f815ba..9fc0dc803a0 100644 --- a/OAI_CONFIG_LIST_sample +++ b/OAI_CONFIG_LIST_sample @@ -5,7 +5,8 @@ [ { "model": "gpt-4", - "api_key": "" + "api_key": "", + "tags": ["gpt-4", "tool"] }, { "model": "", diff --git a/README.md b/README.md index c10862baaa4..dffa451db98 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@ + + [![PyPI version](https://badge.fury.io/py/pyautogen.svg)](https://badge.fury.io/py/pyautogen) [![Build](https://github.com/microsoft/autogen/actions/workflows/python-package.yml/badge.svg)](https://github.com/microsoft/autogen/actions/workflows/python-package.yml) ![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue) [![Downloads](https://static.pepy.tech/badge/pyautogen/week)](https://pepy.tech/project/pyautogen) -[![](https://img.shields.io/discord/1153072414184452236?logo=discord&style=flat)](https://discord.gg/pAbnFJrkgZ) +[![Discord](https://img.shields.io/discord/1153072414184452236?logo=discord&style=flat)](https://aka.ms/autogen-dc) [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40pyautogen)](https://twitter.com/pyautogen) @@ -12,28 +14,31 @@

--> +:fire: Apr 17, 2024: Andrew Ng cited AutoGen in [The Batch newsletter](https://www.deeplearning.ai/the-batch/issue-245/) and [What's next for AI agentic workflows](https://youtu.be/sal78ACtGTc?si=JduUzN_1kDnMq0vF) at Sequoia Capital's AI Ascent (Mar 26). + +:fire: Mar 3, 2024: What's new in AutoGen? πŸ“°[Blog](https://microsoft.github.io/autogen/blog/2024/03/03/AutoGen-Update); πŸ“Ί[Youtube](https://www.youtube.com/watch?v=j_mtwQiaLGU). -:fire: Jan 30: AutoGen is highlighted by Peter Lee in Microsoft Research Forum [Keynote](https://t.co/nUBSjPDjqD). +:fire: Mar 1, 2024: the first AutoGen multi-agent experiment on the challenging [GAIA](https://huggingface.co/spaces/gaia-benchmark/leaderboard) benchmark achieved the No. 1 accuracy in all the three levels. -:warning: Jan 23: **Breaking Change in Latest Release v0.2.8** `use_docker` defaults to `True` for code-execution. See [blog post](https://microsoft.github.io/autogen/blog/2024/01/23/Code-execution-in-docker) for details and [FAQ](https://microsoft.github.io/autogen/docs/FAQ#agents-are-throwing-due-to-docker-not-running-how-can-i-resolve-this) for troubleshooting any issues. +:tada: Jan 30, 2024: AutoGen is highlighted by Peter Lee in Microsoft Research Forum [Keynote](https://t.co/nUBSjPDjqD). -:fire: Dec 31: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023). +:tada: Dec 31, 2023: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023). -:fire: Nov 8: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff. +:tada: Nov 8, 2023: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff. -:fire: Nov 6: AutoGen is mentioned by Satya Nadella in a [fireside chat](https://youtu.be/0pLBvgYtv6U). +:tada: Nov 6, 2023: AutoGen is mentioned by Satya Nadella in a [fireside chat](https://youtu.be/0pLBvgYtv6U). -:fire: Nov 1: AutoGen is the top trending repo on GitHub in October 2023. +:tada: Nov 1, 2023: AutoGen is the top trending repo on GitHub in October 2023. -:tada: Oct 03: AutoGen spins off from FLAML on GitHub and has a major paper update (first version on Aug 16). +:tada: Oct 03, 2023: AutoGen spins off from FLAML on GitHub and has a major paper update (first version on Aug 16). -:tada: Mar 29: AutoGen is first created in [FLAML](https://github.com/microsoft/FLAML). +:tada: Mar 29, 2023: AutoGen is first created in [FLAML](https://github.com/microsoft/FLAML). +

+ + ↑ Back to Top ↑ + +

+ ## What is AutoGen AutoGen is a framework that enables the development of LLM applications using multiple agents that can converse with each other to solve tasks. AutoGen agents are customizable, conversable, and seamlessly allow human participation. They can operate in various modes that employ combinations of LLMs, human inputs, and tools. @@ -56,6 +67,23 @@ AutoGen is a framework that enables the development of LLM applications using mu AutoGen is powered by collaborative [research studies](https://microsoft.github.io/autogen/docs/Research) from Microsoft, Penn State University, and the University of Washington. +

+ + ↑ Back to Top ↑ + +

+ +## Roadmaps + +To see what we are working on and what we plan to work on, please check our +[Roadmap Issues](https://aka.ms/autogen-roadmap). + +

+ + ↑ Back to Top ↑ + +

+ ## Quickstart The easiest way to start playing is 1. Click below to use the GitHub Codespace @@ -66,10 +94,17 @@ The easiest way to start playing is 3. Start playing with the notebooks! *NOTE*: OAI_CONFIG_LIST_sample lists GPT-4 as the default model, as this represents our current recommendation, and is known to work well with AutoGen. If you use a model other than GPT-4, you may need to revise various system prompts (especially if using weaker models like GPT-3.5-turbo). Moreover, if you use models other than those hosted by OpenAI or Azure, you may incur additional risks related to alignment and safety. Proceed with caution if updating this default. + +

+ + ↑ Back to Top ↑ + +

+ ## [Installation](https://microsoft.github.io/autogen/docs/Installation) ### Option 1. Install and Run AutoGen in Docker -Find detailed instructions for users [here](https://microsoft.github.io/autogen/docs/Installation#option-1-install-and-run-autogen-in-docker), and for developers [here](https://microsoft.github.io/autogen/docs/Contribute#docker-for-development). +Find detailed instructions for users [here](https://microsoft.github.io/autogen/docs/installation/Docker#step-1-install-docker), and for developers [here](https://microsoft.github.io/autogen/docs/Contribute#docker-for-development). ### Option 2. Install AutoGen Locally @@ -94,6 +129,12 @@ Even if you are installing and running AutoGen locally outside of docker, the re For LLM inference configurations, check the [FAQs](https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints). +

+ + ↑ Back to Top ↑ + +

+ ## Multi-Agent Conversation Framework Autogen enables the next-gen LLM applications with a generic [multi-agent conversation](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat) framework. It offers customizable and conversable agents that integrate LLMs, tools, and humans. @@ -133,6 +174,12 @@ The figure below shows an example conversation flow with AutoGen. Alternatively, the [sample code](https://github.com/microsoft/autogen/blob/main/samples/simple_chat.py) here allows a user to chat with an AutoGen agent in ChatGPT style. Please find more [code examples](https://microsoft.github.io/autogen/docs/Examples#automated-multi-agent-chat) for this feature. +

+ + ↑ Back to Top ↑ + +

+ ## Enhanced LLM Inferences Autogen also helps maximize the utility out of the expensive LLMs such as ChatGPT and GPT-4. It offers [enhanced LLM inference](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification) with powerful functionalities like caching, error handling, multi-config inference and templating. @@ -156,6 +203,12 @@ response = autogen.Completion.create(context=test_instance, **config) Please find more [code examples](https://microsoft.github.io/autogen/docs/Examples#tune-gpt-models) for this feature. --> +

+ + ↑ Back to Top ↑ + +

+ ## Documentation You can find detailed documentation about AutoGen [here](https://microsoft.github.io/autogen/). @@ -164,12 +217,18 @@ In addition, you can find: - [Research](https://microsoft.github.io/autogen/docs/Research), [blogposts](https://microsoft.github.io/autogen/blog) around AutoGen, and [Transparency FAQs](https://github.com/microsoft/autogen/blob/main/TRANSPARENCY_FAQS.md) -- [Discord](https://discord.gg/pAbnFJrkgZ) +- [Discord](https://aka.ms/autogen-dc) - [Contributing guide](https://microsoft.github.io/autogen/docs/Contribute) - [Roadmap](https://github.com/orgs/microsoft/projects/989/views/3) +

+ + ↑ Back to Top ↑ + +

+ ## Related Papers [AutoGen](https://arxiv.org/abs/2308.08155) @@ -207,6 +266,12 @@ In addition, you can find: } ``` +

+ + ↑ Back to Top ↑ + +

+ ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a @@ -223,11 +288,23 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. +

+ + ↑ Back to Top ↑ + +

+ ## Contributors Wall - + +

+ + ↑ Back to Top ↑ + +

+ # Legal Notices Microsoft and any contributors grant you a license to the Microsoft documentation and other content @@ -244,3 +321,9 @@ Privacy information can be found at https://privacy.microsoft.com/en-us/ Microsoft and any contributors reserve all other rights, whether under their respective copyrights, patents, or trademarks, whether by implication, estoppel, or otherwise. + +

+ + ↑ Back to Top ↑ + +

diff --git a/autogen/__init__.py b/autogen/__init__.py index ba920c92e46..02f956c4bcf 100644 --- a/autogen/__init__.py +++ b/autogen/__init__.py @@ -1,10 +1,10 @@ import logging -from .version import __version__ -from .oai import * + from .agentchat import * -from .exception_utils import * from .code_utils import DEFAULT_MODEL, FAST_MODEL - +from .exception_utils import * +from .oai import * +from .version import __version__ # Set the root logger. logger = logging.getLogger(__name__) diff --git a/autogen/_pydantic.py b/autogen/_pydantic.py index 89dbc4fd291..9a37208c406 100644 --- a/autogen/_pydantic.py +++ b/autogen/_pydantic.py @@ -13,7 +13,7 @@ from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref from pydantic.json_schema import JsonSchemaValue - def type2schema(t: Optional[Type]) -> JsonSchemaValue: + def type2schema(t: Any) -> JsonSchemaValue: """Convert a type to a JSON schema Args: @@ -51,11 +51,11 @@ def model_dump_json(model: BaseModel) -> str: # Remove this once we drop support for pydantic 1.x else: # pragma: no cover from pydantic import schema_of - from pydantic.typing import evaluate_forwardref as evaluate_forwardref + from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef] - JsonSchemaValue = Dict[str, Any] + JsonSchemaValue = Dict[str, Any] # type: ignore[misc] - def type2schema(t: Optional[Type]) -> JsonSchemaValue: + def type2schema(t: Any) -> JsonSchemaValue: """Convert a type to a JSON schema Args: diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 817fc8abdb6..d31a59d98fb 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -1,9 +1,9 @@ from .agent import Agent from .assistant_agent import AssistantAgent +from .chat import ChatResult, initiate_chats from .conversable_agent import ConversableAgent, register_function from .groupchat import GroupChat, GroupChatManager from .user_proxy_agent import UserProxyAgent -from .chat import initiate_chats, ChatResult from .utils import gather_usage_summary __all__ = ( diff --git a/autogen/agentchat/assistant_agent.py b/autogen/agentchat/assistant_agent.py index 25f7edbf073..b5ec7de90c7 100644 --- a/autogen/agentchat/assistant_agent.py +++ b/autogen/agentchat/assistant_agent.py @@ -1,7 +1,8 @@ from typing import Callable, Dict, Literal, Optional, Union +from autogen.runtime_logging import log_new_agent, logging_enabled + from .conversable_agent import ConversableAgent -from autogen.runtime_logging import logging_enabled, log_new_agent class AssistantAgent(ConversableAgent): diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py index 8c9ed1ee2ef..b527f8e0bae 100644 --- a/autogen/agentchat/chat.py +++ b/autogen/agentchat/chat.py @@ -1,12 +1,15 @@ import asyncio +import datetime import logging -from collections import defaultdict -from typing import Dict, List, Any, Set, Tuple -from dataclasses import dataclass import warnings -from termcolor import colored -from .utils import consolidate_chat_info +from collections import abc, defaultdict +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Set, Tuple +from ..formatting_utils import colored +from ..io.base import IOStream +from .utils import consolidate_chat_info logger = logging.getLogger(__name__) Prerequisite = Tuple[int, int] @@ -22,8 +25,12 @@ class ChatResult: """The chat history.""" summary: str = None """A summary obtained from the chat.""" - cost: tuple = None # (dict, dict) - (total_cost, actual_cost_with_cache) - """The cost of the chat. a tuple of (total_cost, total_actual_cost), where total_cost is a dictionary of cost information, and total_actual_cost is a dictionary of information on the actual incurred cost with cache.""" + cost: Dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference" + """The cost of the chat. + The value for each usage type is a dictionary containing cost information for that specific type. + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". + """ human_input: List[str] = None """A list of human input solicited during the chat.""" @@ -100,7 +107,9 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite return chat_order -def __post_carryover_processing(chat_info: Dict[str, Any]): +def __post_carryover_processing(chat_info: Dict[str, Any]) -> None: + iostream = IOStream.get_default() + if "message" not in chat_info: warnings.warn( "message is not provided in a chat_queue entry. input() will be called to get the initial message.", @@ -111,57 +120,60 @@ def __post_carryover_processing(chat_info: Dict[str, Any]): if isinstance(chat_info["carryover"], list) else chat_info["carryover"] ) - print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") - print( + message = chat_info.get("message") + if isinstance(message, str): + print_message = message + elif callable(message): + print_message = "Callable: " + message.__name__ + elif isinstance(message, dict): + print_message = "Dict: " + str(message) + elif message is None: + print_message = "None" + iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") + iostream.print( colored( - "Start a new chat with the following message: \n" - + chat_info.get("message") - + "\n\nWith the following carryover: \n" - + print_carryover, + "Starting a new chat....", "blue", ), flush=True, ) - print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") + if chat_info.get("verbose", False): + iostream.print(colored("Message:\n" + print_message, "blue"), flush=True) + iostream.print(colored("Carryover:\n" + print_carryover, "blue"), flush=True) + iostream.print(colored("\n" + "*" * 80, "blue"), flush=True, sep="") def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: """Initiate a list of chats. - Args: - chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. - Each dictionary should contain the input arguments for `ConversableAgent.initiate_chat`. - More specifically, each dictionary could include the following fields: - - recipient: the recipient agent. - - "sender": the sender agent. - - "recipient": the recipient agent. - - clear_history (bool): whether to clear the chat history with the agent. Default is True. - - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. - - cache (Cache or None): the cache client to be used for this conversation. Default is None. - - max_turns (int or None): the maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None. - - "message" needs to be provided if the `generate_init_message` method is not overridden. - Otherwise, input() will be called to get the initial message. - - "summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". - - Supported string are "last_msg" and "reflection_with_llm": - when set "last_msg", it returns the last message of the dialog as the summary. - when set "reflection_with_llm", it returns a summary extracted using an llm client. - `llm_config` must be set in either the recipient or sender. - "reflection_with_llm" requires the llm_config to be set in either the sender or the recipient. - - A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g, - ```python - def my_summary_method( - sender: ConversableAgent, - recipient: ConversableAgent, - ): - return recipient.last_message(sender)["content"] - ``` - - "summary_prompt": This filed can be used to specify the prompt used to extract a summary when summary_method is "reflection_with_llm". - Default is None and the following default prompt will be used when "summary_method" is set to "reflection_with_llm": - "Identify and extract the final solution to the originally asked question based on the conversation." - - "carryover": It can be used to specify the carryover information to be passed to this chat. - If provided, we will combine this carryover with the "message" content when generating the initial chat - message in `generate_init_message`. + chat_queue (List[Dict]): A list of dictionaries containing the information about the chats. + Each dictionary should contain the input arguments for + [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat). + For example: + - `"sender"` - the sender agent. + - `"recipient"` - the recipient agent. + - `"clear_history" (bool) - whether to clear the chat history with the agent. + Default is True. + - `"silent"` (bool or None) - (Experimental) whether to print the messages in this + conversation. Default is False. + - `"cache"` (Cache or None) - the cache client to use for this conversation. + Default is None. + - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat + will continue until a termination condition is met. Default is None. + - `"summary_method"` (str or callable) - a string or callable specifying the method to get + a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". + - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method. + Default is {}. + - `"message"` (str, callable or None) - if None, input() will be called to get the + initial message. + - `**context` - additional context information to be passed to the chat. + - `"carryover"` - It can be used to specify the carryover information to be passed + to this chat. If provided, we will combine this carryover with the "message" content when + generating the initial chat message in `generate_init_message`. + - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list, + from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list, + then summary from all the finished chats will be taken. Returns: (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue. """ @@ -173,9 +185,16 @@ def my_summary_method( while current_chat_queue: chat_info = current_chat_queue.pop(0) _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + if isinstance(_chat_carryover, str): _chat_carryover = [_chat_carryover] - chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats] + chat_info["carryover"] = _chat_carryover + [ + r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover + ] + __post_carryover_processing(chat_info) sender = chat_info["sender"] chat_res = sender.initiate_chat(**chat_info) @@ -183,41 +202,78 @@ def my_summary_method( return finished_chats +def __system_now_str(): + ct = datetime.datetime.now() + return f" System time at {ct}. " + + +def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int): + """ + Update ChatResult when async Task for Chat is completed. + """ + logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str()) + chat_result = chat_future.result() + chat_result.chat_id = chat_id + + +async def _dependent_chat_future( + chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future] +) -> asyncio.Task: + """ + Create an async Task for each chat. + """ + logger.debug(f"Create Task for chat {chat_id}." + __system_now_str()) + _chat_carryover = chat_info.get("carryover", []) + finished_chats = dict() + for chat in prerequisite_chat_futures: + chat_future = prerequisite_chat_futures[chat] + if chat_future.cancelled(): + raise RuntimeError(f"Chat {chat} is cancelled.") + + # wait for prerequisite chat results for the new chat carryover + finished_chats[chat] = await chat_future + + if isinstance(_chat_carryover, str): + _chat_carryover = [_chat_carryover] + chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats] + __post_carryover_processing(chat_info) + sender = chat_info["sender"] + chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info)) + call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id) + chat_res_future.add_done_callback(call_back_with_args) + logger.debug(f"Task for chat {chat_id} created." + __system_now_str()) + return chat_res_future + + async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: """(async) Initiate a list of chats. args: - Please refer to `initiate_chats`. + - Please refer to `initiate_chats`. returns: - (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue. + - (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue. """ - consolidate_chat_info(chat_queue) _validate_recipients(chat_queue) chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue} num_chats = chat_book.keys() prerequisites = __create_async_prerequisites(chat_queue) chat_order_by_id = __find_async_chat_order(num_chats, prerequisites) - finished_chats = dict() + finished_chat_futures = dict() for chat_id in chat_order_by_id: chat_info = chat_book[chat_id] - condition = asyncio.Condition() prerequisite_chat_ids = chat_info.get("prerequisites", []) - async with condition: - await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids])) - # Do the actual work here. - _chat_carryover = chat_info.get("carryover", []) - if isinstance(_chat_carryover, str): - _chat_carryover = [_chat_carryover] - chat_info["carryover"] = _chat_carryover + [ - finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids - ] - __post_carryover_processing(chat_info) - sender = chat_info["sender"] - chat_res = await sender.a_initiate_chat(**chat_info) - chat_res.chat_id = chat_id - finished_chats[chat_id] = chat_res - + pre_chat_futures = dict() + for pre_chat_id in prerequisite_chat_ids: + pre_chat_future = finished_chat_futures[pre_chat_id] + pre_chat_futures[pre_chat_id] = pre_chat_future + current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures) + finished_chat_futures[chat_id] = current_chat_future + await asyncio.gather(*list(finished_chat_futures.values())) + finished_chats = dict() + for chat in finished_chat_futures: + chat_result = finished_chat_futures[chat].result() + finished_chats[chat] = chat_result return finished_chats diff --git a/autogen/agentchat/contrib/agent_builder.py b/autogen/agentchat/contrib/agent_builder.py index 7a3850d79ae..272d954ff27 100644 --- a/autogen/agentchat/contrib/agent_builder.py +++ b/autogen/agentchat/contrib/agent_builder.py @@ -1,10 +1,11 @@ -import autogen -import time -import subprocess as sp -import socket -import json import hashlib -from typing import Optional, List, Dict, Tuple +import json +import socket +import subprocess as sp +import time +from typing import Dict, List, Optional, Tuple + +import autogen def _config_check(config: Dict): @@ -202,9 +203,6 @@ def _create_agent( Returns: agent: a set-up agent. """ - from huggingface_hub import HfApi - from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - config_list = autogen.config_list_from_json( self.config_file_or_env, file_location=self.config_file_location, @@ -217,10 +215,15 @@ def _create_agent( f"If you load configs from json, make sure the model in agent_configs is in the {self.config_file_or_env}." ) try: + from huggingface_hub import HfApi + from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + hf_api = HfApi() hf_api.model_info(model_name_or_hf_repo) model_name = model_name_or_hf_repo.split("/")[-1] server_id = f"{model_name}_{self.host}" + except ImportError: + server_id = self.online_server_name except GatedRepoError as e: raise e except RepositoryNotFoundError: @@ -494,9 +497,6 @@ def build_from_library( agent_list: a list of agents. cached_configs: cached configs. """ - import chromadb - from chromadb.utils import embedding_functions - if code_execution_config is None: code_execution_config = { "last_n_messages": 2, @@ -527,6 +527,9 @@ def build_from_library( print("==> Looking for suitable agents in library...") if embedding_model is not None: + import chromadb + from chromadb.utils import embedding_functions + chroma_client = chromadb.Client() collection = chroma_client.create_collection( name="agent_list", diff --git a/autogen/agentchat/contrib/agent_optimizer.py b/autogen/agentchat/contrib/agent_optimizer.py new file mode 100644 index 00000000000..af264d4b65f --- /dev/null +++ b/autogen/agentchat/contrib/agent_optimizer.py @@ -0,0 +1,444 @@ +import copy +import json +from typing import Dict, List, Literal, Optional, Union + +import autogen +from autogen.code_utils import execute_code + +ADD_FUNC = { + "type": "function", + "function": { + "name": "add_function", + "description": "Add a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the function in the code implementation."}, + "description": {"type": "string", "description": "A short description of the function."}, + "arguments": { + "type": "string", + "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.', + }, + "packages": { + "type": "string", + "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.", + }, + "code": { + "type": "string", + "description": "The implementation in Python. Do not include the function declaration.", + }, + }, + "required": ["name", "description", "arguments", "packages", "code"], + }, + }, +} + +REVISE_FUNC = { + "type": "function", + "function": { + "name": "revise_function", + "description": "Revise a function in the context of the conversation. Necessary Python packages must be declared. The name of the function MUST be the same with the function name in the code you generated.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the function in the code implementation."}, + "description": {"type": "string", "description": "A short description of the function."}, + "arguments": { + "type": "string", + "description": 'JSON schema of arguments encoded as a string. Please note that the JSON schema only supports specific types including string, integer, object, array, boolean. (do not have float type) For example: { "url": { "type": "string", "description": "The URL", }}. Please avoid the error \'array schema missing items\' when using array type.', + }, + "packages": { + "type": "string", + "description": "A list of package names imported by the function, and that need to be installed with pip prior to invoking the function. This solves ModuleNotFoundError. It should be string, not list.", + }, + "code": { + "type": "string", + "description": "The implementation in Python. Do not include the function declaration.", + }, + }, + "required": ["name", "description", "arguments", "packages", "code"], + }, + }, +} + +REMOVE_FUNC = { + "type": "function", + "function": { + "name": "remove_function", + "description": "Remove one function in the context of the conversation. Once remove one function, the assistant will not use this function in future conversation.", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The name of the function in the code implementation."} + }, + "required": ["name"], + }, + }, +} + +OPT_PROMPT = """You are a function optimizer. Your task is to maintain a list of functions for the assistant according to the existing function list and conversation history that happens between the assistant and the user. +You can perform one of the following four actions to manipulate the function list using the functions you have: +1. Revise one existing function (using revise_function). +2. Remove one existing function (using remove_function). +3. Add one new function (using add_function). +4. Directly return "TERMINATE" to me if no more actions are needed for the current function list. + +Below are the principles that you need to follow for taking these four actions. +(1) Revise one existing function: +1. Pay more attention to the failed tasks and corresponding error information, and optimize the function used in these tasks according to the conversation history if needed. +2. A failed function call can occur due to incorrect input arguments (missing arguments) or an incorrect function code implementation. You should focus more on the function code implementation and make it easy to get success function call. +3. Do not revise the function that you think works well and plays a critical role in solving the problems according to the conversation history. Only making revisions if needed. +4. Sometimes, a NameError may occur. To fix this error, you can either revise the name of the function in the code implementation or revise the name of the function call to make these two names consistent. +(2) Remove one existing function: +1. Only remove the function that you think is not needed anymore in future tasks. +(3) Add one new function: +1. The added function should be general enough to be used in future tasks. For instance, if you encounter a problem that this function can solve, or one step of it, you can use the generated function directly instead of starting from scratch +2. The added new function should solve a higher-level question that encompasses the original query and extend the code's functionality to make it more versatile and widely applicable. +3. Replace specific strings or variable names with general variables to enhance the tool's applicability to various queries. All names used inside the function should be passed in as arguments. +Below is an example of a function that potentially deserves to be adde in solving MATH problems, which can be used to solve a higher-level question: +{{ + \"name\": \"evaluate_expression\", + \"description\": \"Evaluate arithmetic or mathematical expressions provided as strings.\", + \"arguments\": {{ + \"expression\": {{ + \"type\": \"string\", + \"description\": \"The mathematical expression to evaluate.\" + }} + }}, + \"packages\": \"sympy\", + \"code\": \"from sympy import sympify, SympifyError\\n\\ndef evaluate_expression(expression):\\n try:\\n result = sympify(expression)\\n if result.is_number:\\n result = float(result)\\n else:\\n result = str(result)\\n return result\\n except SympifyError as e:\\n return str(e)\" +}} +(4) Directly return "TERMINATE": +If you think there is no need to perform any other actions for the current function list since the current list is optimal more actions will harm the performance in future tasks. Please directly reply to me with "TERMINATE". + +One function signature includes the following five elements: +1. Function name +2. Function description +3. JSON schema of arguments encoded as a string +4. A list of package names imported by the function packages +5. The code implementation + +Below are the signatures of the current functions: +List A: {best_functions}. +The following list are the function signatures that you have after taking {actions_num} actions to manipulate List A: +List B: {incumbent_functions}. + +{accumulated_experience} + +Here are {best_conversations_num} conversation histories of solving {best_conversations_num} tasks using List A. +History: +{best_conversations_history} + +{statistic_informations} + +According to the information I provide, please take one of four actions to manipulate list B using the functions you know. +Instead of returning TERMINATE directly or taking no action, you should try your best to optimize the function list. Only take no action if you really think the current list is optimal, as more actions will harm performance in future tasks. +Even adding a general function that can substitute the assistant’s repeated suggestions of Python code with the same functionality could also be helpful. +""" + + +def execute_func(name, packages, code, **args): + """ + The wrapper for generated functions. + """ + pip_install = ( + f"""print("Installing package: {packages}")\nsubprocess.run(["pip", "-qq", "install", "{packages}"])""" + if packages + else "" + ) + str = f""" +import subprocess +{pip_install} +print("Result of {name} function execution:") +{code} +args={args} +result={name}(**args) +if result is not None: print(result) +""" + print(f"execute_code:\n{str}") + result = execute_code(str, use_docker="shaokun529/evoagent:v1") + if result[0] != 0: + raise Exception("Error in executing function:" + result[1]) + print(f"Result: {result[1]}") + return result[1] + + +class AgentOptimizer: + """ + Base class for optimizing AutoGen agents. Specifically, it is used to optimize the functions used in the agent. + More information could be found in the following paper: https://arxiv.org/abs/2402.11359. + """ + + def __init__( + self, + max_actions_per_step: int, + llm_config: dict, + optimizer_model: Optional[str] = "gpt-4-1106-preview", + ): + """ + (These APIs are experimental and may change in the future.) + Args: + max_actions_per_step (int): the maximum number of actions that the optimizer can take in one step. + llm_config (dict): llm inference configuration. + Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. + optimizer_model: the model used for the optimizer. + """ + self.max_actions_per_step = max_actions_per_step + self._max_trials = 3 + self.optimizer_model = optimizer_model + + self._trial_conversations_history = [] + self._trial_conversations_performance = [] + self._trial_functions = [] + + self._best_conversations_history = [] + self._best_conversations_performance = [] + self._best_functions = [] + + self._failure_functions_performance = [] + self._best_performance = -1 + + assert isinstance(llm_config, dict), "llm_config must be a dict" + llm_config = copy.deepcopy(llm_config) + self.llm_config = llm_config + if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]: + raise ValueError( + "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'." + ) + self.llm_config["config_list"] = autogen.filter_config( + llm_config["config_list"], {"model": [self.optimizer_model]} + ) + self._client = autogen.OpenAIWrapper(**self.llm_config) + + def record_one_conversation(self, conversation_history: List[Dict], is_satisfied: bool = None): + """ + record one conversation history. + Args: + conversation_history (List[Dict]): the chat messages of the conversation. + is_satisfied (bool): whether the user is satisfied with the solution. If it is none, the user will be asked to input the satisfaction. + """ + if is_satisfied is None: + reply = input( + "Please provide whether the user is satisfied with the solution. 1 represents satisfied. 0 represents not satisfied. Press enter to submit. \n" + ) + assert reply in [ + "0", + "1", + ], "The input is invalid. Please input 1 or 0. 1 represents satisfied. 0 represents not satisfied." + is_satisfied = True if reply == "1" else False + self._trial_conversations_history.append( + {"Conversation {i}".format(i=len(self._trial_conversations_history)): conversation_history} + ) + self._trial_conversations_performance.append( + {"Conversation {i}".format(i=len(self._trial_conversations_performance)): 1 if is_satisfied else 0} + ) + + def step(self): + """ + One step of training. It will return register_for_llm and register_for_executor at each iteration, + which are subsequently utilized to update the assistant and executor agents, respectively. + See example: https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentoptimizer.ipynb + """ + performance = sum(sum(d.values()) for d in self._trial_conversations_performance) / len( + self._trial_conversations_performance + ) + + if performance < self._best_performance: + self._failure_functions_performance.append({"functions": self._trial_functions, "performance": performance}) + self._failure_functions_performance = sorted( + self._failure_functions_performance, key=lambda x: x["performance"] + ) + else: + self._failure_functions_performance = [] + self._best_performance = performance + self._best_functions = copy.deepcopy(self._trial_functions) + self._best_conversations_history = copy.deepcopy(self._trial_conversations_history) + self._best_conversations_performance = copy.deepcopy(self._trial_conversations_performance) + self._trial_conversations_history = [] + self._trial_conversations_performance = [] + + best_functions = copy.deepcopy(self._best_functions) + incumbent_functions = copy.deepcopy(self._best_functions) + failure_experience_prompt, statistic_prompt = self._construct_intermediate_prompt() + + for action_index in range(self.max_actions_per_step): + prompt = OPT_PROMPT.format( + best_conversations_history=self._best_conversations_history, + best_conversations_num=len(self._best_conversations_history), + actions_num=action_index, + best_functions=best_functions, + incumbent_functions=incumbent_functions, + accumulated_experience=failure_experience_prompt, + statistic_informations=statistic_prompt, + ) + messages = [{"role": "user", "content": prompt}] + for _ in range(self._max_trials): + response = self._client.create( + messages=messages, tools=[ADD_FUNC, REVISE_FUNC, REMOVE_FUNC], tool_choice="auto" + ) + actions = response.choices[0].message.tool_calls + if self._validate_actions(actions, incumbent_functions): + break + if actions is not None and self._validate_actions(actions, incumbent_functions): + incumbent_functions = self._update_function_call(incumbent_functions, actions) + + remove_functions = list( + set([key for dictionary in self._trial_functions for key in dictionary.keys()]) + - set([key for dictionary in incumbent_functions for key in dictionary.keys()]) + ) + + register_for_llm = [] + register_for_exector = {} + for name in remove_functions: + register_for_llm.append({"func_sig": {"name": name}, "is_remove": True}) + register_for_exector.update({name: None}) + for func in incumbent_functions: + register_for_llm.append( + { + "func_sig": { + "name": func.get("name"), + "description": func.get("description"), + "parameters": {"type": "object", "properties": func.get("arguments")}, + }, + "is_remove": False, + } + ) + register_for_exector.update( + { + func.get("name"): lambda **args: execute_func( + func.get("name"), func.get("packages"), func.get("code"), **args + ) + } + ) + + self._trial_functions = incumbent_functions + return register_for_llm, register_for_exector + + def reset_optimizer(self): + """ + reset the optimizer. + """ + + self._trial_conversations_history = [] + self._trial_conversations_performance = [] + self._trial_functions = [] + + self._best_conversations_history = [] + self._best_conversations_performance = [] + self._best_functions = [] + + self._best_performance = -1 + self._failure_functions_performance = [] + + def _update_function_call(self, incumbent_functions, actions): + """ + update function call. + """ + + formated_actions = [] + for action in actions: + func = json.loads(action.function.arguments.strip('"')) + func["action_name"] = action.function.name + + if func.get("action_name") == "remove_function": + item = { + "action_name": func.get("action_name"), + "name": func.get("name"), + } + else: + item = { + "action_name": func.get("action_name"), + "name": func.get("name"), + "description": func.get("description"), + "arguments": json.loads(func.get("arguments").strip('"')), + "packages": func.get("packages"), + "code": func.get("code"), + } + formated_actions.append(item) + actions = formated_actions + + for action in actions: + name, description, arguments, packages, code, action_name = ( + action.get("name"), + action.get("description"), + action.get("arguments"), + action.get("packages"), + action.get("code"), + action.get("action_name"), + ) + if action_name == "remove_function": + incumbent_functions = [item for item in incumbent_functions if item["name"] != name] + else: + incumbent_functions = [item for item in incumbent_functions if item["name"] != name] + incumbent_functions.append( + { + "name": name, + "description": description, + "arguments": arguments, + "packages": packages, + "code": code, + } + ) + + return incumbent_functions + + def _construct_intermediate_prompt(self): + """ + construct intermediate prompts. + """ + if len(self._failure_functions_performance) != 0: + failure_experience_prompt = "We also provide more examples for different functions and their corresponding performance (0-100).\n The following function signatures are arranged in are arranged in ascending order based on their performance, where higher performance indicate better quality." + failure_experience_prompt += "\n" + for item in self._failure_functions_performance: + failure_experience_prompt += "Function: \n" + str(item["functions"]) + "\n" + failure_experience_prompt += "Performance: \n" + str(item["performance"]) + "\n" + else: + failure_experience_prompt = "\n" + + if len(self._best_conversations_performance) != 0: + statistic_prompt = "The following table shows the statistical information for solving each task in each conversation and indicates, whether the result is satisfied by the users. 1 represents satisfied. 0 represents not satisfied." + statistic_prompt += "\n" + for item in self._best_conversations_performance: + statistic_prompt += str(item) + "\n" + else: + statistic_prompt = "\n" + + return failure_experience_prompt, statistic_prompt + + def _validate_actions(self, actions, incumbent_functions): + """ + validate whether the proposed actions are feasible. + """ + if actions is None: + return True + else: + # val json format + for action in actions: + function_args = action.function.arguments + try: + function_args = json.loads(function_args.strip('"')) + if "arguments" in function_args.keys(): + json.loads(function_args.get("arguments").strip('"')) + except Exception as e: + print("JSON is invalid:", e) + return False + # val syntax + for action in actions: + if action.function.name != "remove_function": + function_args = json.loads(action.function.arguments.strip('"')) + code = function_args.get("code") + try: + compile(code, "", "exec") + print("successfully compiled") + except Exception as e: + print("Syntax is invalid:", e) + return False + for action in actions: + action_name = action.function.name + if action_name == "remove_function": + function_args = json.loads(action.function.arguments.strip('"')) + if function_args.get("name") not in [item["name"] for item in incumbent_functions]: + print("The function you want to remove does not exist.") + return False + return True diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py index 1510ae5fcd6..173811842eb 100644 --- a/autogen/agentchat/contrib/capabilities/context_handling.py +++ b/autogen/agentchat/contrib/capabilities/context_handling.py @@ -1,9 +1,18 @@ import sys -from termcolor import colored -from typing import Dict, Optional, List -from autogen import ConversableAgent -from autogen import token_count_utils +from typing import Dict, List, Optional +from warnings import warn + import tiktoken +from termcolor import colored + +from autogen import ConversableAgent, token_count_utils + +warn( + "Context handling with TransformChatHistory is deprecated. " + "Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.", + DeprecationWarning, + stacklevel=2, +) class TransformChatHistory: @@ -26,7 +35,8 @@ class TransformChatHistory: 3. Third, it limits the total number of tokens in the chat history When adding this capability to an agent, the following are modified: - - A hook is added to the hookable method `process_all_messages_before_reply` to transform the received messages for possible truncation. + - A hook is added to the hookable method `process_all_messages_before_reply` to transform the + received messages for possible truncation. Not modifying the stored message history. """ diff --git a/autogen/agentchat/contrib/capabilities/generate_images.py b/autogen/agentchat/contrib/capabilities/generate_images.py new file mode 100644 index 00000000000..e4a8f1195c2 --- /dev/null +++ b/autogen/agentchat/contrib/capabilities/generate_images.py @@ -0,0 +1,291 @@ +import re +from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union + +from openai import OpenAI +from PIL.Image import Image + +from autogen import Agent, ConversableAgent, code_utils +from autogen.agentchat.contrib import img_utils +from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability +from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent +from autogen.cache import AbstractCache + +SYSTEM_MESSAGE = "You've been given the special ability to generate images." +DESCRIPTION_MESSAGE = "This agent has the ability to generate images." + +PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT. +DO NOT include any advice. RESPOND like the following example: +EXAMPLE: Blue background, 3D shapes, ... +""" + + +class ImageGenerator(Protocol): + """This class defines an interface for image generators. + + Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as + input and returns a PIL Image object. + + NOTE: Current implementation does not allow you to edit a previously existing image. + """ + + def generate_image(self, prompt: str) -> Image: + """Generates an image based on the provided prompt. + + Args: + prompt: A string describing the desired image. + + Returns: + A PIL Image object representing the generated image. + + Raises: + ValueError: If the image generation fails. + """ + ... + + def cache_key(self, prompt: str) -> str: + """Generates a unique cache key for the given prompt. + + This key can be used to store and retrieve generated images based on the prompt. + + Args: + prompt: A string describing the desired image. + + Returns: + A unique string that can be used as a cache key. + """ + ... + + +class DalleImageGenerator: + """Generates images using OpenAI's DALL-E models. + + This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E + models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate. + + Note: Current implementation does not allow you to edit a previously existing image. + """ + + def __init__( + self, + llm_config: Dict, + resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024", + quality: Literal["standard", "hd"] = "standard", + num_images: int = 1, + ): + """ + Args: + llm_config (dict): llm config, must contain a valid dalle model and OpenAI API key in config_list. + resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792". + quality (str): The quality of the image you want to generate. Must be one of "standard", "hd". + num_images (int): The number of images to generate. + """ + config_list = llm_config["config_list"] + _validate_dalle_model(config_list[0]["model"]) + _validate_resolution_format(resolution) + + self._model = config_list[0]["model"] + self._resolution = resolution + self._quality = quality + self._num_images = num_images + self._dalle_client = OpenAI(api_key=config_list[0]["api_key"]) + + def generate_image(self, prompt: str) -> Image: + response = self._dalle_client.images.generate( + model=self._model, + prompt=prompt, + size=self._resolution, + quality=self._quality, + n=self._num_images, + ) + + image_url = response.data[0].url + if image_url is None: + raise ValueError("Failed to generate image.") + + return img_utils.get_pil_image(image_url) + + def cache_key(self, prompt: str) -> str: + keys = (prompt, self._model, self._resolution, self._quality, self._num_images) + return ",".join([str(k) for k in keys]) + + +class ImageGeneration(AgentCapability): + """This capability allows a ConversableAgent to generate images based on the message received from other Agents. + + 1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and + extract relevant details. + 2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image. + 3. Optionally caches generated images for faster retrieval in future conversations. + + NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every + message received by the agent. + + Example: + ```python + import autogen + from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration + + # Assuming you have llm configs configured for the LLMs you want to use and Dalle. + # Create the agent + agent = autogen.ConversableAgent( + name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER" + ) + + # Create an ImageGenerator with desired settings + dalle_gen = generate_images.DalleImageGenerator(llm_config={...}) + + # Add the ImageGeneration capability to the agent + agent.add_capability(ImageGeneration(image_generator=dalle_gen)) + ``` + """ + + def __init__( + self, + image_generator: ImageGenerator, + cache: Optional[AbstractCache] = None, + text_analyzer_llm_config: Optional[Dict] = None, + text_analyzer_instructions: str = PROMPT_INSTRUCTIONS, + verbosity: int = 0, + register_reply_position: int = 2, + ): + """ + Args: + image_generator (ImageGenerator): The image generator you would like to use to generate images. + cache (None or AbstractCache): The cache client to use to store and retrieve generated images. If None, + no caching will be used. + text_analyzer_llm_config (Dict or None): The LLM config for the text analyzer. If None, the LLM config will + be retrieved from the agent you're adding the ability to. + text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze + incoming messages and extract the prompt for image generation. The default instructions focus on + summarizing the prompt. You can customize the instructions to achieve more granular control over prompt + extraction. + Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.' + verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text + analyzer llm calls will be silent if verbosity is less than 2. + register_reply_position (int): The position of the reply function in the agent's list of reply functions. + This capability registers a new reply function to handle messages with image generation requests. + Defaults to 2 to place it after the check termination and human reply for a ConversableAgent. + """ + self._image_generator = image_generator + self._cache = cache + self._text_analyzer_llm_config = text_analyzer_llm_config + self._text_analyzer_instructions = text_analyzer_instructions + self._verbosity = verbosity + self._register_reply_position = register_reply_position + + self._agent: Optional[ConversableAgent] = None + self._text_analyzer: Optional[TextAnalyzerAgent] = None + + def add_to_agent(self, agent: ConversableAgent): + """Adds the Image Generation capability to the specified ConversableAgent. + + This function performs the following modifications to the agent: + + 1. Registers a reply function: A new reply function is registered with the agent to handle messages that + potentially request image generation. This function analyzes the message and triggers image generation if + necessary. + 2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements. + 3. Updates System Message: The agent's system message is updated to include a message indicating the + capability to generate images has been added. + 4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation + capability. This might be helpful in certain use cases, like group chats. + + Args: + agent (ConversableAgent): The ConversableAgent to add the capability to. + """ + self._agent = agent + + agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position) + + self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config + self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config) + + agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE) + agent.description += "\n" + DESCRIPTION_MESSAGE + + def _image_gen_reply( + self, + recipient: ConversableAgent, + messages: Optional[List[Dict]], + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + if messages is None: + return False, None + + last_message = code_utils.content_str(messages[-1]["content"]) + + if not last_message: + return False, None + + if self._should_generate_image(last_message): + prompt = self._extract_prompt(last_message) + + image = self._cache_get(prompt) + if image is None: + image = self._image_generator.generate_image(prompt) + self._cache_set(prompt, image) + + return True, self._generate_content_message(prompt, image) + + else: + return False, None + + def _should_generate_image(self, message: str) -> bool: + assert self._text_analyzer is not None + + instructions = """ + Does any part of the TEXT ask the agent to generate an image? + The TEXT must explicitly mention that the image must be generated. + Answer with just one word, yes or no. + """ + analysis = self._text_analyzer.analyze_text(message, instructions) + + return "yes" in self._extract_analysis(analysis).lower() + + def _extract_prompt(self, last_message) -> str: + assert self._text_analyzer is not None + + analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions) + return self._extract_analysis(analysis) + + def _cache_get(self, prompt: str) -> Optional[Image]: + if self._cache: + key = self._image_generator.cache_key(prompt) + cached_value = self._cache.get(key) + + if cached_value: + return img_utils.get_pil_image(cached_value) + + def _cache_set(self, prompt: str, image: Image): + if self._cache: + key = self._image_generator.cache_key(prompt) + self._cache.set(key, img_utils.pil_to_data_uri(image)) + + def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str: + if isinstance(analysis, Dict): + return code_utils.content_str(analysis["content"]) + else: + return code_utils.content_str(analysis) + + def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]: + return { + "content": [ + {"type": "text", "text": f"I generated an image with the prompt: {prompt}"}, + {"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}}, + ] + } + + +### Helpers +def _validate_resolution_format(resolution: str): + """Checks if a string is in a valid resolution format (e.g., "1024x768").""" + pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits + matched_resolution = re.match(pattern, resolution) + if matched_resolution is None: + raise ValueError(f"Invalid resolution format: {resolution}") + + +def _validate_dalle_model(model: str): + if model not in ["dall-e-3", "dall-e-2"]: + raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'") diff --git a/autogen/agentchat/contrib/capabilities/teachability.py b/autogen/agentchat/contrib/capabilities/teachability.py index 9e18f99a345..596e449ce34 100644 --- a/autogen/agentchat/contrib/capabilities/teachability.py +++ b/autogen/agentchat/contrib/capabilities/teachability.py @@ -1,12 +1,15 @@ import os +import pickle from typing import Dict, Optional, Union + import chromadb from chromadb.config import Settings -import pickle + from autogen.agentchat.assistant_agent import ConversableAgent from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent -from autogen.agentchat.conversable_agent import colored + +from ....formatting_utils import colored class Teachability(AgentCapability): @@ -83,7 +86,7 @@ def prepopulate_db(self): """Adds a few arbitrary memos to the DB.""" self.memo_store.prepopulate() - def process_last_received_message(self, text): + def process_last_received_message(self, text: Union[Dict, str]): """ Appends any relevant memos to the message text, and stores any apparent teachings in new memos. Uses TextAnalyzerAgent to make decisions about memo storage and retrieval. @@ -100,7 +103,7 @@ def process_last_received_message(self, text): # Return the (possibly) expanded message text. return expanded_text - def _consider_memo_storage(self, comment): + def _consider_memo_storage(self, comment: Union[Dict, str]): """Decides whether to store something from one user comment in the DB.""" memo_added = False @@ -158,7 +161,7 @@ def _consider_memo_storage(self, comment): # Yes. Save them to disk. self.memo_store._save_memos() - def _consider_memo_retrieval(self, comment): + def _consider_memo_retrieval(self, comment: Union[Dict, str]): """Decides whether to retrieve memos from the DB, and add them to the chat context.""" # First, use the comment directly as the lookup key. @@ -192,7 +195,7 @@ def _consider_memo_retrieval(self, comment): # Append the memos to the text of the last message. return comment + self._concatenate_memo_texts(memo_list) - def _retrieve_relevant_memos(self, input_text): + def _retrieve_relevant_memos(self, input_text: str) -> list: """Returns semantically related memos from the DB.""" memo_list = self.memo_store.get_related_memos( input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold @@ -210,7 +213,7 @@ def _retrieve_relevant_memos(self, input_text): memo_list = [memo[1] for memo in memo_list] return memo_list - def _concatenate_memo_texts(self, memo_list): + def _concatenate_memo_texts(self, memo_list: list) -> str: """Concatenates the memo texts into a single string for inclusion in the chat context.""" memo_texts = "" if len(memo_list) > 0: @@ -222,7 +225,7 @@ def _concatenate_memo_texts(self, memo_list): memo_texts = memo_texts + "\n" + info return memo_texts - def _analyze(self, text_to_analyze, analysis_instructions): + def _analyze(self, text_to_analyze: Union[Dict, str], analysis_instructions: Union[Dict, str]): """Asks TextAnalyzerAgent to analyze the given text according to specific instructions.""" self.analyzer.reset() # Clear the analyzer's list of messages. self.teachable_agent.send( @@ -243,10 +246,16 @@ class MemoStore: Vector embeddings are currently supplied by Chroma's default Sentence Transformers. """ - def __init__(self, verbosity, reset, path_to_db_dir): + def __init__( + self, + verbosity: Optional[int] = 0, + reset: Optional[bool] = False, + path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db", + ): """ Args: - verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists. + - reset (Optional, bool): True to clear the DB before starting. Default False. - path_to_db_dir (Optional, str): path to the directory where the DB is stored. """ self.verbosity = verbosity @@ -301,7 +310,7 @@ def reset_db(self): self.uid_text_dict = {} self._save_memos() - def add_input_output_pair(self, input_text, output_text): + def add_input_output_pair(self, input_text: str, output_text: str): """Adds an input-output pair to the vector DB.""" self.last_memo_id += 1 self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)]) @@ -318,7 +327,7 @@ def add_input_output_pair(self, input_text, output_text): if self.verbosity >= 3: self.list_memos() - def get_nearest_memo(self, query_text): + def get_nearest_memo(self, query_text: str): """Retrieves the nearest memo to the given query text.""" results = self.vec_db.query(query_texts=[query_text], n_results=1) uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0] @@ -335,7 +344,7 @@ def get_nearest_memo(self, query_text): ) return input_text, output_text, distance - def get_related_memos(self, query_text, n_results, threshold): + def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]): """Retrieves memos that are related to the given query text within the specified distance threshold.""" if n_results > len(self.uid_text_dict): n_results = len(self.uid_text_dict) diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py new file mode 100644 index 00000000000..e96dc39fa7b --- /dev/null +++ b/autogen/agentchat/contrib/capabilities/transform_messages.py @@ -0,0 +1,87 @@ +import copy +from typing import Dict, List + +from autogen import ConversableAgent + +from ....formatting_utils import colored +from .transforms import MessageTransform + + +class TransformMessages: + """Agent capability for transforming messages before reply generation. + + This capability allows you to apply a series of message transformations to + a ConversableAgent's incoming messages before they are processed for response + generation. This is useful for tasks such as: + + - Limiting the number of messages considered for context. + - Truncating messages to meet token limits. + - Filtering sensitive information. + - Customizing message formatting. + + To use `TransformMessages`: + + 1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`). + 2. Instantiate `TransformMessages` with a list of these transformations. + 3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`. + + NOTE: Order of message transformations is important. You could get different results based on + the order of transformations. + + Example: + ```python + from agentchat import ConversableAgent + from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter + + max_messages = MessageHistoryLimiter(max_messages=2) + truncate_messages = MessageTokenLimiter(max_tokens=500) + transform_messages = TransformMessages(transforms=[max_messages, truncate_messages]) + + agent = ConversableAgent(...) + transform_messages.add_to_agent(agent) + ``` + """ + + def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True): + """ + Args: + transforms: A list of message transformations to apply. + verbose: Whether to print logs of each transformation or not. + """ + self._transforms = transforms + self._verbose = verbose + + def add_to_agent(self, agent: ConversableAgent): + """Adds the message transformations capability to the specified ConversableAgent. + + This function performs the following modifications to the agent: + + 1. Registers a hook that automatically transforms all messages before they are processed for + response generation. + """ + agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages) + + def _transform_messages(self, messages: List[Dict]) -> List[Dict]: + post_transform_messages = copy.deepcopy(messages) + system_message = None + + if messages[0]["role"] == "system": + system_message = copy.deepcopy(messages[0]) + post_transform_messages.pop(0) + + for transform in self._transforms: + # deepcopy in case pre_transform_messages will later be used for logs printing + pre_transform_messages = ( + copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages + ) + post_transform_messages = transform.apply_transform(pre_transform_messages) + + if self._verbose: + logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages) + if had_effect: + print(colored(logs_str, "yellow")) + + if system_message: + post_transform_messages.insert(0, system_message) + + return post_transform_messages diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py new file mode 100644 index 00000000000..6dc1d59fe9c --- /dev/null +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -0,0 +1,255 @@ +import copy +import sys +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union + +import tiktoken +from termcolor import colored + +from autogen import token_count_utils + + +class MessageTransform(Protocol): + """Defines a contract for message transformation. + + Classes implementing this protocol should provide an `apply_transform` method + that takes a list of messages and returns the transformed list. + """ + + def apply_transform(self, messages: List[Dict]) -> List[Dict]: + """Applies a transformation to a list of messages. + + Args: + messages: A list of dictionaries representing messages. + + Returns: + A new list of dictionaries containing the transformed messages. + """ + ... + + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + """Creates the string including the logs of the transformation + + Alongside the string, it returns a boolean indicating whether the transformation had an effect or not. + + Args: + pre_transform_messages: A list of dictionaries representing messages before the transformation. + post_transform_messages: A list of dictionaries representig messages after the transformation. + + Returns: + A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not. + """ + ... + + +class MessageHistoryLimiter: + """Limits the number of messages considered by an agent for response generation. + + This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages). + It trims the conversation history by removing older messages, retaining only the most recent messages. + """ + + def __init__(self, max_messages: Optional[int] = None): + """ + Args: + max_messages (None or int): Maximum number of messages to keep in the context. + Must be greater than 0 if not None. + """ + self._validate_max_messages(max_messages) + self._max_messages = max_messages + + def apply_transform(self, messages: List[Dict]) -> List[Dict]: + """Truncates the conversation history to the specified maximum number of messages. + + This method returns a new list containing the most recent messages up to the specified + maximum number of messages (max_messages). If max_messages is None, it returns the + original list of messages unmodified. + + Args: + messages (List[Dict]): The list of messages representing the conversation history. + + Returns: + List[Dict]: A new list containing the most recent messages up to the specified maximum. + """ + if self._max_messages is None: + return messages + + return messages[-self._max_messages :] + + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + pre_transform_messages_len = len(pre_transform_messages) + post_transform_messages_len = len(post_transform_messages) + + if post_transform_messages_len < pre_transform_messages_len: + logs_str = ( + f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. " + f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}." + ) + return logs_str, True + return "No messages were removed.", False + + def _validate_max_messages(self, max_messages: Optional[int]): + if max_messages is not None and max_messages < 1: + raise ValueError("max_messages must be None or greater than 1") + + +class MessageTokenLimiter: + """Truncates messages to meet token limits for efficient processing and response generation. + + This transformation applies two levels of truncation to the conversation history: + + 1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message. + 2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens. + + NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token + counts for the same text. + + NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input + (e.g images). + + The truncation process follows these steps in order: + + 1. Messages are processed in reverse order (newest to oldest). + 2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text + and other types of content, only the text content is truncated. + 3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count + exceeds this limit, the current message being processed get truncated to meet the total token count and any + remaining messages get discarded. + 4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the + original message order. + """ + + def __init__( + self, + max_tokens_per_message: Optional[int] = None, + max_tokens: Optional[int] = None, + model: str = "gpt-3.5-turbo-0613", + ): + """ + Args: + max_tokens_per_message (None or int): Maximum number of tokens to keep in each message. + Must be greater than or equal to 0 if not None. + max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history. + Must be greater than or equal to 0 if not None. + model (str): The target OpenAI model for tokenization alignment. + """ + self._model = model + self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message) + self._max_tokens = self._validate_max_tokens(max_tokens) + + def apply_transform(self, messages: List[Dict]) -> List[Dict]: + """Applies token truncation to the conversation history. + + Args: + messages (List[Dict]): The list of messages representing the conversation history. + + Returns: + List[Dict]: A new list containing the truncated messages up to the specified token limits. + """ + assert self._max_tokens_per_message is not None + assert self._max_tokens is not None + + temp_messages = copy.deepcopy(messages) + processed_messages = [] + processed_messages_tokens = 0 + + for msg in reversed(temp_messages): + # Some messages may not have content. + if not isinstance(msg.get("content"), (str, list)): + processed_messages.insert(0, msg) + continue + + expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message + + # If adding this message would exceed the token limit, truncate the last message to meet the total token + # limit and discard all remaining messages + if expected_tokens_remained < 0: + msg["content"] = self._truncate_str_to_tokens( + msg["content"], self._max_tokens - processed_messages_tokens + ) + processed_messages.insert(0, msg) + break + + msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message) + msg_tokens = _count_tokens(msg["content"]) + + # prepend the message to the list to preserve order + processed_messages_tokens += msg_tokens + processed_messages.insert(0, msg) + + return processed_messages + + def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: + pre_transform_messages_tokens = sum( + _count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg + ) + post_transform_messages_tokens = sum( + _count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg + ) + + if post_transform_messages_tokens < pre_transform_messages_tokens: + logs_str = ( + f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " + f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" + ) + return logs_str, True + return "No tokens were truncated.", False + + def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: + if isinstance(contents, str): + return self._truncate_tokens(contents, n_tokens) + elif isinstance(contents, list): + return self._truncate_multimodal_text(contents, n_tokens) + else: + raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}") + + def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]: + """Truncates text content within a list of multimodal elements, preserving the overall structure.""" + tmp_contents = [] + for content in contents: + if content["type"] == "text": + truncated_text = self._truncate_tokens(content["text"], n_tokens) + tmp_contents.append({"type": "text", "text": truncated_text}) + else: + tmp_contents.append(content) + return tmp_contents + + def _truncate_tokens(self, text: str, n_tokens: int) -> str: + encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer + + encoded_tokens = encoding.encode(text) + truncated_tokens = encoded_tokens[:n_tokens] + truncated_text = encoding.decode(truncated_tokens) # Decode back to text + + return truncated_text + + def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]: + if max_tokens is not None and max_tokens < 0: + raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0") + + try: + allowed_tokens = token_count_utils.get_max_token_limit(self._model) + except Exception: + print(colored(f"Model {self._model} not found in token_count_utils.", "yellow")) + allowed_tokens = None + + if max_tokens is not None and allowed_tokens is not None: + if max_tokens > allowed_tokens: + print( + colored( + f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.", + "yellow", + ) + ) + return allowed_tokens + + return max_tokens if max_tokens is not None else sys.maxsize + + +def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int: + token_count = 0 + if isinstance(content, str): + token_count = token_count_utils.count_token(content) + elif isinstance(content, list): + for item in content: + token_count += _count_tokens(item.get("text", "")) + return token_count diff --git a/autogen/agentchat/contrib/capabilities/vision_capability.py b/autogen/agentchat/contrib/capabilities/vision_capability.py new file mode 100644 index 00000000000..acfb9c8f6d8 --- /dev/null +++ b/autogen/agentchat/contrib/capabilities/vision_capability.py @@ -0,0 +1,211 @@ +import copy +from typing import Callable, Dict, List, Optional, Union + +from autogen.agentchat.assistant_agent import ConversableAgent +from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability +from autogen.agentchat.contrib.img_utils import ( + convert_base64_to_data_uri, + get_image_data, + get_pil_image, + gpt4v_formatter, + message_formatter_pil_to_b64, +) +from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent +from autogen.agentchat.conversable_agent import colored +from autogen.code_utils import content_str +from autogen.oai.client import OpenAIWrapper + +DEFAULT_DESCRIPTION_PROMPT = ( + "Write a detailed caption for this image. " + "Pay special attention to any details that might be useful or relevant " + "to the ongoing conversation." +) + + +class VisionCapability(AgentCapability): + """We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability, + such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe + the image (captioning) before sending the information to the agent's actual client. + + The vision capability will hook to the ConversableAgent's `process_last_received_message`. + + Some technical details: + When the agent (who has the vision capability) received an message, it will: + 1. _process_received_message: + a. _append_oai_message + 2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag. + a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.) + b. hook process_all_messages_before_reply + 3. send: + a. hook process_message_before_send + b. _append_oai_message + """ + + def __init__( + self, + lmm_config: Dict, + description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT, + custom_caption_func: Callable = None, + ) -> None: + """ + Initializes a new instance, setting up the configuration for interacting with + a Language Multimodal (LMM) client and specifying optional parameters for image + description and captioning. + + Args: + lmm_config (Dict): Configuration for the LMM client, which is used to call + the LMM service for describing the image. This must be a dictionary containing + the necessary configuration parameters. If `lmm_config` is False or an empty dictionary, + it is considered invalid, and initialization will assert. + description_prompt (Optional[str], optional): The prompt to use for generating + descriptions of the image. This parameter allows customization of the + prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided. + custom_caption_func (Callable, optional): A callable that, if provided, will be used + to generate captions for images. This allows for custom captioning logic outside + of the standard LMM service interaction. + The callable should take three parameters as input: + 1. an image URL (or local location) + 2. image_data (a PIL image) + 3. lmm_client (to call remote LMM) + and then return a description (as string). + If not provided, captioning will rely on the LMM client configured via `lmm_config`. + If provided, we will not run the default self._get_image_caption method. + + Raises: + AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided, + an AssertionError is raised to indicate that the Vision Capability requires + one of these to be valid for operation. + """ + self._lmm_config = lmm_config + self._description_prompt = description_prompt + self._parent_agent = None + + if lmm_config: + self._lmm_client = OpenAIWrapper(**lmm_config) + else: + self._lmm_client = None + + self._custom_caption_func = custom_caption_func + assert ( + self._lmm_config or custom_caption_func + ), "Vision Capability requires a valid lmm_config or custom_caption_func." + + def add_to_agent(self, agent: ConversableAgent) -> None: + self._parent_agent = agent + + # Append extra info to the system message. + agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.") + + # Register a hook for processing the last message. + agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message) + + def process_last_received_message(self, content: Union[str, List[dict]]) -> str: + """ + Processes the last received message content by normalizing and augmenting it + with descriptions of any included images. The function supports input content + as either a string or a list of dictionaries, where each dictionary represents + a content item (e.g., text, image). If the content contains image URLs, it + fetches the image data, generates a caption for each image, and inserts the + caption into the augmented content. + + The function aims to transform the content into a format compatible with GPT-4V + multimodal inputs, specifically by formatting strings into PIL-compatible + images if needed and appending text descriptions for images. This allows for + a more accessible presentation of the content, especially in contexts where + images cannot be displayed directly. + + Args: + content (Union[str, List[dict]]): The last received message content, which + can be a plain text string or a list of dictionaries representing + different types of content items (e.g., text, image_url). + + Returns: + str: The augmented message content + + Raises: + AssertionError: If an item in the content list is not a dictionary. + + Examples: + Assuming `self._get_image_caption(img_data)` returns + "A beautiful sunset over the mountains" for the image. + + - Input as String: + content = "Check out this cool photo!" + Output: "Check out this cool photo!" + (Content is a string without an image, remains unchanged.) + + - Input as String, with image location: + content = "What's weather in this cool photo: " + Output: "What's weather in this cool photo: in case you can not see, the caption of this image is: + A beautiful sunset over the mountains\n" + (Caption added after the image) + + - Input as List with Text Only: + content = [{"type": "text", "text": "Here's an interesting fact."}] + Output: "Here's an interesting fact." + (No images in the content, it remains unchanged.) + + - Input as List with Image URL: + content = [ + {"type": "text", "text": "What's weather in this cool photo:"}, + {"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}} + ] + Output: "What's weather in this cool photo: in case you can not see, the caption of this image is: + A beautiful sunset over the mountains\n" + (Caption added after the image) + """ + copy.deepcopy(content) + # normalize the content into the gpt-4v format for multimodal + # we want to keep the URL format to keep it concise. + if isinstance(content, str): + content = gpt4v_formatter(content, img_format="url") + + aug_content: str = "" + for item in content: + assert isinstance(item, dict) + if item["type"] == "text": + aug_content += item["text"] + elif item["type"] == "image_url": + img_url = item["image_url"]["url"] + img_caption = "" + + if self._custom_caption_func: + img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client) + elif self._lmm_client: + img_data = get_image_data(img_url) + img_caption = self._get_image_caption(img_data) + else: + img_caption = "" + + aug_content += f" in case you can not see, the caption of this image is: {img_caption}\n" + else: + print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.") + + return aug_content + + def _get_image_caption(self, img_data: str) -> str: + """ + Args: + img_data (str): base64 encoded image data. + Returns: + str: caption for the given image. + """ + response = self._lmm_client.create( + context=None, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": self._description_prompt}, + { + "type": "image_url", + "image_url": { + "url": convert_base64_to_data_uri(img_data), + }, + }, + ], + } + ], + ) + description = response.choices[0].message.content + return content_str(description) diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py index e196773effc..9c4e78af852 100644 --- a/autogen/agentchat/contrib/compressible_agent.py +++ b/autogen/agentchat/contrib/compressible_agent.py @@ -1,26 +1,28 @@ -from typing import Callable, Dict, Optional, Union, Tuple, List, Any -from autogen import OpenAIWrapper -from autogen import Agent, ConversableAgent -import copy import asyncio -import logging +import copy import inspect -from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions - -try: - from termcolor import colored -except ImportError: +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from warnings import warn - def colored(x, *args, **kwargs): - return x +from autogen import Agent, ConversableAgent, OpenAIWrapper +from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions +from ...formatting_utils import colored logger = logging.getLogger(__name__) +warn( + "Context handling with CompressibleAgent is deprecated. " + "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/reference/agentchat/contrib/capabilities/transform_messages", + DeprecationWarning, + stacklevel=2, +) + class CompressibleAgent(ConversableAgent): - """(Experimental) CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`, - it also provides the added feature of compression when activated through the `compress_config` setting. + """CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`, + it also provides the added feature of compression when activated through the `compress_config` setting. `compress_config` is set to False by default, making this agent equivalent to the `AssistantAgent`. This agent does not work well in a GroupChat: The compressed messages will not be sent to all the agents in the group. @@ -73,6 +75,7 @@ def __init__( system_message (str): system message for the ChatCompletion inference. Please override this attribute if you want to reprogram the agent. llm_config (dict): llm inference configuration. + Note: you must set `model` in llm_config. It will be used to compute the token count. Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. is_termination_msg (function): a function that takes a message in the form of a dictionary @@ -121,6 +124,8 @@ def __init__( self.llm_compress_config = False self.compress_client = None else: + if "model" not in llm_config: + raise ValueError("llm_config must contain the 'model' field.") self.llm_compress_config = self.llm_config.copy() # remove functions if "functions" in self.llm_compress_config: diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index c4c73ed2c51..253d4d18e2e 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -1,16 +1,16 @@ -from collections import defaultdict -import openai +import copy import json -import time import logging -import copy +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union + +import openai from autogen import OpenAIWrapper -from autogen.oai.openai_utils import retrieve_assistants_by_name from autogen.agentchat.agent import Agent -from autogen.agentchat.assistant_agent import ConversableAgent -from autogen.agentchat.assistant_agent import AssistantAgent -from typing import Dict, Optional, Union, List, Tuple, Any +from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent +from autogen.oai.openai_utils import retrieve_assistants_by_name logger = logging.getLogger(__name__) @@ -28,6 +28,7 @@ def __init__( name="GPT Assistant", instructions: Optional[str] = None, llm_config: Optional[Union[Dict, bool]] = None, + assistant_config: Optional[Dict] = None, overwrite_instructions: bool = False, overwrite_tools: bool = False, **kwargs, @@ -43,8 +44,9 @@ def __init__( AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the system message will be set to the existing assistant instructions. llm_config (dict or False): llm inference configuration. - - assistant_id: ID of the assistant to use. If None, a new assistant will be created. - model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106). + assistant_config + - assistant_id: ID of the assistant to use. If None, a new assistant will be created. - check_every_ms: check thread run status interval - tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval, or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools @@ -57,23 +59,19 @@ def __init__( """ self._verbose = kwargs.pop("verbose", False) + openai_client_cfg, openai_assistant_cfg = self._process_assistant_config(llm_config, assistant_config) + super().__init__( - name=name, system_message=instructions, human_input_mode="NEVER", llm_config=llm_config, **kwargs + name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs ) - if llm_config is False: - raise ValueError("llm_config=False is not supported for GPTAssistantAgent.") - # Use AutooGen OpenAIWrapper to create a client - openai_client_cfg = copy.deepcopy(llm_config) - # Use the class variable - model_name = GPTAssistantAgent.DEFAULT_MODEL_NAME - # GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list() # See: https://github.com/microsoft/autogen/pull/1721 + model_name = self.DEFAULT_MODEL_NAME if openai_client_cfg.get("config_list") is not None and len(openai_client_cfg["config_list"]) > 0: - model_name = openai_client_cfg["config_list"][0].pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME) + model_name = openai_client_cfg["config_list"][0].pop("model", self.DEFAULT_MODEL_NAME) else: - model_name = openai_client_cfg.pop("model", GPTAssistantAgent.DEFAULT_MODEL_NAME) + model_name = openai_client_cfg.pop("model", self.DEFAULT_MODEL_NAME) logger.warning("OpenAI client config of GPTAssistantAgent(%s) - model: %s", name, model_name) @@ -82,14 +80,17 @@ def __init__( logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.") self._openai_client = oai_wrapper._clients[0]._oai_client - openai_assistant_id = llm_config.get("assistant_id", None) + openai_assistant_id = openai_assistant_cfg.get("assistant_id", None) if openai_assistant_id is None: # try to find assistant by name first candidate_assistants = retrieve_assistants_by_name(self._openai_client, name) if len(candidate_assistants) > 0: # Filter out candidates with the same name but different instructions, file IDs, and function names. candidate_assistants = self.find_matching_assistant( - candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", []) + candidate_assistants, + instructions, + openai_assistant_cfg.get("tools", []), + openai_assistant_cfg.get("file_ids", []), ) if len(candidate_assistants) == 0: @@ -103,9 +104,9 @@ def __init__( self._openai_assistant = self._openai_client.beta.assistants.create( name=name, instructions=instructions, - tools=llm_config.get("tools", []), + tools=openai_assistant_cfg.get("tools", []), model=model_name, - file_ids=llm_config.get("file_ids", []), + file_ids=openai_assistant_cfg.get("file_ids", []), ) else: logger.warning( @@ -135,8 +136,8 @@ def __init__( "overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API." ) - # Check if tools are specified in llm_config - specified_tools = llm_config.get("tools", None) + # Check if tools are specified in assistant_config + specified_tools = openai_assistant_cfg.get("tools", None) if specified_tools is None: # Check if the current assistant has tools defined @@ -155,7 +156,7 @@ def __init__( ) self._openai_assistant = self._openai_client.beta.assistants.update( assistant_id=openai_assistant_id, - tools=llm_config.get("tools", []), + tools=openai_assistant_cfg.get("tools", []), ) else: # Tools are specified but overwrite_tools is False; do not update the assistant's tools @@ -164,9 +165,7 @@ def __init__( # lazily create threads self._openai_threads = {} self._unread_index = defaultdict(int) - self.register_reply(Agent, GPTAssistantAgent._invoke_assistant) - self.register_reply(Agent, GPTAssistantAgent.check_termination_and_human_reply) - self.register_reply(Agent, GPTAssistantAgent.a_check_termination_and_human_reply) + self.register_reply(Agent, GPTAssistantAgent._invoke_assistant, position=2) def _invoke_assistant( self, @@ -414,6 +413,10 @@ def assistant_id(self): def openai_client(self): return self._openai_client + @property + def openai_assistant(self): + return self._openai_assistant + def get_assistant_instructions(self): """Return the assistant instructions from OAI assistant API""" return self._openai_assistant.instructions @@ -472,3 +475,31 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil matching_assistants.append(assistant) return matching_assistants + + def _process_assistant_config(self, llm_config, assistant_config): + """ + Process the llm_config and assistant_config to extract the model name and assistant related configurations. + """ + + if llm_config is False: + raise ValueError("llm_config=False is not supported for GPTAssistantAgent.") + + if llm_config is None: + openai_client_cfg = {} + else: + openai_client_cfg = copy.deepcopy(llm_config) + + if assistant_config is None: + openai_assistant_cfg = {} + else: + openai_assistant_cfg = copy.deepcopy(assistant_config) + + # Move the assistant related configurations to assistant_config + # It's important to keep forward compatibility + assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"] + for item in assistant_config_items: + if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None: + openai_assistant_cfg[item] = openai_client_cfg[item] + openai_client_cfg.pop(item, None) + + return openai_client_cfg, openai_assistant_cfg diff --git a/autogen/agentchat/contrib/img_utils.py b/autogen/agentchat/contrib/img_utils.py index 6062f3b0553..a389c74b064 100644 --- a/autogen/agentchat/contrib/img_utils.py +++ b/autogen/agentchat/contrib/img_utils.py @@ -1,14 +1,15 @@ import base64 import copy -import mimetypes import os import re from io import BytesIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import requests from PIL import Image +from autogen.agentchat import utils + def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image: """ @@ -24,6 +25,12 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image: # Already a PIL Image object return image_file + # Remove quotes if existed + if image_file.startswith('"') and image_file.endswith('"'): + image_file = image_file[1:-1] + if image_file.startswith("'") and image_file.endswith("'"): + image_file = image_file[1:-1] + if image_file.startswith("http://") or image_file.startswith("https://"): # A URL file response = requests.get(image_file) @@ -173,13 +180,9 @@ def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dic last_index = 0 image_count = 0 - # Regular expression pattern for matching tags - img_tag_pattern = re.compile(r"]+)>") - # Find all image tags - for match in img_tag_pattern.finditer(prompt): - image_location = match.group(1) - + for parsed_tag in utils.parse_tags_from_content("img", prompt): + image_location = parsed_tag["attr"]["src"] try: if img_format == "pil": img_data = get_pil_image(image_location) @@ -196,12 +199,12 @@ def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dic continue # Add text before this image tag to output list - output.append({"type": "text", "text": prompt[last_index : match.start()]}) + output.append({"type": "text", "text": prompt[last_index : parsed_tag["match"].start()]}) # Add image data to output list output.append({"type": "image_url", "image_url": {"url": img_data}}) - last_index = match.end() + last_index = parsed_tag["match"].end() image_count += 1 # Add remaining text to output list diff --git a/autogen/agentchat/contrib/llava_agent.py b/autogen/agentchat/contrib/llava_agent.py index c26f576ab39..063b256d3cd 100644 --- a/autogen/agentchat/contrib/llava_agent.py +++ b/autogen/agentchat/contrib/llava_agent.py @@ -1,6 +1,7 @@ import json import logging from typing import List, Optional, Tuple + import replicate import requests @@ -8,8 +9,8 @@ from autogen.agentchat.contrib.img_utils import get_image_data, llava_formatter from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent from autogen.code_utils import content_str -from autogen.agentchat.conversable_agent import colored +from ...formatting_utils import colored logger = logging.getLogger(__name__) diff --git a/autogen/agentchat/contrib/math_user_proxy_agent.py b/autogen/agentchat/contrib/math_user_proxy_agent.py index 67c86daf05d..d2b6b7cde00 100644 --- a/autogen/agentchat/contrib/math_user_proxy_agent.py +++ b/autogen/agentchat/contrib/math_user_proxy_agent.py @@ -1,15 +1,15 @@ -import re import os -from pydantic import BaseModel, Extra, root_validator -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import re from time import sleep +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Extra, root_validator from autogen._pydantic import PYDANTIC_V1 from autogen.agentchat import Agent, UserProxyAgent -from autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang +from autogen.code_utils import UNKNOWN, execute_code, extract_code, infer_lang from autogen.math_utils import get_answer - PROMPTS = { # default "default": """Let's use Python to solve a math problem. @@ -177,28 +177,35 @@ def __init__( self._previous_code = "" self.last_reply = None - def generate_init_message(self, problem, prompt_type="default", customized_prompt=None): + @staticmethod + def message_generator(sender, recipient, context): """Generate a prompt for the assistant agent with the given problem and prompt. Args: - problem (str): the problem to be solved. - prompt_type (str): the type of the prompt. Possible values are "default", "python", "wolfram". - (1) "default": the prompt that allows the agent to choose between 3 ways to solve a problem: - 1. write a python program to solve it directly. - 2. solve it directly without python. - 3. solve it step by step with python. - (2) "python": - a simplified prompt from the third way of the "default" prompt, that asks the assistant - to solve the problem step by step with python. - (3) "two_tools": - a simplified prompt similar to the "python" prompt, but allows the model to choose between - Python and Wolfram Alpha to solve the problem. - customized_prompt (str): a customized prompt to be used. If it is not None, the prompt_type will be ignored. + sender (Agent): the sender of the message. + recipient (Agent): the recipient of the message. + context (dict): a dictionary with the following fields: + problem (str): the problem to be solved. + prompt_type (str, Optional): the type of the prompt. Possible values are "default", "python", "wolfram". + (1) "default": the prompt that allows the agent to choose between 3 ways to solve a problem: + 1. write a python program to solve it directly. + 2. solve it directly without python. + 3. solve it step by step with python. + (2) "python": + a simplified prompt from the third way of the "default" prompt, that asks the assistant + to solve the problem step by step with python. + (3) "two_tools": + a simplified prompt similar to the "python" prompt, but allows the model to choose between + Python and Wolfram Alpha to solve the problem. + customized_prompt (str, Optional): a customized prompt to be used. If it is not None, the prompt_type will be ignored. Returns: str: the generated prompt ready to be sent to the assistant agent. """ - self._reset() + sender._reset() + problem = context.get("problem") + prompt_type = context.get("prompt_type", "default") + customized_prompt = context.get("customized_prompt", None) if customized_prompt is not None: return customized_prompt + problem return PROMPTS[prompt_type] + problem diff --git a/autogen/agentchat/contrib/multimodal_conversable_agent.py b/autogen/agentchat/contrib/multimodal_conversable_agent.py index 2355c630f9c..edeb88cd531 100644 --- a/autogen/agentchat/contrib/multimodal_conversable_agent.py +++ b/autogen/agentchat/contrib/multimodal_conversable_agent.py @@ -11,7 +11,6 @@ from ..._pydantic import model_dump - DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant.""" DEFAULT_MODEL = "gpt-4-vision-preview" @@ -53,16 +52,8 @@ def __init__( ) # Override the `generate_oai_reply` - def _replace_reply_func(arr, x, y): - for item in arr: - if item["reply_func"] is x: - item["reply_func"] = y - - _replace_reply_func( - self._reply_func_list, ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply - ) - _replace_reply_func( - self._reply_func_list, + self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply) + self.replace_reply_func( ConversableAgent.a_generate_oai_reply, MultimodalConversableAgent.a_generate_oai_reply, ) diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index 01b51362374..1ece138963f 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -1,17 +1,21 @@ from typing import Callable, Dict, List, Optional from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent -from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS -import logging +from autogen.agentchat.contrib.vectordb.utils import ( + chroma_results_to_query_results, + filter_results_by_distance, + get_logger, +) +from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks -logger = logging.getLogger(__name__) +logger = get_logger(__name__) try: + import fastembed from qdrant_client import QdrantClient, models from qdrant_client.fastembed_common import QueryResponse - import fastembed except ImportError as e: - logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'") + logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'") raise e @@ -136,6 +140,11 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = collection_name=self._collection_name, embedding_model=self._embedding_model, ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results, "distances") + results = filter_results_by_distance(results, self._distance_threshold) + + self._search_string = search_string self._results = results @@ -190,12 +199,12 @@ def create_qdrant_from_dir( client.set_model(embedding_model) if custom_text_split_function is not None: - chunks = split_files_to_chunks( + chunks, sources = split_files_to_chunks( get_files_from_dir(dir_path, custom_text_types, recursive), custom_text_split_function=custom_text_split_function, ) else: - chunks = split_files_to_chunks( + chunks, sources = split_files_to_chunks( get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line ) logger.info(f"Found {len(chunks)} chunks.") @@ -281,20 +290,24 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore collection_name, query_texts, limit=n_results, - query_filter=models.Filter( - must=[ - models.FieldCondition( - key="document", - match=models.MatchText(text=search_string), - ) - ] - ) - if search_string - else None, + query_filter=( + models.Filter( + must=[ + models.FieldCondition( + key="document", + match=models.MatchText(text=search_string), + ) + ] + ) + if search_string + else None + ), ) data = { "ids": [[result.id for result in sublist] for sublist in results], "documents": [[result.document for result in sublist] for sublist in results], + "distances": [[result.score for result in sublist] for sublist in results], + "metadatas": [[result.metadata for result in sublist] for sublist in results], } return data diff --git a/autogen/agentchat/contrib/retrieve_assistant_agent.py b/autogen/agentchat/contrib/retrieve_assistant_agent.py index a09677710aa..9b5ace200dc 100644 --- a/autogen/agentchat/contrib/retrieve_assistant_agent.py +++ b/autogen/agentchat/contrib/retrieve_assistant_agent.py @@ -1,6 +1,7 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + from autogen.agentchat.agent import Agent from autogen.agentchat.assistant_agent import AssistantAgent -from typing import Dict, Optional, Union, List, Tuple, Any class RetrieveAssistantAgent(AssistantAgent): diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index b6ec6363096..476c7c0739d 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -1,19 +1,35 @@ +import hashlib +import os import re -from typing import Callable, Dict, Optional, Union, List, Tuple, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + from IPython import get_ipython try: import chromadb except ImportError: raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`") -from autogen.agentchat.agent import Agent from autogen.agentchat import UserProxyAgent -from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS -from autogen.token_count_utils import count_token +from autogen.agentchat.agent import Agent +from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory +from autogen.agentchat.contrib.vectordb.utils import ( + chroma_results_to_query_results, + filter_results_by_distance, + get_logger, +) from autogen.code_utils import extract_code -from autogen import logger -from autogen.agentchat.conversable_agent import colored +from autogen.retrieve_utils import ( + TEXT_FORMATS, + create_vector_db_from_dir, + get_files_from_dir, + query_vector_db, + split_files_to_chunks, +) +from autogen.token_count_utils import count_token + +from ...formatting_utils import colored +logger = get_logger(__name__) PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the context provided by the user. You should follow the following steps to answer a question: @@ -33,6 +49,10 @@ User's question is: {input_question} Context is: {input_context} + +The source of the context is: {input_sources} + +If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`. """ PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the @@ -60,8 +80,14 @@ Context is: {input_context} """ +HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8)) + class RetrieveUserProxyAgent(UserProxyAgent): + """(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding + similarity, and sends them along with the question to the Retrieval-Augmented Assistant + """ + def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent @@ -73,67 +99,126 @@ def __init__( r""" Args: name (str): name of the agent. + human_input_mode (str): whether to ask for human inputs every time a message is received. Possible values are "ALWAYS", "TERMINATE", "NEVER". 1. When "ALWAYS", the agent prompts for human input every time a message is received. Under this mode, the conversation stops when the human input is "exit", or when is_termination_msg is True and there is no human input. - 2. When "TERMINATE", the agent only prompts for human input only when a termination message is received or - the number of auto reply reaches the max_consecutive_auto_reply. - 3. When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops - when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + 2. When "TERMINATE", the agent only prompts for human input only when a termination + message is received or the number of auto reply reaches + the max_consecutive_auto_reply. + 3. When "NEVER", the agent will never prompt for human input. Under this mode, the + conversation stops when the number of auto reply reaches the + max_consecutive_auto_reply or when is_termination_msg is True. + is_termination_msg (function): a function that takes a message in the form of a dictionary and returns a boolean value indicating if this received message is a termination message. The dict can contain the following keys: "content", "role", "name", "function_call". + retrieve_config (dict or None): config for the retrieve agent. - To use default config, set to None. Otherwise, set to a dictionary with the following keys: - - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System - prompt will be different for different tasks. The default value is `default`, which supports both code and qa. - - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()` - will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. - - docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file, - the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created. - - extra_docs (Optional, bool): when true, allows adding documents with unique IDs without overwriting existing ones; when false, it replaces existing documents using default IDs, risking collection overwrite., - when set to true it enables the system to assign unique IDs starting from "length+i" for new document chunks, preventing the replacement of existing documents and facilitating the addition of more content to the collection.. - By default, "extra_docs" is set to false, starting document IDs from zero. This poses a risk as new documents might overwrite existing ones, potentially causing unintended loss or alteration of data in the collection. - - collection_name (Optional, str): the name of the collection. - If key not provided, a default name `autogen-docs` will be used. - - model (Optional, str): the model to use for the retrieve chat. + + To use default config, set to None. Otherwise, set to a dictionary with the + following keys: + - `task` (Optional, str) - the task of the retrieve chat. Possible values are + "code", "qa" and "default". System prompt will be different for different tasks. + The default value is `default`, which supports both code and qa, and provides + source information in the end of the response. + - `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat. + If it's a string, it should be the type of the vector db, such as "chroma"; otherwise, + it should be an instance of the VectorDB protocol. Default is "chroma". + Set `None` to use the deprecated `client`. + - `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make + sure you understand the config for the vector db you are using, otherwise, leave it as `{}`. + Only valid when `vector_db` is a string. + - `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a + default client `chromadb.Client()` will be used. If you want to use other + vector db, extend this class and override the `retrieve_docs` function. + **Deprecated**: use `vector_db` instead. + - `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It + can also be the path to a single file, the url to a single file or a list + of directories, files and urls. Default is None, which works only if the + collection is already created. + - `extra_docs` (Optional, bool) - when true, allows adding documents with unique IDs + without overwriting existing ones; when false, it replaces existing documents + using default IDs, risking collection overwrite., when set to true it enables + the system to assign unique IDs starting from "length+i" for new document + chunks, preventing the replacement of existing documents and facilitating the + addition of more content to the collection.. + By default, "extra_docs" is set to false, starting document IDs from zero. + This poses a risk as new documents might overwrite existing ones, potentially + causing unintended loss or alteration of data in the collection. + **Deprecated**: use `new_docs` when use `vector_db` instead of `client`. + - `new_docs` (Optional, bool) - when True, only adds new documents to the collection; + when False, updates existing documents and adds new ones. Default is True. + Document id is used to determine if a document is new or existing. By default, the + id is the hash value of the content. + - `model` (Optional, str) - the model to use for the retrieve chat. If key not provided, a default model `gpt-4` will be used. - - chunk_token_size (Optional, int): the chunk token size for the retrieve chat. + - `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat. If key not provided, a default size `max_tokens * 0.4` will be used. - - context_max_tokens (Optional, int): the context max token size for the retrieve chat. + - `context_max_tokens` (Optional, int) - the context max token size for the + retrieve chat. If key not provided, a default size `max_tokens * 0.8` will be used. - - chunk_mode (Optional, str): the chunk mode for the retrieve chat. Possible values are - "multi_lines" and "one_line". If key not provided, a default mode `multi_lines` will be used. - - must_break_at_empty_line (Optional, bool): chunk will only break at empty line if True. Default is True. + - `chunk_mode` (Optional, str) - the chunk mode for the retrieve chat. Possible values + are "multi_lines" and "one_line". If key not provided, a default mode + `multi_lines` will be used. + - `must_break_at_empty_line` (Optional, bool) - chunk will only break at empty line + if True. Default is True. If chunk_mode is "one_line", this parameter will be ignored. - - embedding_model (Optional, str): the embedding model to use for the retrieve chat. - If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models - can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a - fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended. - - embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None, - SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or - other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. - - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. - - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". - If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - - get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb. - Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None. - - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string. - The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function. - Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models. - - custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings. - Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`. - - custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`. - This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types. - - recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True. + - `embedding_model` (Optional, str) - the embedding model to use for the retrieve chat. + If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available + models can be found at `https://www.sbert.net/docs/pretrained_models.html`. + The default model is a fast model. If you want to use a high performance model, + `all-mpnet-base-v2` is recommended. + **Deprecated**: no need when use `vector_db` instead of `client`. + - `embedding_function` (Optional, Callable) - the embedding function for creating the + vector db. Default is None, SentenceTransformer with the given `embedding_model` + will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding + functions, you can pass it here, + follow the examples in `https://docs.trychroma.com/embeddings`. + - `customized_prompt` (Optional, str) - the customized prompt for the retrieve chat. + Default is None. + - `customized_answer_prefix` (Optional, str) - the customized answer prefix for the + retrieve chat. Default is "". + If not "" and the customized_answer_prefix is not in the answer, + `Update Context` will be triggered. + - `update_context` (Optional, bool) - if False, will not apply `Update Context` for + interactive retrieval. Default is True. + - `collection_name` (Optional, str) - the name of the collection. + If key not provided, a default name `autogen-docs` will be used. + - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True. + - `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + - `custom_token_count_function` (Optional, Callable) - a custom function to count the + number of tokens in a string. + The function should take (text:str, model:str) as input and return the + token_count(int). the retrieve_config["model"] will be passed in the function. + Default is autogen.token_count_utils.count_token that uses tiktoken, which may + not be accurate for non-OpenAI models. + - `custom_text_split_function` (Optional, Callable) - a custom function to split a + string into a list of strings. + Default is None, will use the default function in + `autogen.retrieve_utils.split_text_to_chunks`. + - `custom_text_types` (Optional, List[str]) - a list of file types to be processed. + Default is `autogen.retrieve_utils.TEXT_FORMATS`. + This only applies to files under the directories in `docs_path`. Explicitly + included files and urls will be chunked regardless of their types. + - `recursive` (Optional, bool) - whether to search documents recursively in the + docs_path. Default is True. + - `distance_threshold` (Optional, float) - the threshold for the distance score, only + distance smaller than it will be returned. Will be ignored if < 0. Default is -1. + `**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). Example: - Example of overriding retrieve_docs - If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code. + Example of overriding retrieve_docs - If you have set up a customized vector db, and it's + not compatible with chromadb, you can easily plug in it with below code. + **Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent. ```python class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): def query_vector_db( @@ -166,9 +251,12 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self._retrieve_config = {} if retrieve_config is None else retrieve_config self._task = self._retrieve_config.get("task", "default") + self._vector_db = self._retrieve_config.get("vector_db", "chroma") + self._db_config = self._retrieve_config.get("db_config", {}) self._client = self._retrieve_config.get("client", chromadb.Client()) self._docs_path = self._retrieve_config.get("docs_path", None) self._extra_docs = self._retrieve_config.get("extra_docs", False) + self._new_docs = self._retrieve_config.get("new_docs", True) self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") if "docs_path" not in self._retrieve_config: logger.warning( @@ -187,25 +275,104 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True + self._overwrite = self._retrieve_config.get("overwrite", False) self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token) self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS) self._recursive = self._retrieve_config.get("recursive", True) - self._context_max_tokens = self._max_tokens * 0.8 + self._context_max_tokens = self._retrieve_config.get("context_max_tokens", self._max_tokens * 0.8) self._collection = True if self._docs_path is None else False # whether the collection is created self._ipython = get_ipython() self._doc_idx = -1 # the index of the current used doc - self._results = {} # the results of the current query + self._results = [] # the results of the current query self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc + self._current_docs_in_context = [] # the ids of the current context sources self._search_string = "" # the search string used in the current query + self._distance_threshold = self._retrieve_config.get("distance_threshold", -1) # update the termination message function self._is_termination_msg = ( self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg ) + if isinstance(self._vector_db, str): + if not isinstance(self._db_config, dict): + raise ValueError("`db_config` should be a dictionary.") + if "embedding_function" in self._retrieve_config: + self._db_config["embedding_function"] = self._embedding_function + self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config) self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2) + def _init_db(self): + if not self._vector_db: + return + + IS_TO_CHUNK = False # whether to chunk the raw files + if self._new_docs: + IS_TO_CHUNK = True + if not self._docs_path: + try: + self._vector_db.get_collection(self._collection_name) + logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.") + self._overwrite = False + self._get_or_create = True + IS_TO_CHUNK = False + except ValueError: + raise ValueError( + "`docs_path` is not provided. " + f"The collection `{self._collection_name}` doesn't exist either. " + "Please provide `docs_path` or create the collection first." + ) + elif self._get_or_create and not self._overwrite: + try: + self._vector_db.get_collection(self._collection_name) + logger.info(f"Use the existing collection `{self._collection_name}`.", color="green") + except ValueError: + IS_TO_CHUNK = True + else: + IS_TO_CHUNK = True + + self._vector_db.active_collection = self._vector_db.create_collection( + self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create + ) + + docs = None + if IS_TO_CHUNK: + if self.custom_text_split_function is not None: + chunks, sources = split_files_to_chunks( + get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive), + custom_text_split_function=self.custom_text_split_function, + ) + else: + chunks, sources = split_files_to_chunks( + get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive), + self._max_tokens, + self._chunk_mode, + self._must_break_at_empty_line, + ) + logger.info(f"Found {len(chunks)} chunks.") + + if self._new_docs: + all_docs_ids = set( + [ + doc["id"] + for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name) + ] + ) + else: + all_docs_ids = set() + + chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks] + chunk_ids_set = set(chunk_ids) + chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set] + docs = [ + Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx]) + for idx in chunk_ids_set_idx + if chunk_ids[idx] not in all_docs_ids + ] + + self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True) + def _is_termination_msg_retrievechat(self, message): """Check if a message is a termination message. For code generation, terminate when no code block is detected. Currently only detect python code blocks. @@ -238,37 +405,42 @@ def get_max_tokens(model="gpt-3.5-turbo"): def _reset(self, intermediate=False): self._doc_idx = -1 # the index of the current used doc - self._results = {} # the results of the current query + self._results = [] # the results of the current query if not intermediate: self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]): + def _get_context(self, results: QueryResults): doc_contents = "" + self._current_docs_in_context = [] current_tokens = 0 _doc_idx = self._doc_idx _tmp_retrieve_count = 0 - for idx, doc in enumerate(results["documents"][0]): + for idx, doc in enumerate(results[0]): + doc = doc[0] if idx <= _doc_idx: continue - if results["ids"][0][idx] in self._doc_ids: + if doc["id"] in self._doc_ids: continue - _doc_tokens = self.custom_token_count_function(doc, self._model) + _doc_tokens = self.custom_token_count_function(doc["content"], self._model) if _doc_tokens > self._context_max_tokens: - func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context." + func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context." print(colored(func_print, "green"), flush=True) self._doc_idx = idx continue if current_tokens + _doc_tokens > self._context_max_tokens: break - func_print = f"Adding doc_id {results['ids'][0][idx]} to context." + func_print = f"Adding content of doc {doc['id']} to context." print(colored(func_print, "green"), flush=True) current_tokens += _doc_tokens - doc_contents += doc + "\n" + doc_contents += doc["content"] + "\n" + _metadata = doc.get("metadata") + if isinstance(_metadata, dict): + self._current_docs_in_context.append(_metadata.get("source", "")) self._doc_idx = idx - self._doc_ids.append(results["ids"][0][idx]) - self._doc_contents.append(doc) + self._doc_ids.append(doc["id"]) + self._doc_contents.append(doc["content"]) _tmp_retrieve_count += 1 if _tmp_retrieve_count >= self.n_results: break @@ -285,7 +457,9 @@ def _generate_message(self, doc_contents, task="default"): elif task.upper() == "QA": message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents) elif task.upper() == "DEFAULT": - message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents) + message = PROMPT_DEFAULT.format( + input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context + ) else: raise NotImplementedError(f"task {task} is not implemented.") return message @@ -360,21 +534,40 @@ def _generate_retrieve_user_reply( def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): """Retrieve docs based on the given problem and assign the results to the class property `_results`. - In case you want to customize the retrieval process, such as using a different vector db whose APIs are not - compatible with chromadb or filter results with metadata, you can override this function. Just keep the current - parameters and add your own parameters with default values, and keep the results in below type. - - Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of - the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer - to `chromadb.api.types.QueryResult` as an example. - ids: List[string] - documents: List[List[string]] + The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and + the distance. Args: problem (str): the problem to be solved. n_results (int): the number of results to be retrieved. Default is 20. search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "". + Not used if the vector_db doesn't support it. + + Returns: + None. """ + if isinstance(self._vector_db, VectorDB): + if not self._collection or not self._get_or_create: + print("Trying to create collection.") + self._init_db() + self._collection = True + self._get_or_create = True + + kwargs = {} + if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma": + kwargs["where_document"] = {"$contains": search_string} if search_string else None + results = self._vector_db.retrieve_docs( + queries=[problem], + n_results=n_results, + collection_name=self._collection_name, + distance_threshold=self._distance_threshold, + **kwargs, + ) + self._search_string = search_string + self._results = results + print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results]) + return + if not self._collection or not self._get_or_create: print("Trying to create collection.") self._client = create_vector_db_from_dir( @@ -404,27 +597,39 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = embedding_model=self._embedding_model, embedding_function=self._embedding_function, ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results, "distances") + results = filter_results_by_distance(results, self._distance_threshold) + self._search_string = search_string self._results = results - print("doc_ids: ", results["ids"]) - - def generate_init_message(self, problem: str, n_results: int = 20, search_string: str = ""): - """Generate an initial message with the given problem and prompt. + print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results]) + @staticmethod + def message_generator(sender, recipient, context): + """ + Generate an initial message with the given context for the RetrieveUserProxyAgent. Args: - problem (str): the problem to be solved. - n_results (int): the number of results to be retrieved. - search_string (str): only docs containing this string will be retrieved. - + sender (Agent): the sender agent. It should be the instance of RetrieveUserProxyAgent. + recipient (Agent): the recipient agent. Usually it's the assistant agent. + context (dict): the context for the message generation. It should contain the following keys: + - `problem` (str) - the problem to be solved. + - `n_results` (int) - the number of results to be retrieved. Default is 20. + - `search_string` (str) - only docs that contain an exact match of this string will be retrieved. Default is "". Returns: - str: the generated prompt ready to be sent to the assistant agent. + str: the generated message ready to be sent to the recipient agent. """ - self._reset() - self.retrieve_docs(problem, n_results, search_string) - self.problem = problem - self.n_results = n_results - doc_contents = self._get_context(self._results) - message = self._generate_message(doc_contents, self._task) + sender._reset() + + problem = context.get("problem", "") + n_results = context.get("n_results", 20) + search_string = context.get("search_string", "") + + sender.retrieve_docs(problem, n_results, search_string) + sender.problem = problem + sender.n_results = n_results + doc_contents = sender._get_context(sender._results) + message = sender._generate_message(doc_contents, sender._task) return message def run_code(self, code, **kwargs): diff --git a/autogen/agentchat/contrib/society_of_mind_agent.py b/autogen/agentchat/contrib/society_of_mind_agent.py index 6a6f4aa2186..97cf6aee1a5 100644 --- a/autogen/agentchat/contrib/society_of_mind_agent.py +++ b/autogen/agentchat/contrib/society_of_mind_agent.py @@ -1,10 +1,11 @@ # ruff: noqa: E722 +import copy import json import traceback -import copy from dataclasses import dataclass -from typing import Dict, List, Optional, Union, Callable, Literal, Tuple -from autogen import Agent, ConversableAgent, GroupChatManager, GroupChat, OpenAIWrapper +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union + +from autogen import Agent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper class SocietyOfMindAgent(ConversableAgent): diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py index 10100c9e57f..e917cca574f 100644 --- a/autogen/agentchat/contrib/text_analyzer_agent.py +++ b/autogen/agentchat/contrib/text_analyzer_agent.py @@ -1,7 +1,8 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + from autogen import oai from autogen.agentchat.agent import Agent from autogen.agentchat.assistant_agent import ConversableAgent -from typing import Callable, Dict, Optional, Union, List, Tuple, Any system_message = """You are an expert in text analysis. The user will give you TEXT to analyze. diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py new file mode 100644 index 00000000000..29a08008619 --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -0,0 +1,213 @@ +from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable + +Metadata = Union[Mapping[str, Any], None] +Vector = Union[Sequence[float], Sequence[int]] +ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does + + +class Document(TypedDict): + """A Document is a record in the vector database. + + id: ItemID | the unique identifier of the document. + content: str | the text content of the chunk. + metadata: Metadata, Optional | contains additional information about the document such as source, date, etc. + embedding: Vector, Optional | the vector representation of the content. + """ + + id: ItemID + content: str + metadata: Optional[Metadata] + embedding: Optional[Vector] + + +"""QueryResults is the response from the vector database for a query/queries. +A query is a list containing one string while queries is a list containing multiple strings. +The response is a list of query results, each query result is a list of tuples containing the document and the distance. +""" +QueryResults = List[List[Tuple[Document, float]]] + + +@runtime_checkable +class VectorDB(Protocol): + """ + Abstract class for vector database. A vector database is responsible for storing and retrieving documents. + + Attributes: + active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None. + type: str | The type of the vector database, chroma, pgvector, etc. Default is "". + + Methods: + create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database. + get_collection: Callable[[str], Any] | Get the collection from the vector database. + delete_collection: Callable[[str], Any] | Delete the collection from the vector database. + insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database. + update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database. + delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database. + retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries. + get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids. + """ + + active_collection: Any = None + type: str = "" + + def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Any | The collection object. + """ + ... + + def get_collection(self, collection_name: str = None) -> Any: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Any | The collection object. + """ + ... + + def delete_collection(self, collection_name: str) -> Any: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + Any + """ + ... + + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + ... + + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + ... + + +class VectorDBFactory: + """ + Factory class for creating vector databases. + """ + + PREDEFINED_VECTOR_DB = ["chroma", "pgvector"] + + @staticmethod + def create_vector_db(db_type: str, **kwargs) -> VectorDB: + """ + Create a vector database. + + Args: + db_type: str | The type of the vector database. + kwargs: Dict | The keyword arguments for initializing the vector database. + + Returns: + VectorDB | The vector database. + """ + if db_type.lower() in ["chroma", "chromadb"]: + from .chromadb import ChromaVectorDB + + return ChromaVectorDB(**kwargs) + if db_type.lower() in ["pgvector", "pgvectordb"]: + from .pgvectordb import PGVectorDB + + return PGVectorDB(**kwargs) + else: + raise ValueError( + f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}." + ) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py new file mode 100644 index 00000000000..3f1fbc86a44 --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -0,0 +1,318 @@ +import os +from typing import Callable, List + +from .base import Document, ItemID, QueryResults, VectorDB +from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger + +try: + import chromadb + + if chromadb.__version__ < "0.4.15": + raise ImportError("Please upgrade chromadb to version 0.4.15 or later.") + import chromadb.utils.embedding_functions as ef + from chromadb.api.models.Collection import Collection +except ImportError: + raise ImportError("Please install chromadb: `pip install chromadb`") + +CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000) +logger = get_logger(__name__) + + +class ChromaVectorDB(VectorDB): + """ + A vector database that uses ChromaDB as the backend. + """ + + def __init__( + self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs + ) -> None: + """ + Initialize the vector database. + + Args: + client: chromadb.Client | The client object of the vector database. Default is None. + If provided, it will use the client object directly and ignore other arguments. + path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24. + embedding_function: Callable | The embedding function used to generate the vector representation + of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used. + metadata: dict | The metadata of the vector database. Default is None. If None, it will use this + setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of + the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), + [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), + and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + self.client = client + self.path = path + self.embedding_function = ( + ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") + if embedding_function is None + else embedding_function + ) + self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32} + if not self.client: + if self.path is not None: + self.client = chromadb.PersistentClient(path=self.path, **kwargs) + else: + self.client = chromadb.Client(**kwargs) + self.active_collection = None + self.type = "chroma" + + def create_collection( + self, collection_name: str, overwrite: bool = False, get_or_create: bool = True + ) -> Collection: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Collection | The collection object. + """ + try: + if self.active_collection and self.active_collection.name == collection_name: + collection = self.active_collection + else: + collection = self.client.get_collection(collection_name) + except ValueError: + collection = None + if collection is None: + return self.client.create_collection( + collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + elif overwrite: + self.client.delete_collection(collection_name) + return self.client.create_collection( + collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + elif get_or_create: + return collection + else: + raise ValueError(f"Collection {collection_name} already exists.") + + def get_collection(self, collection_name: str = None) -> Collection: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Collection | The collection object. + """ + if collection_name is None: + if self.active_collection is None: + raise ValueError("No collection is specified.") + else: + logger.info( + f"No collection is specified. Using current active collection {self.active_collection.name}." + ) + else: + if not (self.active_collection and self.active_collection.name == collection_name): + self.active_collection = self.client.get_collection(collection_name) + return self.active_collection + + def delete_collection(self, collection_name: str) -> None: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + None + """ + self.client.delete_collection(collection_name) + if self.active_collection and self.active_collection.name == collection_name: + self.active_collection = None + + def _batch_insert( + self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False + ) -> None: + batch_size = int(CHROMADB_MAX_BATCH_SIZE) + for i in range(0, len(documents), min(batch_size, len(documents))): + end_idx = i + min(batch_size, len(documents) - i) + collection_kwargs = { + "documents": documents[i:end_idx], + "ids": ids[i:end_idx], + "metadatas": metadatas[i:end_idx] if metadatas else None, + "embeddings": embeddings[i:end_idx] if embeddings else None, + } + if upsert: + collection.upsert(**collection_kwargs) + else: + collection.add(**collection_kwargs) + + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + if not docs: + return + if docs[0].get("content") is None: + raise ValueError("The document content is required.") + if docs[0].get("id") is None: + raise ValueError("The document id is required.") + documents = [doc.get("content") for doc in docs] + ids = [doc.get("id") for doc in docs] + collection = self.get_collection(collection_name) + if docs[0].get("embedding") is None: + logger.info( + "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding." + ) + embeddings = None + else: + embeddings = [doc.get("embedding") for doc in docs] + if docs[0].get("metadata") is None: + metadatas = None + else: + metadatas = [doc.get("metadata") for doc in docs] + self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) + + def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: str | The name of the collection. Default is None. + + Returns: + None + """ + self.insert_docs(docs, collection_name, upsert=True) + + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + collection = self.get_collection(collection_name) + collection.delete(ids, **kwargs) + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + collection = self.get_collection(collection_name) + if isinstance(queries, str): + queries = [queries] + results = collection.query( + query_texts=queries, + n_results=n_results, + **kwargs, + ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results) + results = filter_results_by_distance(results, distance_threshold) + return results + + @staticmethod + def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: + """Converts a dictionary with list values to a list of Document. + + Args: + data_dict: A dictionary where keys map to lists or None. + + Returns: + List[Document] | The list of Document. + + Example: + data_dict = { + "key1s": [1, 2, 3], + "key2s": ["a", "b", "c"], + "key3s": None, + "key4s": ["x", "y", "z"], + } + + results = [ + {"key1": 1, "key2": "a", "key4": "x"}, + {"key1": 2, "key2": "b", "key4": "y"}, + {"key1": 3, "key2": "c", "key4": "z"}, + ] + """ + + results = [] + keys = [key for key in data_dict if data_dict[key] is not None] + + for i in range(len(data_dict[keys[0]])): + sub_dict = {} + for key in data_dict.keys(): + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i] + results.append(sub_dict) + return results + + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + collection = self.get_collection(collection_name) + include = include if include else ["metadatas", "documents"] + results = collection.get(ids, include=include, **kwargs) + results = self._chroma_get_results_to_list_documents(results) + return results diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py new file mode 100644 index 00000000000..ae9d5cbbbec --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py @@ -0,0 +1,736 @@ +import os +import re +from typing import Callable, List + +import numpy as np +from sentence_transformers import SentenceTransformer + +from .base import Document, ItemID, QueryResults, VectorDB +from .utils import get_logger + +try: + import pgvector + from pgvector.psycopg import register_vector +except ImportError: + raise ImportError("Please install pgvector: `pip install pgvector`") + +try: + import psycopg +except ImportError: + raise ImportError("Please install pgvector: `pip install psycopg`") + +PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000) +logger = get_logger(__name__) + + +class Collection: + """ + A Collection object for PGVector. + + Attributes: + client: The PGVector client. + collection_name (str): The name of the collection. Default is "documents". + embedding_function (Callable): The embedding function used to generate the vector representation. + metadata (Optional[dict]): The metadata of the collection. + get_or_create (Optional): The flag indicating whether to get or create the collection. + + """ + + def __init__( + self, + client=None, + collection_name: str = "autogen-docs", + embedding_function: Callable = None, + metadata=None, + get_or_create=None, + ): + """ + Initialize the Collection object. + + Args: + client: The PostgreSQL client. + collection_name: The name of the collection. Default is "documents". + embedding_function: The embedding function used to generate the vector representation. + metadata: The metadata of the collection. + get_or_create: The flag indicating whether to get or create the collection. + + Returns: + None + """ + self.client = client + self.embedding_function = embedding_function + self.name = self.set_collection_name(collection_name) + self.require_embeddings_or_documents = False + self.ids = [] + self.embedding_function = ( + SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function + ) + self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16} + self.documents = "" + self.get_or_create = get_or_create + + def set_collection_name(self, collection_name): + name = re.sub("-", "_", collection_name) + self.name = name + return self.name + + def add(self, ids: List[ItemID], embeddings: List, metadatas: List, documents: List): + """ + Add documents to the collection. + + Args: + ids (List[ItemID]): A list of document IDs. + embeddings (List): A list of document embeddings. + metadatas (List): A list of document metadatas. + documents (List): A list of documents. + + Returns: + None + """ + cursor = self.client.cursor() + + sql_values = [] + for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents): + sql_values.append((doc_id, embedding, metadata, document)) + sql_string = f"INSERT INTO {self.name} (id, embedding, metadata, document) " f"VALUES (%s, %s, %s, %s);" + cursor.executemany(sql_string, sql_values) + cursor.close() + + def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None: + """ + Upsert documents into the collection. + + Args: + ids (List[ItemID]): A list of document IDs. + documents (List): A list of documents. + embeddings (List): A list of document embeddings. + metadatas (List): A list of document metadatas. + + Returns: + None + """ + cursor = self.client.cursor() + sql_values = [] + if embeddings is not None and metadatas is not None: + for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents): + metadata = re.sub("'", '"', str(metadata)) + sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document)) + sql_string = ( + f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" + f"VALUES (%s, %s, %s, %s)\n" + f"ON CONFLICT (id)\n" + f"DO UPDATE SET embedding = %s,\n" + f"metadatas = %s, documents = %s;\n" + ) + elif embeddings is not None: + for doc_id, embedding, document in zip(ids, embeddings, documents): + sql_values.append((doc_id, embedding, document, embedding, document)) + sql_string = ( + f"INSERT INTO {self.name} (id, embedding, documents) " + f"VALUES (%s, %s, %s) ON CONFLICT (id)\n" + f"DO UPDATE SET embedding = %s, documents = %s;\n" + ) + elif metadatas is not None: + for doc_id, metadata, document in zip(ids, metadatas, documents): + metadata = re.sub("'", '"', str(metadata)) + embedding = self.embedding_function.encode(document) + sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding)) + sql_string = ( + f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" + f"VALUES (%s, %s, %s, %s)\n" + f"ON CONFLICT (id)\n" + f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n" + ) + else: + for doc_id, document in zip(ids, documents): + embedding = self.embedding_function.encode(document) + sql_values.append((doc_id, document, embedding, document)) + sql_string = ( + f"INSERT INTO {self.name} (id, documents, embedding)\n" + f"VALUES (%s, %s, %s)\n" + f"ON CONFLICT (id)\n" + f"DO UPDATE SET documents = %s;\n" + ) + logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}") + cursor.executemany(sql_string, sql_values) + cursor.close() + + def count(self): + """ + Get the total number of documents in the collection. + + Returns: + int: The total number of documents. + """ + cursor = self.client.cursor() + query = f"SELECT COUNT(*) FROM {self.name}" + cursor.execute(query) + total = cursor.fetchone()[0] + cursor.close() + try: + total = int(total) + except (TypeError, ValueError): + total = None + return total + + def get(self, ids=None, include=None, where=None, limit=None, offset=None): + """ + Retrieve documents from the collection. + + Args: + ids (Optional[List]): A list of document IDs. + include (Optional): The fields to include. + where (Optional): Additional filtering criteria. + limit (Optional): The maximum number of documents to retrieve. + offset (Optional): The offset for pagination. + + Returns: + List: The retrieved documents. + """ + cursor = self.client.cursor() + if include: + query = f'SELECT (id, {", ".join(map(str, include))}, embedding) FROM {self.name}' + else: + query = f"SELECT * FROM {self.name}" + if ids: + query = f"{query} WHERE id IN {ids}" + elif where: + query = f"{query} WHERE {where}" + if offset: + query = f"{query} OFFSET {offset}" + if limit: + query = f"{query} LIMIT {limit}" + retreived_documents = [] + try: + cursor.execute(query) + retrieval = cursor.fetchall() + for retrieved_document in retrieval: + retreived_documents.append( + Document( + id=retrieved_document[0][0], + metadata=retrieved_document[0][1], + content=retrieved_document[0][2], + embedding=retrieved_document[0][3], + ) + ) + except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn): + logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead.") + self.create_collection(collection_name=self.name) + logger.info(f"Created table {self.name}") + cursor.close() + return retreived_documents + + def update(self, ids: List, embeddings: List, metadatas: List, documents: List): + """ + Update documents in the collection. + + Args: + ids (List): A list of document IDs. + embeddings (List): A list of document embeddings. + metadatas (List): A list of document metadatas. + documents (List): A list of documents. + + Returns: + None + """ + cursor = self.client.cursor() + sql_values = [] + for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents): + sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document)) + sql_string = ( + f"INSERT INTO {self.name} (id, embedding, metadata, document) " + f"VALUES (%s, %s, %s, %s) " + f"ON CONFLICT (id) " + f"DO UPDATE SET id = %s, embedding = %s, " + f"metadata = %s, document = %s;\n" + ) + logger.debug(f"Upsert SQL String:\n{sql_string}\n") + cursor.executemany(sql_string, sql_values) + cursor.close() + + @staticmethod + def euclidean_distance(arr1: List[float], arr2: List[float]) -> float: + """ + Calculate the Euclidean distance between two vectors. + + Parameters: + - arr1 (List[float]): The first vector. + - arr2 (List[float]): The second vector. + + Returns: + - float: The Euclidean distance between arr1 and arr2. + """ + dist = np.linalg.norm(arr1 - arr2) + return dist + + @staticmethod + def cosine_distance(arr1: List[float], arr2: List[float]) -> float: + """ + Calculate the cosine distance between two vectors. + + Parameters: + - arr1 (List[float]): The first vector. + - arr2 (List[float]): The second vector. + + Returns: + - float: The cosine distance between arr1 and arr2. + """ + dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2)) + return dist + + @staticmethod + def inner_product_distance(arr1: List[float], arr2: List[float]) -> float: + """ + Calculate the Euclidean distance between two vectors. + + Parameters: + - arr1 (List[float]): The first vector. + - arr2 (List[float]): The second vector. + + Returns: + - float: The Euclidean distance between arr1 and arr2. + """ + dist = np.linalg.norm(arr1 - arr2) + return dist + + def query( + self, + query_texts: List[str], + collection_name: str = None, + n_results: int = 10, + distance_type: str = "euclidean", + distance_threshold: float = -1, + ) -> QueryResults: + """ + Query documents in the collection. + + Args: + query_texts (List[str]): A list of query texts. + collection_name (Optional[str]): The name of the collection. + n_results (int): The maximum number of results to return. + distance_type (Optional[str]): Distance search type - euclidean or cosine + distance_threshold (Optional[float]): Distance threshold to limit searches + Returns: + QueryResults: The query results. + """ + if collection_name: + self.name = collection_name + + if distance_threshold == -1: + distance_threshold = "" + elif distance_threshold > 0: + distance_threshold = f"< {distance_threshold}" + + cursor = self.client.cursor() + results = [] + for query in query_texts: + vector = self.embedding_function.encode(query, convert_to_tensor=False).tolist() + if distance_type.lower() == "cosine": + index_function = "<=>" + elif distance_type.lower() == "euclidean": + index_function = "<->" + elif distance_type.lower() == "inner-product": + index_function = "<#>" + else: + index_function = "<->" + + query = ( + f"SELECT id, documents, embedding, metadatas FROM {self.name}\n" + f"ORDER BY embedding {index_function} '{str(vector)}'::vector {distance_threshold}\n" + f"LIMIT {n_results}" + ) + cursor.execute(query) + for row in cursor.fetchall(): + fetched_document = Document(id=row[0], content=row[1], embedding=row[2], metadata=row[3]) + fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding")) + if distance_type.lower() == "cosine": + distance = self.cosine_distance(fetched_document_array, vector) + elif distance_type.lower() == "euclidean": + distance = self.euclidean_distance(fetched_document_array, vector) + elif distance_type.lower() == "inner-product": + distance = self.inner_product_distance(fetched_document_array, vector) + else: + distance = self.euclidean_distance(fetched_document_array, vector) + results.append((fetched_document, distance)) + cursor.close() + results = [results] + logger.debug(f"Query Results: {results}") + return results + + @staticmethod + def convert_string_to_array(array_string) -> List[float]: + """ + Convert a string representation of an array to a list of floats. + + Parameters: + - array_string (str): The string representation of the array. + + Returns: + - list: A list of floats parsed from the input string. If the input is + not a string, it returns the input itself. + """ + if not isinstance(array_string, str): + return array_string + array_string = array_string.strip("[]") + array = [float(num) for num in array_string.split()] + return array + + def modify(self, metadata, collection_name: str = None): + """ + Modify metadata for the collection. + + Args: + collection_name: The name of the collection. + metadata: The new metadata. + + Returns: + None + """ + if collection_name: + self.name = collection_name + cursor = self.client.cursor() + cursor.execute( + "UPDATE collections" "SET metadata = '%s'" "WHERE collection_name = '%s';", (metadata, self.name) + ) + cursor.close() + + def delete(self, ids: List[ItemID], collection_name: str = None): + """ + Delete documents from the collection. + + Args: + ids (List[ItemID]): A list of document IDs to delete. + collection_name (str): The name of the collection to delete. + + Returns: + None + """ + if collection_name: + self.name = collection_name + cursor = self.client.cursor() + cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({ids});") + cursor.close() + + def delete_collection(self, collection_name: str = None): + """ + Delete the entire collection. + + Args: + collection_name (Optional[str]): The name of the collection to delete. + + Returns: + None + """ + if collection_name: + self.name = collection_name + cursor = self.client.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {self.name}") + cursor.close() + + def create_collection(self, collection_name: str = None): + """ + Create a new collection. + + Args: + collection_name (Optional[str]): The name of the new collection. + + Returns: + None + """ + if collection_name: + self.name = collection_name + cursor = self.client.cursor() + cursor.execute( + f"CREATE TABLE {self.name} (" + f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384));" + f"CREATE INDEX " + f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, ' + f'ef_construction = {self.metadata["hnsw:construction_ef"]});' + f"CREATE INDEX " + f'ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata["hnsw:M"]}, ' + f'ef_construction = {self.metadata["hnsw:construction_ef"]});' + f"CREATE INDEX " + f'ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata["hnsw:M"]}, ' + f'ef_construction = {self.metadata["hnsw:construction_ef"]});' + ) + cursor.close() + + +class PGVectorDB(VectorDB): + """ + A vector database that uses PGVector as the backend. + """ + + def __init__( + self, + *, + connection_string: str = None, + host: str = None, + port: int = None, + dbname: str = None, + connect_timeout: int = 10, + embedding_function: Callable = None, + metadata: dict = None, + ) -> None: + """ + Initialize the vector database. + + Note: connection_string or host + port + dbname must be specified + + Args: + connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None. + host: str | The host to connect to. Default is None. + port: int | The port to connect to. Default is None. + dbname: str | The database name to connect to. Default is None. + connect_timeout: int | The timeout to set for the connection. Default is 10. + embedding_function: Callable | The embedding function used to generate the vector representation + of the documents. Default is None. + metadata: dict | The metadata of the vector database. Default is None. If None, it will use this + setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table + using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef". + For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + if connection_string: + self.client = psycopg.connect(conninfo=connection_string, autocommit=True) + elif host and port and dbname: + self.client = psycopg.connect( + host=host, port=port, dbname=dbname, connect_timeout=connect_timeout, autocommit=True + ) + self.embedding_function = ( + SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function + ) + self.metadata = metadata + self.client.execute("CREATE EXTENSION IF NOT EXISTS vector") + register_vector(self.client) + self.active_collection = None + + def create_collection( + self, collection_name: str, overwrite: bool = False, get_or_create: bool = True + ) -> Collection: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Collection | The collection object. + """ + try: + if self.active_collection and self.active_collection.name == collection_name: + collection = self.active_collection + else: + collection = self.get_collection(collection_name) + except ValueError: + collection = None + if collection is None: + collection = Collection( + collection_name=collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + collection.set_collection_name(collection_name=collection_name) + collection.create_collection(collection_name=collection_name) + return collection + elif overwrite: + self.delete_collection(collection_name) + collection = Collection( + collection_name=collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + collection.set_collection_name(collection_name=collection_name) + collection.create_collection(collection_name=collection_name) + return collection + elif get_or_create: + return collection + else: + raise ValueError(f"Collection {collection_name} already exists.") + + def get_collection(self, collection_name: str = None) -> Collection: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Collection | The collection object. + """ + if collection_name is None: + if self.active_collection is None: + raise ValueError("No collection is specified.") + else: + logger.debug( + f"No collection is specified. Using current active collection {self.active_collection.name}." + ) + else: + self.active_collection = Collection( + client=self.client, collection_name=collection_name, embedding_function=self.embedding_function + ) + return self.active_collection + + def delete_collection(self, collection_name: str) -> None: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + None + """ + self.active_collection.delete_collection(collection_name) + if self.active_collection and self.active_collection.name == collection_name: + self.active_collection = None + + def _batch_insert( + self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False + ): + batch_size = int(PGVECTOR_MAX_BATCH_SIZE) + default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16} + default_metadatas = [default_metadata] + for i in range(0, len(documents), min(batch_size, len(documents))): + end_idx = i + min(batch_size, len(documents) - i) + collection_kwargs = { + "documents": documents[i:end_idx], + "ids": ids[i:end_idx], + "metadatas": metadatas[i:end_idx] if metadatas else default_metadatas, + "embeddings": embeddings[i:end_idx] if embeddings else None, + } + if upsert: + collection.upsert(**collection_kwargs) + else: + collection.add(**collection_kwargs) + + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + if not docs: + return + if docs[0].get("content") is None: + raise ValueError("The document content is required.") + if docs[0].get("id") is None: + raise ValueError("The document id is required.") + documents = [doc.get("content") for doc in docs] + ids = [doc.get("id") for doc in docs] + + collection = self.get_collection(collection_name) + if docs[0].get("embedding") is None: + logger.debug( + "No content embedding is provided. " + "Will use the VectorDB's embedding function to generate the content embedding." + ) + embeddings = None + else: + embeddings = [doc.get("embedding") for doc in docs] + if docs[0].get("metadata") is None: + metadatas = None + else: + metadatas = [doc.get("metadata") for doc in docs] + + self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) + + def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: str | The name of the collection. Default is None. + + Returns: + None + """ + self.insert_docs(docs, collection_name, upsert=True) + + def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: str | The name of the collection. Default is None. + kwargs: Dict | Additional keyword arguments. + + Returns: + None + """ + collection = self.get_collection(collection_name) + collection.delete(ids=ids, collection_name=collection_name) + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + collection = self.get_collection(collection_name) + if isinstance(queries, str): + queries = [queries] + results = collection.query( + query_texts=queries, + n_results=n_results, + distance_threshold=distance_threshold, + ) + logger.debug(f"Retrieve Docs Results:\n{results}") + return results + + def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + collection = self.get_collection(collection_name) + include = include if include else ["metadatas", "documents"] + results = collection.get(ids, include=include, **kwargs) + logger.debug(f"Retrieve Documents by ID Results:\n{results}") + return results diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py new file mode 100644 index 00000000000..3dcf79f1f55 --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -0,0 +1,115 @@ +import logging +from typing import Any, Dict, List + +from termcolor import colored + +from .base import QueryResults + + +class ColoredLogger(logging.Logger): + def __init__(self, name, level=logging.NOTSET): + super().__init__(name, level) + + def debug(self, msg, *args, color=None, **kwargs): + super().debug(colored(msg, color), *args, **kwargs) + + def info(self, msg, *args, color=None, **kwargs): + super().info(colored(msg, color), *args, **kwargs) + + def warning(self, msg, *args, color="yellow", **kwargs): + super().warning(colored(msg, color), *args, **kwargs) + + def error(self, msg, *args, color="light_red", **kwargs): + super().error(colored(msg, color), *args, **kwargs) + + def critical(self, msg, *args, color="red", **kwargs): + super().critical(colored(msg, color), *args, **kwargs) + + def fatal(self, msg, *args, color="red", **kwargs): + super().fatal(colored(msg, color), *args, **kwargs) + + +def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: + logger = ColoredLogger(name, level) + console_handler = logging.StreamHandler() + logger.addHandler(console_handler) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + logger.handlers[0].setFormatter(formatter) + return logger + + +logger = get_logger(__name__) + + +def filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults: + """Filters results based on a distance threshold. + + Args: + results: QueryResults | The query results. List[List[Tuple[Document, float]]] + distance_threshold: The maximum distance allowed for results. + + Returns: + QueryResults | A filtered results containing only distances smaller than the threshold. + """ + + if distance_threshold > 0: + results = [[(key, value) for key, value in data if value < distance_threshold] for data in results] + + return results + + +def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults: + """Converts a dictionary with list-of-list values to a list of tuples. + + Args: + data_dict: A dictionary where keys map to lists of lists or None. + special_key: The key in the dictionary containing the special values + for each tuple. + + Returns: + A list of tuples, where each tuple contains a sub-dictionary with + some keys from the original dictionary and the value from the + special_key. + + Example: + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + """ + + keys = [key for key in data_dict if key != special_key] + result = [] + + for i in range(len(data_dict[special_key])): + sub_result = [] + for j, distance in enumerate(data_dict[special_key][i]): + sub_dict = {} + for key in keys: + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key + sub_result.append((sub_dict, distance)) + result.append(sub_result) + + return result diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py index 9b7320f092f..1a54aeebe15 100644 --- a/autogen/agentchat/contrib/web_surfer.py +++ b/autogen/agentchat/contrib/web_surfer.py @@ -1,16 +1,18 @@ -import json import copy +import json import logging import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union, Callable, Literal, Tuple +from datetime import datetime +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + from typing_extensions import Annotated -from ... import Agent, ConversableAgent, AssistantAgent, UserProxyAgent, GroupChatManager, GroupChat, OpenAIWrapper + +from ... import Agent, AssistantAgent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper, UserProxyAgent from ...browser_utils import SimpleTextBrowser from ...code_utils import content_str -from datetime import datetime -from ...token_count_utils import count_token, get_max_token_limit from ...oai.openai_utils import filter_config +from ...token_count_utils import count_token, get_max_token_limit logger = logging.getLogger(__name__) @@ -79,8 +81,7 @@ def __init__( if inner_llm_config not in [None, False]: self._register_functions() - self._reply_func_list = [] - self.register_reply([Agent, None], WebSurferAgent.generate_surfer_reply) + self.register_reply([Agent, None], WebSurferAgent.generate_surfer_reply, remove_other_reply_funcs=True) self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply) self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 0d968f2ac6b..f457667cf8b 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -5,42 +5,37 @@ import json import logging import re +import warnings from collections import defaultdict from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union -import warnings + from openai import BadRequestError -from ..coding.base import CodeExecutor -from ..coding.factory import CodeExecutorFactory +from autogen.exception_utils import InvalidCarryOverType, SenderRequired -from ..oai.client import OpenAIWrapper, ModelClient -from ..runtime_logging import logging_enabled, log_new_agent -from ..cache.cache import Cache +from .._pydantic import model_dump +from ..cache.cache import AbstractCache from ..code_utils import ( + PYTHON_VARIANTS, UNKNOWN, - content_str, check_can_use_docker_or_throw, + content_str, decide_use_docker, execute_code, extract_code, infer_lang, ) -from .utils import gather_usage_summary, consolidate_chat_info -from .chat import ChatResult, initiate_chats, a_initiate_chats - - +from ..coding.base import CodeExecutor +from ..coding.factory import CodeExecutorFactory +from ..formatting_utils import colored from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str +from ..io.base import IOStream +from ..oai.client import ModelClient, OpenAIWrapper +from ..runtime_logging import log_new_agent, logging_enabled from .agent import Agent, LLMAgent -from .._pydantic import model_dump - -try: - from termcolor import colored -except ImportError: - - def colored(x, *args, **kwargs): - return x - +from .chat import ChatResult, a_initiate_chats, initiate_chats +from .utils import consolidate_chat_info, gather_usage_summary __all__ = ("ConversableAgent",) @@ -61,14 +56,13 @@ class ConversableAgent(LLMAgent): To modify the way to get human input, override `get_human_input` method. To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, `run_code`, and `execute_function` methods respectively. - To customize the initial message when a conversation starts, override `generate_init_message` method. """ - DEFAULT_CONFIG = {} # An empty configuration + DEFAULT_CONFIG = False # False or dict, the default config for llm inference MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) - DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases." - DEFAULT_summary_method = "last_msg" + DEFAULT_SUMMARY_PROMPT = "Summarize the takeaway from the conversation. Do not add any introductory phrases." + DEFAULT_SUMMARY_METHOD = "last_msg" llm_config: Union[Dict, Literal[False]] def __init__( @@ -77,7 +71,7 @@ def __init__( system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.", is_termination_msg: Optional[Callable[[Dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, - human_input_mode: Optional[str] = "TERMINATE", + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE", function_map: Optional[Dict[str, Callable]] = None, code_execution_config: Union[Dict, Literal[False]] = False, llm_config: Optional[Union[Dict, Literal[False]]] = None, @@ -122,11 +116,19 @@ def __init__( llm_config (dict or False or None): llm inference configuration. Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. + When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`. To disable llm-based auto reply, set to False. + When set to None, will use self.DEFAULT_CONFIG, which defaults to False. default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated. description (str): a short description of the agent. This description is used by other agents (e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message) """ + # we change code_execution_config below and we have to make sure we don't change the input + # in case of UserProxyAgent, without this we could even change the default value {} + code_execution_config = ( + code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config + ) + self._name = name # a dictionary of conversations, default value is list self._oai_messages = defaultdict(list) @@ -137,22 +139,11 @@ def __init__( if is_termination_msg is not None else (lambda x: content_str(x.get("content")) == "TERMINATE") ) + # Take a copy to avoid modifying the given dict + if isinstance(llm_config, dict): + llm_config = copy.deepcopy(llm_config) - if llm_config is False: - self.llm_config = False - self.client = None - else: - self.llm_config = self.DEFAULT_CONFIG.copy() - if isinstance(llm_config, dict): - self.llm_config.update(llm_config) - if "model" not in self.llm_config and ( - not self.llm_config.get("config_list") - or any(not config.get("model") for config in self.llm_config["config_list"]) - ): - raise ValueError( - "Please either set llm_config to False, or specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'." - ) - self.client = OpenAIWrapper(**self.llm_config) + self._validate_llm_config(llm_config) if logging_enabled(): log_new_agent(self, locals()) @@ -173,7 +164,6 @@ def __init__( ) self._default_auto_reply = default_auto_reply self._reply_func_list = [] - self._ignore_async_func_in_sync_chat_list = [] self._human_input = [] self.reply_at_receive = defaultdict(bool) self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) @@ -197,6 +187,21 @@ def __init__( self._code_execution_config = code_execution_config if self._code_execution_config.get("executor") is not None: + if "use_docker" in self._code_execution_config: + raise ValueError( + "'use_docker' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + if "work_dir" in self._code_execution_config: + raise ValueError( + "'work_dir' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + + if "timeout" in self._code_execution_config: + raise ValueError( + "'timeout' in code_execution_config is not valid when 'executor' is set. Use the appropriate arg in the chosen executor instead." + ) + # Use the new code executor. self._code_executor = CodeExecutorFactory.create(self._code_execution_config) self.register_reply([Agent, None], ConversableAgent._generate_code_execution_reply_using_executor) @@ -230,6 +235,20 @@ def __init__( "process_message_before_send": [], } + def _validate_llm_config(self, llm_config): + assert llm_config in (None, False) or isinstance( + llm_config, dict + ), "llm_config must be a dict or False or None." + if llm_config is None: + llm_config = self.DEFAULT_CONFIG + self.llm_config = self.DEFAULT_CONFIG if llm_config is None else llm_config + # TODO: more complete validity check + if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]: + raise ValueError( + "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'." + ) + self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config) + @property def name(self) -> str: """Get the name of the agent.""" @@ -246,13 +265,10 @@ def description(self, description: str): self._description = description @property - def code_executor(self) -> CodeExecutor: - """The code executor used by this agent. Raise if code execution is disabled.""" + def code_executor(self) -> Optional[CodeExecutor]: + """The code executor used by this agent. Returns None if code execution is disabled.""" if not hasattr(self, "_code_executor"): - raise ValueError( - "No code executor as code execution is disabled. " - "To enable code execution, set code_execution_config." - ) + return None return self._code_executor def register_reply( @@ -264,6 +280,7 @@ def register_reply( reset_config: Optional[Callable] = None, *, ignore_async_in_sync_chat: bool = False, + remove_other_reply_funcs: bool = False, ): """Register a reply function. @@ -275,34 +292,29 @@ def register_reply( from both sync and async chats. However, an async reply function will only be triggered from async chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows: - - if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and - - if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored. + if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and + if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored. Args: trigger (Agent class, str, Agent instance, callable, or list): the trigger. - - If a class is provided, the reply function will be called when the sender is an instance of the class. - - If a string is provided, the reply function will be called when the sender's name matches the string. - - If an agent instance is provided, the reply function will be called when the sender is the agent instance. - - If a callable is provided, the reply function will be called when the callable returns True. - - If a list is provided, the reply function will be called when any of the triggers in the list is activated. - - If None is provided, the reply function will be called only when the sender is None. - Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. + If a class is provided, the reply function will be called when the sender is an instance of the class. + If a string is provided, the reply function will be called when the sender's name matches the string. + If an agent instance is provided, the reply function will be called when the sender is the agent instance. + If a callable is provided, the reply function will be called when the callable returns True. + If a list is provided, the reply function will be called when any of the triggers in the list is activated. + If None is provided, the reply function will be called only when the sender is None. + Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. reply_func (Callable): the reply function. The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. - position: the position of the reply function in the reply function list. - config: the config to be passed to the reply function, see below. - reset_config: the function to reset the config, see below. - ignore_async_in_sync_chat: whether to ignore the async reply function in sync chats. If `False`, an exception - will be raised if an async reply function is registered and a chat is initialized with a sync - function. - ```python - def reply_func( - recipient: ConversableAgent, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: - ``` + + ```python + def reply_func( + recipient: ConversableAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + ``` position (int): the position of the reply function in the reply function list. The function registered later will be checked earlier by default. To change the order, set the position to a positive integer. @@ -310,9 +322,15 @@ def reply_func( When an agent is reset, the config will be reset to the original value. reset_config (Callable): the function to reset the config. The function returns None. Signature: ```def reset_config(config: Any)``` + ignore_async_in_sync_chat (bool): whether to ignore the async reply function in sync chats. If `False`, an exception + will be raised if an async reply function is registered and a chat is initialized with a sync + function. + remove_other_reply_funcs (bool): whether to remove other reply functions when registering this reply function. """ if not isinstance(trigger, (type, str, Agent, Callable, list)): raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") + if remove_other_reply_funcs: + self._reply_func_list.clear() self._reply_func_list.insert( position, { @@ -321,10 +339,20 @@ def reply_func( "config": copy.copy(config), "init_config": config, "reset_config": reset_config, + "ignore_async_in_sync_chat": ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func), }, ) - if ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func): - self._ignore_async_func_in_sync_chat_list.append(reply_func) + + def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable): + """Replace a registered reply function with a new one. + + Args: + old_reply_func (Callable): the old reply function to be replaced. + new_reply_func (Callable): the new reply function to replace the old one. + """ + for f in self._reply_func_list: + if f["reply_func"] == old_reply_func: + f["reply_func"] = new_reply_func @staticmethod def _summary_from_nested_chats( @@ -343,6 +371,8 @@ def _summary_from_nested_chats( chat_to_run = [] for i, c in enumerate(chat_queue): current_c = c.copy() + if current_c.get("sender") is None: + current_c["sender"] = recipient message = current_c.get("message") # If message is not provided in chat_queue, we by default use the last message from the original chat history as the first message in this nested chat (for the first chat in the chat queue). # NOTE: This setting is prone to change. @@ -356,7 +386,7 @@ def _summary_from_nested_chats( chat_to_run.append(current_c) if not chat_to_run: return True, None - res = recipient.initiate_chats(chat_to_run) + res = initiate_chats(chat_to_run) return True, res[-1].summary def register_nested_chats( @@ -561,7 +591,7 @@ def send( recipient: Agent, request_reply: Optional[bool] = None, silent: Optional[bool] = False, - ) -> ChatResult: + ): """Send a message to another agent. Args: @@ -593,9 +623,6 @@ def send( Raises: ValueError: if the message can't be converted into a valid ChatCompletion message. - - Returns: - ChatResult: a ChatResult object. """ message = self._process_message_before_send(message, recipient, silent) # When the agent composes and sends the message, the role of the message is "assistant" @@ -614,7 +641,7 @@ async def a_send( recipient: Agent, request_reply: Optional[bool] = None, silent: Optional[bool] = False, - ) -> ChatResult: + ): """(async) Send a message to another agent. Args: @@ -646,9 +673,6 @@ async def a_send( Raises: ValueError: if the message can't be converted into a valid ChatCompletion message. - - Returns: - ChatResult: an ChatResult object. """ message = self._process_message_before_send(message, recipient, silent) # When the agent composes and sends the message, the role of the message is "assistant" @@ -662,8 +686,9 @@ async def a_send( ) def _print_received_message(self, message: Union[Dict, str], sender: Agent): + iostream = IOStream.get_default() # print the message received - print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) + iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) message = self._message_to_dict(message) if message.get("tool_responses"): # Handle tool multi-call responses @@ -677,11 +702,11 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent): id_key = "name" else: id_key = "tool_call_id" - - func_print = f"***** Response from calling {message['role']} \"{message[id_key]}\" *****" - print(colored(func_print, "green"), flush=True) - print(message["content"], flush=True) - print(colored("*" * len(func_print), "green"), flush=True) + id = message.get(id_key, "No id found") + func_print = f"***** Response from calling {message['role']} ({id}) *****" + iostream.print(colored(func_print, "green"), flush=True) + iostream.print(message["content"], flush=True) + iostream.print(colored("*" * len(func_print), "green"), flush=True) else: content = message.get("content") if content is not None: @@ -691,35 +716,35 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent): message["context"], self.llm_config and self.llm_config.get("allow_format_str_template", False), ) - print(content_str(content), flush=True) + iostream.print(content_str(content), flush=True) if "function_call" in message and message["function_call"]: function_call = dict(message["function_call"]) func_print = ( - f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****" + f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****" ) - print(colored(func_print, "green"), flush=True) - print( + iostream.print(colored(func_print, "green"), flush=True) + iostream.print( "Arguments: \n", function_call.get("arguments", "(No arguments found)"), flush=True, sep="", ) - print(colored("*" * len(func_print), "green"), flush=True) + iostream.print(colored("*" * len(func_print), "green"), flush=True) if "tool_calls" in message and message["tool_calls"]: for tool_call in message["tool_calls"]: - id = tool_call.get("id", "(No id found)") + id = tool_call.get("id", "No tool call id found") function_call = dict(tool_call.get("function", {})) - func_print = f"***** Suggested tool Call ({id}): {function_call.get('name', '(No function name found)')} *****" - print(colored(func_print, "green"), flush=True) - print( + func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****" + iostream.print(colored(func_print, "green"), flush=True) + iostream.print( "Arguments: \n", function_call.get("arguments", "(No arguments found)"), flush=True, sep="", ) - print(colored("*" * len(func_print), "green"), flush=True) + iostream.print(colored("*" * len(func_print), "green"), flush=True) - print("\n", "-" * 80, flush=True, sep="") + iostream.print("\n", "-" * 80, flush=True, sep="") def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool): # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) @@ -826,12 +851,12 @@ def _raise_exception_on_async_reply_functions(self) -> None: Raises: RuntimeError: if any async reply functions are registered. """ - reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference( - self._ignore_async_func_in_sync_chat_list - ) + reply_functions = { + f["reply_func"] for f in self._reply_func_list if not f.get("ignore_async_in_sync_chat", False) + } async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)] - if async_reply_functions != []: + if async_reply_functions: msg = ( "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: " + ", ".join([f.__name__ for f in async_reply_functions]) @@ -842,48 +867,94 @@ def _raise_exception_on_async_reply_functions(self) -> None: def initiate_chat( self, recipient: "ConversableAgent", - clear_history: Optional[bool] = True, + clear_history: bool = True, silent: Optional[bool] = False, - cache: Optional[Cache] = None, + cache: Optional[AbstractCache] = None, max_turns: Optional[int] = None, - **context, + summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict] = {}, + message: Optional[Union[Dict, str, Callable]] = None, + **kwargs, ) -> ChatResult: """Initiate a chat with the recipient agent. Reset the consecutive auto reply counter. If `clear_history` is True, the chat history with the recipient agent will be cleared. - `generate_init_message` is called to generate the initial message for the agent. + Args: recipient: the recipient agent. clear_history (bool): whether to clear the chat history with the agent. Default is True. silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False. - cache (Cache or None): the cache client to be used for this conversation. Default is None. + cache (AbstractCache or None): the cache client to be used for this conversation. Default is None. max_turns (int or None): the maximum number of turns for the chat between the two agents. One turn means one conversation round trip. Note that this is different from - [max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session. - If max_turns is set to None, the chat will continue until a termination condition is met. Default is None. - **context: any context information. It has the following reserved fields: - "message": a str of message. Needs to be provided. Otherwise, input() will be called to get the initial message. - "summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". - - Supported string are "last_msg" and "reflection_with_llm": - when set "last_msg", it returns the last message of the dialog as the summary. - when set "reflection_with_llm", it returns a summary extracted using an llm client. - `llm_config` must be set in either the recipient or sender. - "reflection_with_llm" requires the llm_config to be set in either the sender or the recipient. - - A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g, - ```python - def my_summary_method( - sender: ConversableAgent, - recipient: ConversableAgent, - ): - return recipient.last_message(sender)["content"] - ``` - "summary_prompt": a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflext - on the conversation and extract a summary when summary_method is "reflection_with_llm". - Default is DEFAULT_summary_prompt, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." - "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string. - If provided, we will combine this carryover with the "message" content when generating the initial chat + [max_consecutive_auto_reply](#max_consecutive_auto_reply) which is the maximum number of consecutive auto replies; and it is also different from [max_rounds in GroupChat](./groupchat#groupchat-objects) which is the maximum number of rounds in a group chat session. + If max_turns is set to None, the chat will continue until a termination condition is met. Default is None. + summary_method (str or callable): a method to get a summary from the chat. Default is DEFAULT_SUMMARY_METHOD, i.e., "last_msg". + + Supported strings are "last_msg" and "reflection_with_llm": + - when set to "last_msg", it returns the last message of the dialog as the summary. + - when set to "reflection_with_llm", it returns a summary extracted using an llm client. + `llm_config` must be set in either the recipient or sender. + + A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g., + + ```python + def my_summary_method( + sender: ConversableAgent, + recipient: ConversableAgent, + summary_args: dict, + ): + return recipient.last_message(sender)["content"] + ``` + summary_args (dict): a dictionary of arguments to be passed to the summary_method. + One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect + on the conversation and extract a summary when summary_method is "reflection_with_llm". + The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." + message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message. + - If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context. + If dict, it may contain the following reserved fields (either content or tool_calls need to be provided). + + 1. "content": content of the message, can be None. + 2. "function_call": a dictionary containing the function name and arguments. (deprecated in favor of "tool_calls") + 3. "tool_calls": a list of dictionaries containing the function name and arguments. + 4. "role": role of the message, can be "assistant", "user", "function". + This field is only needed to distinguish between "function" or "assistant"/"user". + 5. "name": In most cases, this field is not needed. When the role is "function", this field is needed to indicate the function name. + 6. "context" (dict): the context of the message, which will be passed to + [OpenAIWrapper.create](../oai/client#create). + + - If a callable is provided, it will be called to get the initial message in the form of a string or a dict. + If the returned type is dict, it may contain the reserved fields mentioned above. + + Example of a callable message (returning a string): + + ```python + def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]: + carryover = context.get("carryover", "") + if isinstance(message, list): + carryover = carryover[-1] + final_msg = "Write a blogpost." + "\\nContext: \\n" + carryover + return final_msg + ``` + + Example of a callable message (returning a dict): + + ```python + def my_message(sender: ConversableAgent, recipient: ConversableAgent, context: dict) -> Union[str, Dict]: + final_msg = {} + carryover = context.get("carryover", "") + if isinstance(message, list): + carryover = carryover[-1] + final_msg["content"] = "Write a blogpost." + "\\nContext: \\n" + carryover + final_msg["context"] = {"prefix": "Today I feel"} + return final_msg + ``` + **kwargs: any additional information. It has the following reserved fields: + - "carryover": a string or a list of string to specify the carryover information to be passed to this chat. + If provided, we will combine this carryover (by attaching a "context: " string and the carryover content after the message content) with the "message" content when generating the initial chat message in `generate_init_message`. + - "verbose": a boolean to specify whether to print the message and carryover in a chat. Default is False. Raises: RuntimeError: if any async reply functions are registered and not ignored in sync chat. @@ -891,8 +962,8 @@ def my_summary_method( Returns: ChatResult: an ChatResult object. """ - _chat_info = context.copy() - _chat_info["recipient"] = recipient + _chat_info = locals().copy() + _chat_info["sender"] = self consolidate_chat_info(_chat_info, uniform_sender=self) for agent in [self, recipient]: agent._raise_exception_on_async_reply_functions() @@ -902,7 +973,10 @@ def my_summary_method( self._prepare_chat(recipient, clear_history, reply_at_receive=False) for _ in range(max_turns): if _ == 0: - msg2send = self.generate_init_message(**context) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) else: msg2send = self.generate_reply(messages=self.chat_messages[recipient], sender=recipient) if msg2send is None: @@ -910,11 +984,15 @@ def my_summary_method( self.send(msg2send, recipient, request_reply=True, silent=silent) else: self._prepare_chat(recipient, clear_history) - self.send(self.generate_init_message(**context), recipient, silent=silent) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = self.generate_init_message(message, **kwargs) + self.send(msg2send, recipient, silent=silent) summary = self._summarize_chat( - context.get("summary_method", ConversableAgent.DEFAULT_summary_method), + summary_method, + summary_args, recipient, - prompt=context.get("summary_prompt"), cache=cache, ) for agent in [self, recipient]: @@ -931,11 +1009,14 @@ def my_summary_method( async def a_initiate_chat( self, recipient: "ConversableAgent", - clear_history: Optional[bool] = True, + clear_history: bool = True, silent: Optional[bool] = False, - cache: Optional[Cache] = None, + cache: Optional[AbstractCache] = None, max_turns: Optional[int] = None, - **context, + summary_method: Optional[Union[str, Callable]] = DEFAULT_SUMMARY_METHOD, + summary_args: Optional[dict] = {}, + message: Optional[Union[str, Callable]] = None, + **kwargs, ) -> ChatResult: """(async) Initiate a chat with the recipient agent. @@ -948,8 +1029,8 @@ async def a_initiate_chat( Returns: ChatResult: an ChatResult object. """ - _chat_info = context.copy() - _chat_info["recipient"] = recipient + _chat_info = locals().copy() + _chat_info["sender"] = self consolidate_chat_info(_chat_info, uniform_sender=self) for agent in [self, recipient]: agent.previous_cache = agent.client_cache @@ -958,7 +1039,10 @@ async def a_initiate_chat( self._prepare_chat(recipient, clear_history, reply_at_receive=False) for _ in range(max_turns): if _ == 0: - msg2send = await self.a_generate_init_message(**context) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) else: msg2send = await self.a_generate_reply(messages=self.chat_messages[recipient], sender=recipient) if msg2send is None: @@ -966,11 +1050,15 @@ async def a_initiate_chat( await self.a_send(msg2send, recipient, request_reply=True, silent=silent) else: self._prepare_chat(recipient, clear_history) - await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent) + if isinstance(message, Callable): + msg2send = message(_chat_info["sender"], _chat_info["recipient"], kwargs) + else: + msg2send = await self.a_generate_init_message(message, **kwargs) + await self.a_send(msg2send, recipient, silent=silent) summary = self._summarize_chat( - context.get("summary_method", ConversableAgent.DEFAULT_summary_method), + summary_method, + summary_args, recipient, - prompt=context.get("summary_prompt"), cache=cache, ) for agent in [self, recipient]: @@ -987,9 +1075,9 @@ async def a_initiate_chat( def _summarize_chat( self, summary_method, + summary_args, recipient: Optional[Agent] = None, - prompt: Optional[str] = None, - cache: Optional[Cache] = None, + cache: Optional[AbstractCache] = None, ) -> str: """Get a chat summary from an agent participating in a chat. @@ -1000,39 +1088,71 @@ def _summarize_chat( def my_summary_method( sender: ConversableAgent, recipient: ConversableAgent, + summary_args: dict, ): return recipient.last_message(sender)["content"] ``` + summary_args (dict): a dictionary of arguments to be passed to the summary_method. recipient: the recipient agent in a chat. prompt (str): the prompt used to get a summary when summary_method is "reflection_with_llm". Returns: str: a chat summary from the agent. """ - agent = self if recipient is None else recipient summary = "" if summary_method is None: return summary + if "cache" not in summary_args: + summary_args["cache"] = cache if summary_method == "reflection_with_llm": - prompt = ConversableAgent.DEFAULT_summary_prompt if prompt is None else prompt - if not isinstance(prompt, str): - raise ValueError("The summary_prompt must be a string.") - msg_list = agent.chat_messages_for_summary(self) - try: - summary = self._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=cache) - except BadRequestError as e: - warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning) - elif summary_method == "last_msg" or summary_method is None: - try: - summary = agent.last_message(self)["content"].replace("TERMINATE", "") - except (IndexError, AttributeError) as e: - warnings.warn(f"Cannot extract summary using last_msg: {e}", UserWarning) - elif isinstance(summary_method, Callable): - summary = summary_method(recipient, self) + summary_method = self._reflection_with_llm_as_summary + elif summary_method == "last_msg": + summary_method = self._last_msg_as_summary + + if isinstance(summary_method, Callable): + summary = summary_method(self, recipient, summary_args) + else: + raise ValueError( + "If not None, the summary_method must be a string from [`reflection_with_llm`, `last_msg`] or a callable." + ) + return summary + + @staticmethod + def _last_msg_as_summary(sender, recipient, summary_args) -> str: + """Get a chat summary from the last message of the recipient.""" + summary = "" + try: + content = recipient.last_message(sender)["content"] + if isinstance(content, str): + summary = content.replace("TERMINATE", "") + elif isinstance(content, list): + # Remove the `TERMINATE` word in the content list. + summary = "\n".join( + x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x + ) + except (IndexError, AttributeError) as e: + warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning) + return summary + + @staticmethod + def _reflection_with_llm_as_summary(sender, recipient, summary_args): + prompt = summary_args.get("summary_prompt") + prompt = ConversableAgent.DEFAULT_SUMMARY_PROMPT if prompt is None else prompt + if not isinstance(prompt, str): + raise ValueError("The summary_prompt must be a string.") + msg_list = recipient.chat_messages_for_summary(sender) + agent = sender if recipient is None else recipient + try: + summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache")) + except BadRequestError as e: + warnings.warn( + f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning + ) + summary = "" return summary def _reflection_with_llm( - self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None + self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None ) -> str: """Get a chat summary using reflection with an llm client based on the conversation history. @@ -1040,7 +1160,7 @@ def _reflection_with_llm( prompt (str): The prompt (in this method it is used as system prompt) used to get the summary. messages (list): The messages generated as part of a chat conversation. llm_agent: the agent with an llm client. - cache (Cache or None): the cache client to be used for this conversation. + cache (AbstractCache or None): the cache client to be used for this conversation. """ system_msg = [ { @@ -1059,9 +1179,25 @@ def _reflection_with_llm( response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) return response + def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Check the chat queue and add the "sender" key if it's missing. + + Args: + chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information. + + Returns: + List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. + """ + chat_queue_with_sender = [] + for chat_info in chat_queue: + if chat_info.get("sender") is None: + chat_info["sender"] = self + chat_queue_with_sender.append(chat_info) + return chat_queue_with_sender + def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: """(Experimental) Initiate chats with multiple agents. - TODO: add async version of this method. Args: chat_queue (List[Dict]): a list of dictionaries containing the information of the chats. @@ -1069,16 +1205,13 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. """ - _chat_queue = chat_queue.copy() - for chat_info in _chat_queue: - chat_info["sender"] = self + _chat_queue = self._check_chat_queue_for_sender(chat_queue) self._finished_chats = initiate_chats(_chat_queue) return self._finished_chats async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: - _chat_queue = chat_queue.copy() - for chat_info in _chat_queue: - chat_info["sender"] = self + + _chat_queue = self._check_chat_queue_for_sender(chat_queue) self._finished_chats = await a_initiate_chats(_chat_queue) return self._finished_chats @@ -1123,6 +1256,7 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser recipient: the agent with whom the chat history to clear. If None, clear the chat history with all agents. nr_messages_to_preserve: the number of newest messages to preserve in the chat history. """ + iostream = IOStream.get_default() if recipient is None: if nr_messages_to_preserve: for key in self._oai_messages: @@ -1132,7 +1266,7 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser first_msg_to_save = self._oai_messages[key][-nr_messages_to_preserve_internal] if "tool_responses" in first_msg_to_save: nr_messages_to_preserve_internal += 1 - print( + iostream.print( f"Preserving one more message for {self.name} to not divide history between tool call and " f"tool response." ) @@ -1143,7 +1277,7 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser else: self._oai_messages[recipient].clear() if nr_messages_to_preserve: - print( + iostream.print( colored( "WARNING: `nr_preserved_messages` is ignored when clearing chat history with a specific agent.", "yellow", @@ -1190,7 +1324,7 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ extracted_response = llm_client.extract_text_or_completion_object(response)[0] if extracted_response is None: - warnings.warn("Extracted_response from {response} is None.", UserWarning) + warnings.warn(f"Extracted_response from {response} is None.", UserWarning) return None # ensure function and tool calls will be accepted when sent back to the LLM if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"): @@ -1202,6 +1336,12 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[ ) for tool_call in extracted_response.get("tool_calls") or []: tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"]) + # Remove id and type if they are not present. + # This is to make the tool call object compatible with Mistral API. + if tool_call.get("id") is None: + tool_call.pop("id") + if tool_call.get("type") is None: + tool_call.pop("type") return extracted_response async def a_generate_oai_reply( @@ -1211,8 +1351,19 @@ async def a_generate_oai_reply( config: Optional[Any] = None, ) -> Tuple[bool, Union[str, Dict, None]]: """Generate a reply using autogen.oai asynchronously.""" + iostream = IOStream.get_default() + + def _generate_oai_reply( + self, iostream: IOStream, *args: Any, **kwargs: Any + ) -> Tuple[bool, Union[str, Dict, None]]: + with IOStream.set_default(iostream): + return self.generate_oai_reply(*args, **kwargs) + return await asyncio.get_event_loop().run_in_executor( - None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config) + None, + functools.partial( + _generate_oai_reply, self=self, iostream=iostream, messages=messages, sender=sender, config=config + ), ) def _generate_code_execution_reply_using_executor( @@ -1222,6 +1373,8 @@ def _generate_code_execution_reply_using_executor( config: Optional[Union[Dict, Literal[False]]] = None, ): """Generate a reply using code executor.""" + iostream = IOStream.get_default() + if config is not None: raise ValueError("config is not supported for _generate_code_execution_reply_using_executor.") if self._code_execution_config is False: @@ -1256,6 +1409,25 @@ def _generate_code_execution_reply_using_executor( code_blocks = self._code_executor.code_extractor.extract_code_blocks(message["content"]) if len(code_blocks) == 0: continue + + num_code_blocks = len(code_blocks) + if num_code_blocks == 1: + iostream.print( + colored( + f"\n>>>>>>>> EXECUTING CODE BLOCK (inferred language is {code_blocks[0].language})...", + "red", + ), + flush=True, + ) + else: + iostream.print( + colored( + f"\n>>>>>>>> EXECUTING {num_code_blocks} CODE BLOCKS (inferred languages are [{', '.join([x.language for x in code_blocks])}])...", + "red", + ), + flush=True, + ) + # found code blocks, execute code. code_result = self._code_executor.execute_code_blocks(code_blocks) exitcode2str = "execution succeeded" if code_result.exit_code == 0 else "execution failed" @@ -1399,7 +1571,6 @@ def generate_tool_calls_reply( message = messages[-1] tool_returns = [] for tool_call in message.get("tool_calls", []): - id = tool_call["id"] function_call = tool_call.get("function", {}) func = self._function_map.get(function_call.get("name", None), None) if inspect.iscoroutinefunction(func): @@ -1417,13 +1588,24 @@ def generate_tool_calls_reply( loop.close() else: _, func_return = self.execute_function(function_call) - tool_returns.append( - { - "tool_call_id": id, + content = func_return.get("content", "") + if content is None: + content = "" + tool_call_id = tool_call.get("id", None) + if tool_call_id is not None: + tool_call_response = { + "tool_call_id": tool_call_id, "role": "tool", - "content": func_return.get("content", ""), + "content": content, } - ) + else: + # Do not include tool_call_id if it is not present. + # This is to make the tool call object compatible with Mistral API. + tool_call_response = { + "role": "tool", + "content": content, + } + tool_returns.append(tool_call_response) if tool_returns: return True, { "role": "tool", @@ -1490,7 +1672,7 @@ def check_termination_and_human_reply( - Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation should be terminated, and a human reply which can be a string, a dictionary, or None. """ - # Function implementation... + iostream = IOStream.get_default() if config is None: config = self @@ -1536,7 +1718,7 @@ def check_termination_and_human_reply( # print the no_human_input_msg if no_human_input_msg: - print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) # stop the conversation if reply == "exit": @@ -1576,7 +1758,7 @@ def check_termination_and_human_reply( # increment the consecutive_auto_reply_counter self._consecutive_auto_reply_counter[sender] += 1 if self.human_input_mode != "NEVER": - print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) + iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) return False, None @@ -1603,6 +1785,8 @@ async def a_check_termination_and_human_reply( - Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation should be terminated, and a human reply which can be a string, a dictionary, or None. """ + iostream = IOStream.get_default() + if config is None: config = self if messages is None: @@ -1647,7 +1831,7 @@ async def a_check_termination_and_human_reply( # print the no_human_input_msg if no_human_input_msg: - print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) + iostream.print(colored(f"\n>>>>>>>> {no_human_input_msg}", "red"), flush=True) # stop the conversation if reply == "exit": @@ -1687,7 +1871,7 @@ async def a_check_termination_and_human_reply( # increment the consecutive_auto_reply_counter self._consecutive_auto_reply_counter[sender] += 1 if self.human_input_mode != "NEVER": - print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) + iostream.print(colored("\n>>>>>>>> USING AUTO REPLY...", "red"), flush=True) return False, None @@ -1805,6 +1989,7 @@ async def a_generate_reply( reply_func = reply_func_tuple["reply_func"] if "exclude" in kwargs and reply_func in kwargs["exclude"]: continue + if self._match_trigger(reply_func_tuple["trigger"], sender): if inspect.iscoroutinefunction(reply_func): final, reply = await reply_func( @@ -1816,7 +2001,7 @@ async def a_generate_reply( return reply return self._default_auto_reply - def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Agent) -> bool: + def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool: """Check if the sender matches the trigger. Args: @@ -1833,6 +2018,8 @@ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], if trigger is None: return sender is None elif isinstance(trigger, str): + if sender is None: + raise SenderRequired() return trigger == sender.name elif isinstance(trigger, type): return isinstance(sender, trigger) @@ -1841,7 +2028,7 @@ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], return trigger == sender elif isinstance(trigger, Callable): rst = trigger(sender) - assert rst in [True, False], f"trigger {trigger} must return a boolean value." + assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value." return rst elif isinstance(trigger, list): return any(self._match_trigger(t, sender) for t in trigger) @@ -1859,7 +2046,9 @@ def get_human_input(self, prompt: str) -> str: Returns: str: human input. """ - reply = input(prompt) + iostream = IOStream.get_default() + + reply = iostream.input(prompt) self._human_input.append(reply) return reply @@ -1874,8 +2063,8 @@ async def a_get_human_input(self, prompt: str) -> str: Returns: str: human input. """ - reply = input(prompt) - self._human_input.append(reply) + loop = asyncio.get_running_loop() + reply = await loop.run_in_executor(None, functools.partial(self.get_human_input, prompt)) return reply def run_code(self, code, **kwargs): @@ -1896,12 +2085,14 @@ def run_code(self, code, **kwargs): def execute_code_blocks(self, code_blocks): """Execute the code blocks and return the result.""" + iostream = IOStream.get_default() + logs_all = "" for i, code_block in enumerate(code_blocks): lang, code = code_block if not lang: lang = infer_lang(code) - print( + iostream.print( colored( f"\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...", "red", @@ -1910,7 +2101,7 @@ def execute_code_blocks(self, code_blocks): ) if lang in ["bash", "shell", "sh"]: exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config) - elif lang in ["python", "Python"]: + elif lang in PYTHON_VARIANTS: if code.startswith("# filename: "): filename = code[11 : code.find("\n")].strip() else: @@ -1982,6 +2173,8 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call """ + iostream = IOStream.get_default() + func_name = func_call.get("name", "") func = self._function_map.get(func_name, None) @@ -1997,7 +2190,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict # Try to execute the function if arguments is not None: - print( + iostream.print( colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"), flush=True, ) @@ -2010,7 +2203,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict content = f"Error: Function {func_name} not found." if verbose: - print( + iostream.print( colored(f"\nInput arguments: {arguments}\nOutput:\n{content}", "magenta"), flush=True, ) @@ -2037,6 +2230,8 @@ async def a_execute_function(self, func_call): "function_call" deprecated as of [OpenAI API v1.1.0](https://github.com/openai/openai-python/releases/tag/v1.1.0) See https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call """ + iostream = IOStream.get_default() + func_name = func_call.get("name", "") func = self._function_map.get(func_name, None) @@ -2052,7 +2247,7 @@ async def a_execute_function(self, func_call): # Try to execute the function if arguments is not None: - print( + iostream.print( colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"), flush=True, ) @@ -2074,70 +2269,83 @@ async def a_execute_function(self, func_call): "content": str(content), } - def generate_init_message(self, **context) -> Union[str, Dict]: + def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: """Generate the initial message for the agent. - TODO: offer a way to customize initial message without overriding this function. - - Override this function to customize the initial message based on user's request. - If not overridden, "message" needs to be provided in the context, or input() will be called to get the initial message. + If message is None, input() will be called to get the initial message. Args: - **context: any context information. It has the following reserved fields: - "message": a str of message. - "summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg". - - Supported string are "last_msg" and "reflection_with_llm": - when set "last_msg", it returns the last message of the dialog as the summary. - when set "reflection_with_llm", it returns a summary extracted using an llm client. - `llm_config` must be set in either the recipient or sender. - "reflection_with_llm" requires the llm_config to be set in either the sender or the recipient. - - A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g, - ```python - def my_summary_method( - sender: ConversableAgent, - recipient: ConversableAgent, - ): - return recipient.last_message(sender)["content"] - ``` - When both the sender and the recipient have an llm client, the recipient's llm client will be used. - "summary_prompt": a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflext - on the conversation and extract a summary when summary_method is "reflection_with_llm". - Default is DEFAULT_summary_prompt, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." + message (str or None): the message to be processed. + **kwargs: any additional information. It has the following reserved fields: "carryover": a string or a list of string to specify the carryover information to be passed to this chat. It can be a string or a list of string. If provided, we will combine this carryover with the "message" content when generating the initial chat message. + Returns: + str or dict: the processed message. """ - if "message" not in context: - context["message"] = self.get_human_input(">") - self._process_carryover(context) - return context["message"] - - def _process_carryover(self, context): - carryover = context.get("carryover", "") - if carryover: - # if carryover is string - if isinstance(carryover, str): - context["message"] = context["message"] + "\nContext: \n" + carryover - elif isinstance(carryover, list): - context["message"] = context["message"] + "\nContext: \n" + ("\n").join([t for t in carryover]) - else: - raise warnings.warn( - "Carryover should be a string or a list of strings. Not adding carryover to the message." - ) + if message is None: + message = self.get_human_input(">") - async def a_generate_init_message(self, **context) -> Union[str, Dict]: - """Generate the initial message for the agent. - TODO: offer a way to customize initial message without overriding this function. + return self._handle_carryover(message, kwargs) + + def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]: + if not kwargs.get("carryover"): + return message - Override this function to customize the initial message based on user's request. - If not overridden, "message" needs to be provided in the context, or input() will be called to get the initial message. + if isinstance(message, str): + return self._process_carryover(message, kwargs) + + elif isinstance(message, dict): + if isinstance(message.get("content"), str): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_carryover(message["content"], kwargs) + elif isinstance(message.get("content"), list): + # Makes sure the original message is not mutated + message = message.copy() + message["content"] = self._process_multimodal_carryover(message["content"], kwargs) + else: + raise InvalidCarryOverType("Carryover should be a string or a list of strings.") + + return message + + def _process_carryover(self, content: str, kwargs: dict) -> str: + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + # if carryover is string + if isinstance(kwargs["carryover"], str): + content += "\nContext: \n" + kwargs["carryover"] + elif isinstance(kwargs["carryover"], list): + content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]]) + else: + raise InvalidCarryOverType( + "Carryover should be a string or a list of strings. Not adding carryover to the message." + ) + return content + + def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]: + """Prepends the context to a multimodal message.""" + # Makes sure there's a carryover + if not kwargs.get("carryover"): + return content + + return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content + + async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: + """Generate the initial message for the agent. + If message is None, input() will be called to get the initial message. Args: Please refer to `generate_init_message` for the description of the arguments. + + Returns: + str or dict: the processed message. """ - if "message" not in context: - context["message"] = await self.a_get_human_input(">") - self._process_carryover(context) - return context["message"] + if message is None: + message = await self.a_get_human_input(">") + + return self._handle_carryover(message, kwargs) def register_function(self, function_map: Dict[str, Union[Callable, None]]): """Register functions to the agent. @@ -2178,6 +2386,11 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) func for func in self.llm_config["functions"] if func["name"] != func_sig ] else: + if not isinstance(func_sig, dict): + raise ValueError( + f"The function signature must be of the type dict. Received function signature type {type(func_sig)}" + ) + self._assert_valid_name(func_sig["name"]) if "functions" in self.llm_config.keys(): self.llm_config["functions"] = [ @@ -2214,6 +2427,10 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig ] else: + if not isinstance(tool_sig, dict): + raise ValueError( + f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}" + ) self._assert_valid_name(tool_sig["function"]["name"]) if "tools" in self.llm_config.keys(): self.llm_config["tools"] = [ @@ -2472,30 +2689,35 @@ def process_last_received_message(self, messages): return messages # Last message contains a context key. if "content" not in last_message: return messages # Last message has no content. - user_text = last_message["content"] - if not isinstance(user_text, str): - return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here. - if user_text == "exit": + + user_content = last_message["content"] + if not isinstance(user_content, str) and not isinstance(user_content, list): + # if the user_content is a string, it is for regular LLM + # if the user_content is a list, it should follow the multimodal LMM format. + return messages + if user_content == "exit": return messages # Last message is an exit command. # Call each hook (in order of registration) to process the user's message. - processed_user_text = user_text + processed_user_content = user_content for hook in hook_list: - processed_user_text = hook(processed_user_text) - if processed_user_text == user_text: + processed_user_content = hook(processed_user_content) + if processed_user_content == user_content: return messages # No hooks actually modified the user's message. # Replace the last user message with the expanded one. messages = messages.copy() - messages[-1]["content"] = processed_user_text + messages[-1]["content"] = processed_user_content return messages def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: """Print the usage summary.""" + iostream = IOStream.get_default() + if self.client is None: - print(f"No cost incurred from agent '{self.name}'.") + iostream.print(f"No cost incurred from agent '{self.name}'.") else: - print(f"Agent '{self.name}':") + iostream.print(f"Agent '{self.name}':") self.client.print_usage_summary(mode) def get_actual_usage(self) -> Union[None, Dict[str, int]]: diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index c9ebe82c32c..f5b6106863a 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -3,27 +3,19 @@ import re import sys from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union, Tuple - +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from ..code_utils import content_str -from ..exception_utils import AgentNameConflict +from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent +from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed +from ..io.base import IOStream +from ..runtime_logging import log_new_agent, logging_enabled from .agent import Agent from .conversable_agent import ConversableAgent -from ..runtime_logging import logging_enabled, log_new_agent -from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed logger = logging.getLogger(__name__) -class NoEligibleSpeakerException(Exception): - """Exception raised for early termination of a GroupChat.""" - - def __init__(self, message="No eligible speakers."): - self.message = message - super().__init__(self.message) - - @dataclass class GroupChat: """(In preview) A group chat class that contains the following data fields: @@ -36,13 +28,29 @@ class GroupChat: When set to True and when a message is a function call suggestion, the next speaker will be chosen from an agent which contains the corresponding function name in its `function_map`. + - select_speaker_message_template: customize the select speaker message (used in "auto" speaker selection), which appears first in the message context and generally includes the agent descriptions and list of agents. The string value will be converted to an f-string, use "{roles}" to output the agent's and their role descriptions and "{agentlist}" for a comma-separated list of agent names in square brackets. The default value is: + "You are in a role play game. The following roles are available: + {roles}. + Read the following conversation. + Then select the next role from {agentlist} to play. Only return the role." + - select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. The string value will be converted to an f-string, use "{agentlist}" for a comma-separated list of agent names in square brackets. The default value is: + "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role." - speaker_selection_method: the method for selecting the next speaker. Default is "auto". Could be any of the following (case insensitive), will raise ValueError if not recognized: - "auto": the next speaker is selected automatically by LLM. - "manual": the next speaker is selected manually by user input. - "random": the next speaker is selected randomly. - "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`. - + - a customized speaker selection function (Callable): the function will be called to select the next speaker. + The function should take the last speaker and the group chat as input and return one of the following: + 1. an `Agent` class, it must be one of the agents in the group chat. + 2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use. + 3. None, which would terminate the conversation gracefully. + ```python + def custom_speaker_selection_func( + last_speaker: Agent, groupchat: GroupChat + ) -> Union[Agent, str, None]: + ``` - allow_repeat_speaker: whether to allow the same speaker to speak consecutively. Default is True, in which case all speakers are allowed to speak consecutively. If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat. @@ -60,6 +68,7 @@ class GroupChat: "clear history" phrase in user prompt. This is experimental feature. See description of GroupChatManager.clear_agents_history function for more info. - send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False) + - role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system') """ agents: List[Agent] @@ -67,16 +76,29 @@ class GroupChat: max_round: Optional[int] = 10 admin_name: Optional[str] = "Admin" func_call_filter: Optional[bool] = True - speaker_selection_method: Optional[str] = "auto" + speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto" allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None allowed_or_disallowed_speaker_transitions: Optional[Dict] = None - speaker_transitions_type: Optional[str] = None + speaker_transitions_type: Literal["allowed", "disallowed", None] = None enable_clear_history: Optional[bool] = False - send_introductions: Optional[bool] = False + send_introductions: bool = False + select_speaker_message_template: str = """You are in a role play game. The following roles are available: + {roles}. + Read the following conversation. + Then select the next role from {agentlist} to play. Only return the role.""" + select_speaker_prompt_template: str = ( + "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role." + ) + role_for_select_speaker_messages: Optional[str] = "system" _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] _VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None] + # Define a class attribute for the default introduction message + DEFAULT_INTRO_MSG = ( + "Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are:" + ) + allowed_speaker_transitions_dict: Dict = field(init=False) def __post_init__(self): @@ -156,6 +178,16 @@ def __post_init__(self): agents=self.agents, ) + # Check select_speaker_message_template and select_speaker_prompt_template have values + if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0: + raise ValueError("select_speaker_message_template cannot be empty or None.") + + if self.select_speaker_prompt_template is None or len(self.select_speaker_prompt_template) == 0: + raise ValueError("select_speaker_prompt_template cannot be empty or None.") + + if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0: + raise ValueError("role_for_select_speaker_messages cannot be empty or None.") + @property def agent_names(self) -> List[str]: """Return the names of the agents in the group chat.""" @@ -203,6 +235,10 @@ def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agen if agents is None: agents = self.agents + # Ensure the provided list of agents is a subset of self.agents + if not set(agents).issubset(set(self.agents)): + raise UndefinedNextAgent() + # What index is the agent? (-1 if not present) idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 @@ -215,50 +251,63 @@ def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agen if self.agents[(offset + i) % len(self.agents)] in agents: return self.agents[(offset + i) % len(self.agents)] + # Explicitly handle cases where no valid next agent exists in the provided subset. + raise UndefinedNextAgent() + def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str: """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" if agents is None: agents = self.agents - return f"""You are in a role play game. The following roles are available: -{self._participant_roles(agents)}. -Read the following conversation. -Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" + roles = self._participant_roles(agents) + agentlist = f"{[agent.name for agent in agents]}" + + return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist) + return return_msg def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str: """Return the floating system prompt selecting the next speaker. This is always the *last* message in the context.""" if agents is None: agents = self.agents - return f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role." + + agentlist = f"{[agent.name for agent in agents]}" + + return_prompt = self.select_speaker_prompt_template.format(agentlist=agentlist) + return return_prompt def introductions_msg(self, agents: Optional[List[Agent]] = None) -> str: """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" if agents is None: agents = self.agents - return f"""Hello everyone. We have assembled a great team today to answer questions and solve tasks. In attendance are: + # Use the class attribute instead of a hardcoded string + intro_msg = self.DEFAULT_INTRO_MSG + participant_roles = self._participant_roles(agents) -{self._participant_roles(agents)} -""" + return f"{intro_msg}\n\n{participant_roles}" def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]: """Manually select the next speaker.""" + iostream = IOStream.get_default() + if agents is None: agents = self.agents - print("Please select the next speaker from the following list:") + iostream.print("Please select the next speaker from the following list:") _n_agents = len(agents) for i in range(_n_agents): - print(f"{i+1}: {agents[i].name}") + iostream.print(f"{i+1}: {agents[i].name}") try_count = 0 # Assume the user will enter a valid number within 3 tries, otherwise use auto selection to avoid blocking. while try_count <= 3: try_count += 1 if try_count >= 3: - print(f"You have tried {try_count} times. The next speaker will be selected automatically.") + iostream.print(f"You have tried {try_count} times. The next speaker will be selected automatically.") break try: - i = input("Enter the number of the next speaker (enter nothing or `q` to use auto selection): ") + i = iostream.input( + "Enter the number of the next speaker (enter nothing or `q` to use auto selection): " + ) if i == "" or i == "q": break i = int(i) @@ -267,7 +316,7 @@ def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A else: raise ValueError except ValueError: - print(f"Invalid input. Please enter a number between 1 and {_n_agents}.") + iostream.print(f"Invalid input. Please enter a number between 1 and {_n_agents}.") return None def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]: @@ -277,11 +326,34 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A return random.choice(agents) def _prepare_and_select_agents( - self, last_speaker: Agent + self, + last_speaker: Agent, ) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]: - if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: + # If self.speaker_selection_method is a callable, call it to get the next speaker. + # If self.speaker_selection_method is a string, return it. + speaker_selection_method = self.speaker_selection_method + if isinstance(self.speaker_selection_method, Callable): + selected_agent = self.speaker_selection_method(last_speaker, self) + if selected_agent is None: + raise NoEligibleSpeaker("Custom speaker selection function returned None. Terminating conversation.") + elif isinstance(selected_agent, Agent): + if selected_agent in self.agents: + return selected_agent, self.agents, None + else: + raise ValueError( + f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat." + ) + elif isinstance(selected_agent, str): + # If returned a string, assume it is a speaker selection method + speaker_selection_method = selected_agent + else: + raise ValueError( + f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str." + ) + + if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS: raise ValueError( - f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. " + f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. " f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). " ) @@ -300,7 +372,7 @@ def _prepare_and_select_agents( f"GroupChat is underpopulated with {n_agents} agents. " "Please add more agents to the GroupChat or use direct communication instead." ) - elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: + elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: logger.warning( f"GroupChat is underpopulated with {n_agents} agents. " "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, " @@ -344,9 +416,7 @@ def _prepare_and_select_agents( # this condition means last_speaker is a sink in the graph, then no agents are eligible if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group: - raise NoEligibleSpeakerException( - f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict." - ) + raise NoEligibleSpeaker(f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict.") # last_speaker is not in the group, so all agents are eligible elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group: graph_eligible_agents = [] @@ -366,13 +436,13 @@ def _prepare_and_select_agents( # Use the selected speaker selection method select_speaker_messages = None - if self.speaker_selection_method.lower() == "manual": + if speaker_selection_method.lower() == "manual": selected_agent = self.manual_select_speaker(graph_eligible_agents) - elif self.speaker_selection_method.lower() == "round_robin": + elif speaker_selection_method.lower() == "round_robin": selected_agent = self.next_agent(last_speaker, graph_eligible_agents) - elif self.speaker_selection_method.lower() == "random": + elif speaker_selection_method.lower() == "random": selected_agent = self.random_select_speaker(graph_eligible_agents) - else: + else: # auto selected_agent = None select_speaker_messages = self.messages.copy() # If last message is a tool call or function call, blank the call so the api doesn't throw @@ -381,7 +451,10 @@ def _prepare_and_select_agents( if select_speaker_messages[-1].get("tool_calls", False): select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None) select_speaker_messages = select_speaker_messages + [ - {"role": "system", "content": self.select_speaker_prompt(graph_eligible_agents)} + { + "role": self.role_for_select_speaker_messages, + "content": self.select_speaker_prompt(graph_eligible_agents), + } ] return selected_agent, graph_eligible_agents, select_speaker_messages @@ -439,6 +512,10 @@ def _participant_roles(self, agents: List[Agent] = None) -> str: def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[List[Agent]]) -> Dict: """Counts the number of times each agent is mentioned in the provided message content. + Agent names will match under any of the following conditions (all case-sensitive): + - Exact name match + - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer') + - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer') Args: message_content (Union[str, List]): The content of the message, either as a single string or a list of strings. @@ -457,9 +534,17 @@ def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[ mentions = dict() for agent in agents: + # Finds agent mentions, taking word boundaries into account, + # accommodates escaping underscores and underscores as spaces regex = ( - r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)" - ) # Finds agent mentions, taking word boundaries into account + r"(?<=\W)(" + + re.escape(agent.name) + + r"|" + + re.escape(agent.name.replace("_", " ")) + + r"|" + + re.escape(agent.name.replace("_", r"\_")) + + r")(?=\W)" + ) count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching if count > 0: mentions[agent.name] = count @@ -479,7 +564,11 @@ def __init__( system_message: Optional[Union[str, List]] = "Group chat manager.", **kwargs, ): - if kwargs.get("llm_config") and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools")): + if ( + kwargs.get("llm_config") + and isinstance(kwargs["llm_config"], dict) + and (kwargs["llm_config"].get("functions") or kwargs["llm_config"].get("tools")) + ): raise ValueError( "GroupChatManager is not allowed to make function/tool calls. Please remove the 'functions' or 'tools' config in 'llm_config' you passed in." ) @@ -584,7 +673,7 @@ def run_chat( else: # admin agent is not found in the participants raise - except NoEligibleSpeakerException: + except NoEligibleSpeaker: # No eligible speaker, terminate the conversation break @@ -701,6 +790,8 @@ def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: reply (dict): reply message dict to analyze. groupchat (GroupChat): GroupChat object. """ + iostream = IOStream.get_default() + reply_content = reply["content"] # Split the reply into words words = reply_content.split() @@ -736,21 +827,21 @@ def clear_agents_history(self, reply: dict, groupchat: GroupChat) -> str: # clear history if agent_to_memory_clear: if nr_messages_to_preserve: - print( + iostream.print( f"Clearing history for {agent_to_memory_clear.name} except last {nr_messages_to_preserve} messages." ) else: - print(f"Clearing history for {agent_to_memory_clear.name}.") + iostream.print(f"Clearing history for {agent_to_memory_clear.name}.") agent_to_memory_clear.clear_history(nr_messages_to_preserve=nr_messages_to_preserve) else: if nr_messages_to_preserve: - print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.") + iostream.print(f"Clearing history for all agents except last {nr_messages_to_preserve} messages.") # clearing history for groupchat here temp = groupchat.messages[-nr_messages_to_preserve:] groupchat.messages.clear() groupchat.messages.extend(temp) else: - print("Clearing history for all agents.") + iostream.print("Clearing history for all agents.") # clearing history for groupchat here groupchat.messages.clear() # clearing history for agents diff --git a/autogen/agentchat/user_proxy_agent.py b/autogen/agentchat/user_proxy_agent.py index dc68c6ec6d0..a80296a8355 100644 --- a/autogen/agentchat/user_proxy_agent.py +++ b/autogen/agentchat/user_proxy_agent.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, List, Literal, Optional, Union +from ..runtime_logging import log_new_agent, logging_enabled from .conversable_agent import ConversableAgent -from ..runtime_logging import logging_enabled, log_new_agent class UserProxyAgent(ConversableAgent): @@ -14,7 +14,6 @@ class UserProxyAgent(ConversableAgent): To modify the way to get human input, override `get_human_input` method. To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, `run_code`, and `execute_function` methods respectively. - To customize the initial message when a conversation starts, override `generate_init_message` method. """ # Default UserProxyAgent.description values, based on human_input_mode @@ -29,9 +28,9 @@ def __init__( name: str, is_termination_msg: Optional[Callable[[Dict], bool]] = None, max_consecutive_auto_reply: Optional[int] = None, - human_input_mode: Optional[str] = "ALWAYS", + human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS", function_map: Optional[Dict[str, Callable]] = None, - code_execution_config: Optional[Union[Dict, Literal[False]]] = None, + code_execution_config: Union[Dict, Literal[False]] = {}, default_auto_reply: Optional[Union[str, Dict, None]] = "", llm_config: Optional[Union[Dict, Literal[False]]] = False, system_message: Optional[Union[str, List]] = "", @@ -71,10 +70,11 @@ def __init__( - timeout (Optional, int): The maximum execution time in seconds. - last_n_messages (Experimental, Optional, int): The number of messages to look back for code execution. Default to 1. default_auto_reply (str or dict or None): the default auto reply message when no code execution or llm based reply is generated. - llm_config (dict or False): llm inference configuration. + llm_config (dict or False or None): llm inference configuration. Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options. - Default to false, which disables llm-based auto reply. + Default to False, which disables llm-based auto reply. + When set to None, will use self.DEFAULT_CONFIG, which defaults to False. system_message (str or List): system message for ChatCompletion inference. Only used when llm_config is not False. Use it to reprogram the agent. description (str): a short description of the agent. This description is used by other agents @@ -90,9 +90,9 @@ def __init__( code_execution_config=code_execution_config, llm_config=llm_config, default_auto_reply=default_auto_reply, - description=description - if description is not None - else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode], + description=( + description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode] + ), ) if logging_enabled(): diff --git a/autogen/agentchat/utils.py b/autogen/agentchat/utils.py index 07982d03768..b32c2f5f0a0 100644 --- a/autogen/agentchat/utils.py +++ b/autogen/agentchat/utils.py @@ -1,4 +1,6 @@ -from typing import List, Dict, Tuple, Callable +import re +from typing import Any, Callable, Dict, List, Union + from .agent import Agent @@ -24,36 +26,49 @@ def consolidate_chat_info(chat_info, uniform_sender=None) -> None: ), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm." -def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str, any]]: +def gather_usage_summary(agents: List[Agent]) -> Dict[Dict[str, Dict], Dict[str, Dict]]: r"""Gather usage summary from all agents. Args: agents: (list): List of agents. Returns: - tuple: (total_usage_summary, actual_usage_summary) + dictionary: A dictionary containing two keys: + - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference. + - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference". Example: ```python - total_usage_summary = { - "total_cost": 0.0006090000000000001, - "gpt-35-turbo": { - "cost": 0.0006090000000000001, - "prompt_tokens": 242, - "completion_tokens": 123, - "total_tokens": 365 + { + "usage_including_cached_inference" : { + "total_cost": 0.0006090000000000001, + "gpt-35-turbo": { + "cost": 0.0006090000000000001, + "prompt_tokens": 242, + "completion_tokens": 123, + "total_tokens": 365 + }, + }, + + "usage_excluding_cached_inference" : { + "total_cost": 0.0006090000000000001, + "gpt-35-turbo": { + "cost": 0.0006090000000000001, + "prompt_tokens": 242, + "completion_tokens": 123, + "total_tokens": 365 + }, } } ``` Note: - `actual_usage_summary` follows the same format. - If none of the agents incurred any cost (not having a client), then the total_usage_summary and actual_usage_summary will be `{'total_cost': 0}`. + If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`. """ - def aggregate_summary(usage_summary: Dict[str, any], agent_summary: Dict[str, any]) -> None: + def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None: if agent_summary is None: return usage_summary["total_cost"] += agent_summary.get("total_cost", 0) @@ -67,12 +82,120 @@ def aggregate_summary(usage_summary: Dict[str, any], agent_summary: Dict[str, an usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0) usage_summary[model]["total_tokens"] += data.get("total_tokens", 0) - total_usage_summary = {"total_cost": 0} - actual_usage_summary = {"total_cost": 0} + usage_including_cached_inference = {"total_cost": 0} + usage_excluding_cached_inference = {"total_cost": 0} for agent in agents: if getattr(agent, "client", None): - aggregate_summary(total_usage_summary, agent.client.total_usage_summary) - aggregate_summary(actual_usage_summary, agent.client.actual_usage_summary) + aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary) + aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary) + + return { + "usage_including_cached_inference": usage_including_cached_inference, + "usage_excluding_cached_inference": usage_excluding_cached_inference, + } + + +def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Dict[str, str]]]: + """Parses HTML style tags from message contents. + + The parsing is done by looking for patterns in the text that match the format of HTML tags. The tag to be parsed is + specified as an argument to the function. The function looks for this tag in the text and extracts its content. The + content of a tag is everything that is inside the tag, between the opening and closing angle brackets. The content + can be a single string or a set of attribute-value pairs. + + Examples: + -> [{"tag": "img", "attr": {"src": "http://example.com/image.png"}, "match": re.Match}] +