diff --git a/.gitattributes b/.gitattributes
index e69de29bb2d..c139e44b4dc 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -0,0 +1,3 @@
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 5be7688b06e..d06999db34c 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -6,11 +6,6 @@ name: Build
on:
push:
branches: ["main"]
- paths:
- - "autogen/**"
- - "test/**"
- - ".github/workflows/build.yml"
- - "setup.py"
pull_request:
branches: ["main"]
merge_group:
@@ -21,7 +16,39 @@ concurrency:
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions: {}
jobs:
+ paths-filter:
+ runs-on: ubuntu-latest
+ outputs:
+ hasChanges: ${{ steps.filter.outputs.autogen == 'true' || steps.filter.outputs.test == 'true' || steps.filter.outputs.workflows == 'true' || steps.filter.outputs.setup == 'true' }}
+ steps:
+ - uses: actions/checkout@v4
+ - uses: dorny/paths-filter@v2
+ id: filter
+ with:
+ filters: |
+ autogen:
+ - "autogen/**"
+ test:
+ - "test/**"
+ workflows:
+ - ".github/workflows/**"
+ setup:
+ - "setup.py"
+ - name: autogen has changes
+ run: echo "autogen has changes"
+ if: steps.filter.outputs.autogen == 'true'
+ - name: test has changes
+ run: echo "test has changes"
+ if: steps.filter.outputs.test == 'true'
+ - name: workflows has changes
+ run: echo "workflows has changes"
+ if: steps.filter.outputs.workflows == 'true'
+ - name: setup has changes
+ run: echo "setup has changes"
+ if: steps.filter.outputs.setup == 'true'
build:
+ needs: paths-filter
+ if: needs.paths-filter.outputs.hasChanges == 'true'
runs-on: ${{ matrix.os }}
env:
AUTOGEN_USE_DOCKER: ${{ matrix.os != 'ubuntu-latest' && 'False' }}
@@ -30,6 +57,11 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.8"
+ - os: macos-latest
+ python-version: "3.9"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -39,9 +71,9 @@ jobs:
- name: Install packages and dependencies
run: |
python -m pip install --upgrade pip wheel
- pip install -e .
+ pip install -e .[cosmosdb]
python -c "import autogen"
- pip install pytest mock
+ pip install pytest-cov>=5 mock
- name: Install optional dependencies for code executors
# code executors and udfs auto skip without deps, so only run for python 3.11
if: matrix.python-version == '3.11'
@@ -57,20 +89,52 @@ jobs:
- name: Test with pytest skipping openai tests
if: matrix.python-version != '3.10' && matrix.os == 'ubuntu-latest'
run: |
- pytest test --skip-openai --durations=10 --durations-min=1.0
+ pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0
- name: Test with pytest skipping openai and docker tests
if: matrix.python-version != '3.10' && matrix.os != 'ubuntu-latest'
run: |
- pytest test --skip-openai --skip-docker --durations=10 --durations-min=1.0
- - name: Coverage
+ pytest test --ignore=test/agentchat/contrib --skip-openai --skip-docker --durations=10 --durations-min=1.0
+ - name: Coverage with Redis
if: matrix.python-version == '3.10'
run: |
pip install -e .[test,redis,websockets]
- coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0
- coverage xml
+ pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0
+ - name: Test with Cosmos DB
+ run: |
+ pip install -e .[test,cosmosdb]
+ pytest test/cache/test_cosmos_db_cache.py --skip-openai --durations=10 --durations-min=1.0
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
+ build-check:
+ if: always()
+ runs-on: ubuntu-latest
+ needs: [build]
+ steps:
+ - name: Get Date
+ shell: bash
+ run: |
+ echo "date=$(date +'%m/%d/%Y %H:%M:%S')" >> "$GITHUB_ENV"
+
+ - name: Run Type is ${{ github.event_name }}
+ if: ${{ github.event_name != 'schedule' && github.event_name != 'workflow_dispatch'}}
+ shell: bash
+ run: |
+ echo "run_type=${{ github.event_name }}" >> "$GITHUB_ENV"
+
+ - name: Fail workflow if build failed
+ id: check_build_failed
+ if: contains(join(needs.*.result, ','), 'failure')
+ uses: actions/github-script@v6
+ with:
+ script: core.setFailed('Build Failed!')
+
+ - name: Fail workflow if build cancelled
+ id: check_build_cancelled
+ if: contains(join(needs.*.result, ','), 'cancelled')
+ uses: actions/github-script@v6
+ with:
+ script: core.setFailed('Build Cancelled!')
diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml
index 4eda8d93071..b1b3e35e478 100644
--- a/.github/workflows/contrib-openai.yml
+++ b/.github/workflows/contrib-openai.yml
@@ -5,14 +5,15 @@ name: OpenAI4ContribTests
on:
pull_request:
- branches: ['main']
+ branches: ["main"]
paths:
- - 'autogen/**'
- - 'test/agentchat/contrib/**'
- - '.github/workflows/contrib-openai.yml'
- - 'setup.py'
-permissions: {}
- # actions: read
+ - "autogen/**"
+ - "test/agentchat/contrib/**"
+ - ".github/workflows/contrib-openai.yml"
+ - "setup.py"
+permissions:
+ {}
+ # actions: read
# checks: read
# contents: read
# deployments: read
@@ -24,6 +25,21 @@ jobs:
python-version: ["3.10"]
runs-on: ${{ matrix.os }}
environment: openai1
+ services:
+ pgvector:
+ image: ankane/pgvector
+ env:
+ POSTGRES_DB: postgres
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
+ POSTGRES_HOST_AUTH_METHOD: trust
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+ ports:
+ - 5432:5432
steps:
# checkout to pr branch
- name: Checkout
@@ -40,12 +56,48 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
- pip install coverage pytest-asyncio
+ pip install pytest-cov>=5 pytest-asyncio
- name: Install packages for test when needed
run: |
pip install docker
- pip install qdrant_client[fastembed]
- pip install -e .[retrievechat]
+ pip install -e .[retrievechat,retrievechat-qdrant,retrievechat-pgvector]
+ - name: Coverage
+ env:
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
+ AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
+ OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
+ run: |
+ pytest test/agentchat/contrib/retrievechat/ test/agentchat/contrib/retrievechat
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+ AgentEvalTest:
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.10"]
+ runs-on: ${{ matrix.os }}
+ environment: openai1
+ steps:
+ # checkout to pr branch
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ github.event.pull_request.head.sha }}
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies
+ run: |
+ docker --version
+ python -m pip install --upgrade pip wheel
+ pip install -e .
+ python -c "import autogen"
+ pip install pytest-cov>=5 pytest-asyncio
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -53,8 +105,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
- coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py
- coverage xml
+ pytest test/agentchat/contrib/agent_eval/test_agent_eval.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -83,7 +134,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
- pip install coverage pytest-asyncio
+ pip install pytest-cov>=5 pytest-asyncio
- name: Install packages for test when needed
run: |
pip install docker
@@ -94,8 +145,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
- coverage run -a -m pytest test/agentchat/contrib/test_compressible_agent.py
- coverage xml
+ pytest test/agentchat/contrib/test_compressible_agent.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -124,7 +174,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
- pip install coverage pytest-asyncio
+ pip install pytest-cov>=5 pytest-asyncio
- name: Install packages for test when needed
run: |
pip install docker
@@ -135,8 +185,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
- coverage run -a -m pytest test/agentchat/contrib/test_gpt_assistant.py
- coverage xml
+ pytest test/agentchat/contrib/test_gpt_assistant.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -165,7 +214,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .[teachable]
python -c "import autogen"
- pip install coverage pytest
+ pip install pytest-cov>=5
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -173,8 +222,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
- coverage run -a -m pytest test/agentchat/contrib/capabilities/test_teachable_agent.py
- coverage xml
+ pytest test/agentchat/contrib/capabilities/test_teachable_agent.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -183,8 +231,8 @@ jobs:
AgentBuilder:
strategy:
matrix:
- os: [ ubuntu-latest ]
- python-version: [ "3.11" ]
+ os: [ubuntu-latest]
+ python-version: ["3.11"]
runs-on: ${{ matrix.os }}
environment: openai1
steps:
@@ -203,7 +251,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
- pip install coverage pytest-asyncio
+ pip install pytest-cov>=5 pytest-asyncio
- name: Install packages for test when needed
run: |
pip install -e .[autobuild]
@@ -214,8 +262,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
- coverage run -a -m pytest test/agentchat/contrib/test_agent_builder.py
- coverage xml
+ pytest test/agentchat/contrib/test_agent_builder.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -244,7 +291,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e .[websurfer]
python -c "import autogen"
- pip install coverage pytest
+ pip install pytest-cov>=5
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -253,84 +300,119 @@ jobs:
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
BING_API_KEY: ${{ secrets.BING_API_KEY }}
run: |
- coverage run -a -m pytest test/agentchat/contrib/test_web_surfer.py
- coverage xml
+ pytest test/agentchat/contrib/test_web_surfer.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
ContextHandling:
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: ["3.11"]
- runs-on: ${{ matrix.os }}
- environment: openai1
- steps:
- # checkout to pr branch
- - name: Checkout
- uses: actions/checkout@v4
- with:
- ref: ${{ github.event.pull_request.head.sha }}
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install packages and dependencies
- run: |
- docker --version
- python -m pip install --upgrade pip wheel
- pip install -e .
- python -c "import autogen"
- pip install coverage pytest
- - name: Coverage
- env:
- OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
- AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
- OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
- BING_API_KEY: ${{ secrets.BING_API_KEY }}
- run: |
- coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py
- coverage xml
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v3
- with:
- file: ./coverage.xml
- flags: unittests
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.11"]
+ runs-on: ${{ matrix.os }}
+ environment: openai1
+ steps:
+ # checkout to pr branch
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ github.event.pull_request.head.sha }}
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies
+ run: |
+ docker --version
+ python -m pip install --upgrade pip wheel
+ pip install -e .
+ python -c "import autogen"
+ pip install pytest-cov>=5
+ - name: Coverage
+ env:
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
+ AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
+ OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
+ BING_API_KEY: ${{ secrets.BING_API_KEY }}
+ run: |
+ pytest test/agentchat/contrib/capabilities/test_context_handling.py
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
ImageGen:
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: ["3.12"]
- runs-on: ${{ matrix.os }}
- environment: openai1
- steps:
- # checkout to pr branch
- - name: Checkout
- uses: actions/checkout@v4
- with:
- ref: ${{ github.event.pull_request.head.sha }}
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install packages and dependencies
- run: |
- docker --version
- python -m pip install --upgrade pip wheel
- pip install -e .[lmm]
- python -c "import autogen"
- pip install coverage pytest
- - name: Coverage
- env:
- OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- run: |
- coverage run -a -m pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py
- coverage xml
- - name: Upload coverage to Codecov
- uses: codecov/codecov-action@v3
- with:
- file: ./coverage.xml
- flags: unittests
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.12"]
+ runs-on: ${{ matrix.os }}
+ environment: openai1
+ steps:
+ # checkout to pr branch
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ github.event.pull_request.head.sha }}
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies
+ run: |
+ docker --version
+ python -m pip install --upgrade pip wheel
+ pip install -e .[lmm]
+ python -c "import autogen"
+ pip install pytest-cov>=5
+ - name: Coverage
+ env:
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ run: |
+ pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ AgentOptimizer:
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.11"]
+ runs-on: ${{ matrix.os }}
+ environment: openai1
+ steps:
+ # checkout to pr branch
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ github.event.pull_request.head.sha }}
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies
+ run: |
+ docker --version
+ python -m pip install --upgrade pip wheel
+ pip install -e .
+ python -c "import autogen"
+ pip install pytest-cov>=5
+ - name: Coverage
+ env:
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
+ AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
+ OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
+ run: |
+ pytest test/agentchat/contrib/test_agent_optimizer.py
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml
index ced35dc115b..005abe4ef8e 100644
--- a/.github/workflows/contrib-tests.yml
+++ b/.github/workflows/contrib-tests.yml
@@ -27,8 +27,11 @@ jobs:
strategy:
fail-fast: false
matrix:
- os: [ubuntu-latest, macos-latest, windows-2019]
+ os: [macos-latest, windows-2019]
python-version: ["3.9", "3.10", "3.11"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -38,15 +41,11 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install qdrant_client when python-version is 3.10
if: matrix.python-version == '3.10'
run: |
- pip install qdrant_client[fastembed]
- - name: Install unstructured when python-version is 3.9 and not windows
- if: matrix.python-version == '3.9' && matrix.os != 'windows-2019'
- run: |
- pip install unstructured[all-docs]
+ pip install -e .[retrievechat-qdrant]
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
@@ -56,14 +55,99 @@ jobs:
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- - name: Test RetrieveChat
+ - name: Coverage
+ run: |
+ pytest test/test_retrieve_utils.py test/agentchat/contrib/retrievechat/test_retrievechat.py test/agentchat/contrib/retrievechat/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ RetrieveChatTest-Ubuntu:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.9", "3.10", "3.11"]
+ services:
+ pgvector:
+ image: ankane/pgvector
+ env:
+ POSTGRES_DB: postgres
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
+ POSTGRES_HOST_AUTH_METHOD: trust
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+ ports:
+ - 5432:5432
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest
+ - name: Install qdrant_client when python-version is 3.10
+ if: matrix.python-version == '3.10'
+ run: |
+ pip install -e .[retrievechat-qdrant]
+ - name: Install pgvector when on linux
run: |
- pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
+ pip install -e .[retrievechat-pgvector]
+ - name: Install unstructured when python-version is 3.9 and on linux
+ if: matrix.python-version == '3.9'
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y tesseract-ocr poppler-utils
+ pip install unstructured[all-docs]==0.13.0
+ - name: Install packages and dependencies for RetrieveChat
+ run: |
+ pip install -e .[retrievechat]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
- coverage xml
+ pip install pytest-cov>=5
+ pytest test/test_retrieve_utils.py test/agentchat/contrib/retrievechat test/agentchat/contrib/vectordb --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ AgentEvalTest:
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.10"]
+ runs-on: ${{ matrix.os }}
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for AgentEval
+ run: |
+ pip install -e .
+ - name: Coverage
+ run: |
+ pytest test/agentchat/contrib/agent_eval/ --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -76,7 +160,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
- python-version: ["3.8"]
+ python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -86,7 +170,7 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install packages and dependencies for Compression
run: |
pip install -e .
@@ -98,9 +182,7 @@ jobs:
fi
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_compressible_agent.py --skip-openai
- coverage xml
+ pytest test/agentchat/contrib/test_compressible_agent.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -123,7 +205,7 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install packages and dependencies for GPTAssistantAgent
run: |
pip install -e .
@@ -135,9 +217,7 @@ jobs:
fi
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_gpt_assistant.py --skip-openai
- coverage xml
+ pytest test/agentchat/contrib/test_gpt_assistant.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -160,7 +240,7 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install packages and dependencies for Teachability
run: |
pip install -e .[teachable]
@@ -172,9 +252,7 @@ jobs:
fi
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/capabilities/test_teachable_agent.py --skip-openai
- coverage xml
+ pytest test/agentchat/contrib/capabilities/test_teachable_agent.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -197,7 +275,7 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install packages and dependencies for WebSurfer
run: |
pip install -e .[websurfer]
@@ -209,9 +287,7 @@ jobs:
fi
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/test_browser_utils.py test/agentchat/contrib/test_web_surfer.py --skip-openai
- coverage xml
+ pytest test/test_browser_utils.py test/agentchat/contrib/test_web_surfer.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -227,6 +303,8 @@ jobs:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
+ with:
+ lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
@@ -234,7 +312,7 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install packages and dependencies for LMM
run: |
pip install -e .[lmm]
@@ -246,9 +324,51 @@ jobs:
fi
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai
- coverage xml
+ pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai
+ - name: Image Gen Coverage
+ if: ${{ matrix.os != 'windows-2019' && matrix.python-version != '3.12' }}
+ run: |
+ pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittests
+
+ GeminiTest:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-2019]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Gemini
+ run: |
+ pip install -e .[gemini,test]
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/oai/test_gemini.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
@@ -271,7 +391,7 @@ jobs:
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
- pip install pytest
+ pip install pytest-cov>=5
- name: Install packages and dependencies for Context Handling
run: |
pip install -e .
@@ -283,11 +403,44 @@ jobs:
fi
- name: Coverage
run: |
- pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
- coverage xml
+ pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
+
+ TransformMessages:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-2019]
+ python-version: ["3.11"]
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install packages and dependencies for all tests
+ run: |
+ python -m pip install --upgrade pip wheel
+ pip install pytest-cov>=5
+ - name: Install packages and dependencies for Transform Messages
+ run: |
+ pip install -e '.[long-context]'
+ - name: Set AUTOGEN_USE_DOCKER based on OS
+ shell: bash
+ run: |
+ if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
+ echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
+ fi
+ - name: Coverage
+ run: |
+ pytest test/agentchat/contrib/capabilities/test_transform_messages.py --skip-openai
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ file: ./coverage.xml
+ flags: unittest
diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml
index d223fffd28b..2e679412f63 100644
--- a/.github/workflows/dotnet-build.yml
+++ b/.github/workflows/dotnet-build.yml
@@ -6,11 +6,11 @@ name: dotnet-ci
on:
workflow_dispatch:
pull_request:
- branches: [ "dotnet" ]
- paths:
- - 'dotnet/**'
+ branches: [ "main" ]
push:
- branches: [ "dotnet" ]
+ branches: [ "main" ]
+ merge_group:
+ types: [checks_requested]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
@@ -21,14 +21,38 @@ permissions:
packages: write
jobs:
+ paths-filter:
+ runs-on: ubuntu-latest
+ outputs:
+ hasChanges: ${{ steps.filter.outputs.dotnet == 'true'}}
+ steps:
+ - uses: actions/checkout@v4
+ - uses: dorny/paths-filter@v2
+ id: filter
+ with:
+ filters: |
+ dotnet:
+ - "dotnet/**"
+ workflows:
+ - ".github/workflows/**"
+ - name: dotnet has changes
+ run: echo "dotnet has changes"
+ if: steps.filter.outputs.dotnet == 'true'
+ - name: workflows has changes
+ run: echo "workflows has changes"
+ if: steps.filter.outputs.workflows == 'true'
build:
- name: Build
+ name: Dotnet Build
runs-on: ubuntu-latest
+ needs: paths-filter
+ if: needs.paths-filter.outputs.hasChanges == 'true'
defaults:
run:
working-directory: dotnet
steps:
- uses: actions/checkout@v4
+ with:
+ lfs: true
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
@@ -50,10 +74,12 @@ jobs:
defaults:
run:
working-directory: dotnet
- if: success() && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dotnet')
+ if: success() && (github.ref == 'refs/heads/main')
needs: build
steps:
- uses: actions/checkout@v4
+ with:
+ lfs: true
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
@@ -139,4 +165,3 @@ jobs:
dotnet nuget push --api-key ${{ secrets.MYGET_TOKEN }} --source "https://www.myget.org/F/agentchat/api/v3/index.json" ./output/nightly/*.nupkg --skip-duplicate
env:
MYGET_TOKEN: ${{ secrets.MYGET_TOKEN }}
-
diff --git a/.github/workflows/dotnet-release.yml b/.github/workflows/dotnet-release.yml
index d66f21a6cd6..b512b4c1696 100644
--- a/.github/workflows/dotnet-release.yml
+++ b/.github/workflows/dotnet-release.yml
@@ -7,7 +7,7 @@ on:
workflow_dispatch:
push:
branches:
- - dotnet/release
+ - dotnet/release/**
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
@@ -27,6 +27,8 @@ jobs:
working-directory: dotnet
steps:
- uses: actions/checkout@v4
+ with:
+ lfs: true
- name: Setup .NET
uses: actions/setup-dotnet@v4
with:
@@ -57,13 +59,6 @@ jobs:
echo "Publish package to Nuget"
echo "ls output directory"
ls -R ./output/release
- dotnet nuget push --api-key AzureArtifacts ./output/release/*.nupkg --skip-duplicate --api-key ${{ secrets.AUTOGEN_NUGET_API_KEY }}
- - name: Tag commit
- run: |
- Write-Host "Tag commit"
- # version = eng/MetaInfo.props.Project.PropertyGroup.VersionPrefix
- $metaInfoContent = cat ./eng/MetaInfo.props
- $version = $metaInfoContent | Select-String -Pattern "(.*) " | ForEach-Object { $_.Matches.Groups[1].Value }
- git tag -a "$version" -m "AutoGen.Net release $version"
- git push origin --tags
- shell: pwsh
\ No newline at end of file
+ # remove AutoGen.SourceGenerator.snupkg because it's an empty package
+ rm ./output/release/AutoGen.SourceGenerator.*.snupkg
+ dotnet nuget push --api-key ${{ secrets.AUTOGEN_NUGET_API_KEY }} --source https://api.nuget.org/v3/index.json ./output/release/*.nupkg --skip-duplicate
diff --git a/.github/workflows/lfs-check.yml b/.github/workflows/lfs-check.yml
new file mode 100644
index 00000000000..4baae925de3
--- /dev/null
+++ b/.github/workflows/lfs-check.yml
@@ -0,0 +1,15 @@
+name: "Git LFS Check"
+
+on: pull_request
+permissions: {}
+jobs:
+ lfs-check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ lfs: true
+ - name: "Check Git LFS files for consistency, if you see error like 'pointer: unexpectedGitObject ... should have been a pointer but was not', please install Git LFS locally, delete the problematic file, and then add it back again. This ensures it's properly tracked."
+ run: |
+ git lfs fsck
diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml
index d2780eea542..a9ab8e9e0c5 100644
--- a/.github/workflows/openai.yml
+++ b/.github/workflows/openai.yml
@@ -13,7 +13,8 @@ on:
- "notebook/agentchat_function_call.ipynb"
- "notebook/agentchat_groupchat_finite_state_machine.ipynb"
- ".github/workflows/openai.yml"
-permissions: {}
+permissions:
+ {}
# actions: read
# checks: read
# contents: read
@@ -49,7 +50,7 @@ jobs:
python -m pip install --upgrade pip wheel
pip install -e.
python -c "import autogen"
- pip install coverage pytest-asyncio
+ pip install pytest-cov>=5 pytest-asyncio
- name: Install packages for test when needed
if: matrix.python-version == '3.9'
run: |
@@ -63,8 +64,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
- coverage run -a -m pytest test --ignore=test/agentchat/contrib --durations=10 --durations-min=1.0
- coverage xml
+ pytest test --ignore=test/agentchat/contrib --durations=10 --durations-min=1.0
- name: Coverage and check notebook outputs
if: matrix.python-version != '3.9'
env:
@@ -75,8 +75,7 @@ jobs:
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
pip install nbconvert nbformat ipykernel
- coverage run -a -m pytest test/test_notebook.py --durations=10 --durations-min=1.0
- coverage xml
+ pytest test/test_notebook.py --durations=10 --durations-min=1.0
cat "$(pwd)/test/executed_openai_notebook_output.txt"
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
diff --git a/.github/workflows/samples-tools-tests.yml b/.github/workflows/samples-tools-tests.yml
index 12c8de3b7af..e774e5cb0b1 100644
--- a/.github/workflows/samples-tools-tests.yml
+++ b/.github/workflows/samples-tools-tests.yml
@@ -24,6 +24,9 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.9", "3.10", "3.11"]
+ exclude:
+ - os: macos-latest
+ python-version: "3.9"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@@ -34,7 +37,7 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
pip install -e .
- pip install pytest
+ pip install pytest-cov>=5
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml
index f6896d1145d..c66fb6ad7b1 100644
--- a/.github/workflows/type-check.yml
+++ b/.github/workflows/type-check.yml
@@ -1,6 +1,6 @@
name: Type check
# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
-on: # Trigger the workflow on pull request or merge
+on: # Trigger the workflow on pull request or merge
pull_request:
merge_group:
types: [checks_requested]
@@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
- python-version: ${{ matrix.version }}
+ python-version: ${{ matrix.version }}
# All additional modules should be defined in setup.py
- run: pip install ".[types]"
# Any additional configuration should be defined in pyproject.toml
diff --git a/.gitignore b/.gitignore
index 49a41e9ed2c..4c925f739ec 100644
--- a/.gitignore
+++ b/.gitignore
@@ -172,6 +172,10 @@ test/my_tmp/*
# Storage for the AgentEval output
test/test_files/agenteval-in-out/out/
+# local cache or coding foler
+local_cache/
+coding/
+
# Files created by tests
*tmp_code_*
test/agentchat/test_agent_scripts/*
@@ -179,7 +183,10 @@ test/agentchat/test_agent_scripts/*
# test cache
.cache_test
.db
+local_cache
notebook/result.png
samples/apps/autogen-studio/autogenstudio/models/test/
+
+notebook/coding
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 025cc7cbb17..fcea09223c6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,11 +26,12 @@ repos:
rev: 24.3.0
hooks:
- id: black
- - repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.3.3
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.3.4
hooks:
- id: ruff
- args: ["--fix"]
+ types_or: [ python, pyi, jupyter ]
+ args: ["--fix", "--ignore=E402"]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
@@ -42,6 +43,8 @@ repos:
website/static/img/ag.svg |
website/yarn.lock |
website/docs/tutorial/code-executors.ipynb |
+ website/docs/topics/code-execution/custom-executor.ipynb |
+ website/docs/topics/non-openai-models/cloud-gemini.ipynb |
notebook/.*
)$
# See https://jaredkhan.com/blog/mypy-pre-commit
@@ -61,9 +64,6 @@ repos:
# Print the number of files as a sanity-check
verbose: true
- repo: https://github.com/nbQA-dev/nbQA
- rev: 1.8.4
+ rev: 1.8.5
hooks:
- - id: nbqa-ruff
- # Don't require notebooks to have all imports at the top
- args: ["--fix", "--ignore=E402"]
- id: nbqa-black
diff --git a/OAI_CONFIG_LIST_sample b/OAI_CONFIG_LIST_sample
index ef027f815ba..9fc0dc803a0 100644
--- a/OAI_CONFIG_LIST_sample
+++ b/OAI_CONFIG_LIST_sample
@@ -5,7 +5,8 @@
[
{
"model": "gpt-4",
- "api_key": ""
+ "api_key": "",
+ "tags": ["gpt-4", "tool"]
},
{
"model": "",
diff --git a/README.md b/README.md
index 76f469ecef5..327dc8c4e54 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,10 @@
+
+
[![PyPI version](https://badge.fury.io/py/pyautogen.svg)](https://badge.fury.io/py/pyautogen)
[![Build](https://github.com/microsoft/autogen/actions/workflows/python-package.yml/badge.svg)](https://github.com/microsoft/autogen/actions/workflows/python-package.yml)
![Python Version](https://img.shields.io/badge/3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)
[![Downloads](https://static.pepy.tech/badge/pyautogen/week)](https://pepy.tech/project/pyautogen)
-[![Discord](https://img.shields.io/discord/1153072414184452236?logo=discord&style=flat)](https://discord.gg/pAbnFJrkgZ)
+[![Discord](https://img.shields.io/discord/1153072414184452236?logo=discord&style=flat)](https://aka.ms/autogen-dc)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow%20%40pyautogen)](https://twitter.com/pyautogen)
@@ -12,29 +14,37 @@
-->
-:fire: Mar 3: What's new in AutoGen? π°[Blog](https://microsoft.github.io/autogen/blog/2024/03/03/AutoGen-Update); πΊ[Youtube](https://www.youtube.com/watch?v=j_mtwQiaLGU).
+:fire: May 13, 2024: [The Economist](https://www.economist.com/science-and-technology/2024/05/13/todays-ai-models-are-impressive-teams-of-them-will-be-formidable) published an article about multi-agent systems (MAS) following a January 2024 interview with [Chi Wang](https://github.com/sonichi).
+
+:fire: May 11, 2024: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://openreview.net/pdf?id=uAjxFFing2) received the best paper award in [ICLR 2024 LLM Agents Workshop](https://llmagents.github.io/).
+
+:fire: Apr 26, 2024: [AutoGen.NET](https://microsoft.github.io/autogen-for-net/) is available for .NET developers!
-:fire: Mar 1: the first AutoGen multi-agent experiment on the challenging [GAIA](https://huggingface.co/spaces/gaia-benchmark/leaderboard) benchmark achieved the No. 1 accuracy in all the three levels.
+:fire: Apr 17, 2024: Andrew Ng cited AutoGen in [The Batch newsletter](https://www.deeplearning.ai/the-batch/issue-245/) and [What's next for AI agentic workflows](https://youtu.be/sal78ACtGTc?si=JduUzN_1kDnMq0vF) at Sequoia Capital's AI Ascent (Mar 26).
-:fire: Jan 30: AutoGen is highlighted by Peter Lee in Microsoft Research Forum [Keynote](https://t.co/nUBSjPDjqD).
+:fire: Mar 3, 2024: What's new in AutoGen? π°[Blog](https://microsoft.github.io/autogen/blog/2024/03/03/AutoGen-Update); πΊ[Youtube](https://www.youtube.com/watch?v=j_mtwQiaLGU).
-:fire: Dec 31: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023).
+:fire: Mar 1, 2024: the first AutoGen multi-agent experiment on the challenging [GAIA](https://huggingface.co/spaces/gaia-benchmark/leaderboard) benchmark achieved the No. 1 accuracy in all the three levels.
+
+
+
+:tada: Dec 31, 2023: [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155) is selected by [TheSequence: My Five Favorite AI Papers of 2023](https://thesequence.substack.com/p/my-five-favorite-ai-papers-of-2023).
-:fire: Nov 8: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff.
+:tada: Nov 8, 2023: AutoGen is selected into [Open100: Top 100 Open Source achievements](https://www.benchcouncil.org/evaluation/opencs/annual.html) 35 days after spinoff from [FLAML](https://github.com/microsoft/FLAML).
-:fire: Nov 6: AutoGen is mentioned by Satya Nadella in a [fireside chat](https://youtu.be/0pLBvgYtv6U).
+
-:fire: Nov 1: AutoGen is the top trending repo on GitHub in October 2023.
+
-:tada: Oct 03: AutoGen spins off from FLAML on GitHub and has a major paper update (first version on Aug 16).
+
-:tada: Mar 29: AutoGen is first created in [FLAML](https://github.com/microsoft/FLAML).
+:tada: Mar 29, 2023: AutoGen is first created in [FLAML](https://github.com/microsoft/FLAML).
+
+
+ β Back to Top β
+
+
+
## What is AutoGen
AutoGen is a framework that enables the development of LLM applications using multiple agents that can converse with each other to solve tasks. AutoGen agents are customizable, conversable, and seamlessly allow human participation. They can operate in various modes that employ combinations of LLMs, human inputs, and tools.
@@ -57,10 +73,22 @@ AutoGen is a framework that enables the development of LLM applications using mu
AutoGen is powered by collaborative [research studies](https://microsoft.github.io/autogen/docs/Research) from Microsoft, Penn State University, and the University of Washington.
+
+
+ β Back to Top β
+
+
+
## Roadmaps
To see what we are working on and what we plan to work on, please check our
-[Roadmap Issues](https://github.com/microsoft/autogen/issues?q=is%3Aopen+is%3Aissue+label%3Aroadmap).
+[Roadmap Issues](https://aka.ms/autogen-roadmap).
+
+
+
+ β Back to Top β
+
+
## Quickstart
The easiest way to start playing is
@@ -72,6 +100,13 @@ The easiest way to start playing is
3. Start playing with the notebooks!
*NOTE*: OAI_CONFIG_LIST_sample lists GPT-4 as the default model, as this represents our current recommendation, and is known to work well with AutoGen. If you use a model other than GPT-4, you may need to revise various system prompts (especially if using weaker models like GPT-3.5-turbo). Moreover, if you use models other than those hosted by OpenAI or Azure, you may incur additional risks related to alignment and safety. Proceed with caution if updating this default.
+
+
+
+ β Back to Top β
+
+
+
## [Installation](https://microsoft.github.io/autogen/docs/Installation)
### Option 1. Install and Run AutoGen in Docker
@@ -100,6 +135,12 @@ Even if you are installing and running AutoGen locally outside of docker, the re
For LLM inference configurations, check the [FAQs](https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints).
+
+
+ β Back to Top β
+
+
+
## Multi-Agent Conversation Framework
Autogen enables the next-gen LLM applications with a generic [multi-agent conversation](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat) framework. It offers customizable and conversable agents that integrate LLMs, tools, and humans.
@@ -139,6 +180,12 @@ The figure below shows an example conversation flow with AutoGen.
Alternatively, the [sample code](https://github.com/microsoft/autogen/blob/main/samples/simple_chat.py) here allows a user to chat with an AutoGen agent in ChatGPT style.
Please find more [code examples](https://microsoft.github.io/autogen/docs/Examples#automated-multi-agent-chat) for this feature.
+
+
+ β Back to Top β
+
+
+
## Enhanced LLM Inferences
Autogen also helps maximize the utility out of the expensive LLMs such as ChatGPT and GPT-4. It offers [enhanced LLM inference](https://microsoft.github.io/autogen/docs/Use-Cases/enhanced_inference#api-unification) with powerful functionalities like caching, error handling, multi-config inference and templating.
@@ -162,6 +209,12 @@ response = autogen.Completion.create(context=test_instance, **config)
Please find more [code examples](https://microsoft.github.io/autogen/docs/Examples#tune-gpt-models) for this feature. -->
+
+
+ β Back to Top β
+
+
+
## Documentation
You can find detailed documentation about AutoGen [here](https://microsoft.github.io/autogen/).
@@ -170,12 +223,18 @@ In addition, you can find:
- [Research](https://microsoft.github.io/autogen/docs/Research), [blogposts](https://microsoft.github.io/autogen/blog) around AutoGen, and [Transparency FAQs](https://github.com/microsoft/autogen/blob/main/TRANSPARENCY_FAQS.md)
-- [Discord](https://discord.gg/pAbnFJrkgZ)
+- [Discord](https://aka.ms/autogen-dc)
- [Contributing guide](https://microsoft.github.io/autogen/docs/Contribute)
- [Roadmap](https://github.com/orgs/microsoft/projects/989/views/3)
+
+
+ β Back to Top β
+
+
+
## Related Papers
[AutoGen](https://arxiv.org/abs/2308.08155)
@@ -213,6 +272,23 @@ In addition, you can find:
}
```
+[AgentOptimizer](https://arxiv.org/pdf/2402.11359)
+
+```
+@article{zhang2024training,
+ title={Training Language Model Agents without Modifying Language Models},
+ author={Zhang, Shaokun and Zhang, Jieyu and Liu, Jiale and Song, Linxin and Wang, Chi and Krishna, Ranjay and Wu, Qingyun},
+ journal={ICML'24},
+ year={2024}
+}
+```
+
+
+
+ β Back to Top β
+
+
+
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
@@ -229,11 +305,23 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
+
+
+ β Back to Top β
+
+
+
## Contributors Wall
-
+
+
+
+ β Back to Top β
+
+
+
# Legal Notices
Microsoft and any contributors grant you a license to the Microsoft documentation and other content
@@ -250,3 +338,9 @@ Privacy information can be found at https://privacy.microsoft.com/en-us/
Microsoft and any contributors reserve all other rights, whether under their respective copyrights, patents,
or trademarks, whether by implication, estoppel, or otherwise.
+
+
+
+ β Back to Top β
+
+
diff --git a/autogen/__init__.py b/autogen/__init__.py
index ba920c92e46..02f956c4bcf 100644
--- a/autogen/__init__.py
+++ b/autogen/__init__.py
@@ -1,10 +1,10 @@
import logging
-from .version import __version__
-from .oai import *
+
from .agentchat import *
-from .exception_utils import *
from .code_utils import DEFAULT_MODEL, FAST_MODEL
-
+from .exception_utils import *
+from .oai import *
+from .version import __version__
# Set the root logger.
logger = logging.getLogger(__name__)
diff --git a/autogen/_pydantic.py b/autogen/_pydantic.py
index 89dbc4fd291..c463dbb3875 100644
--- a/autogen/_pydantic.py
+++ b/autogen/_pydantic.py
@@ -13,7 +13,7 @@
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue
- def type2schema(t: Optional[Type]) -> JsonSchemaValue:
+ def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
@@ -51,11 +51,11 @@ def model_dump_json(model: BaseModel) -> str:
# Remove this once we drop support for pydantic 1.x
else: # pragma: no cover
from pydantic import schema_of
- from pydantic.typing import evaluate_forwardref as evaluate_forwardref
+ from pydantic.typing import evaluate_forwardref as evaluate_forwardref # type: ignore[no-redef]
- JsonSchemaValue = Dict[str, Any]
+ JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
- def type2schema(t: Optional[Type]) -> JsonSchemaValue:
+ def type2schema(t: Any) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
@@ -64,27 +64,27 @@ def type2schema(t: Optional[Type]) -> JsonSchemaValue:
Returns:
JsonSchemaValue: The JSON schema
"""
- if PYDANTIC_V1:
- if t is None:
- return {"type": "null"}
- elif get_origin(t) is Union:
- return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
- elif get_origin(t) in [Tuple, tuple]:
- prefixItems = [type2schema(tt) for tt in get_args(t)]
- return {
- "maxItems": len(prefixItems),
- "minItems": len(prefixItems),
- "prefixItems": prefixItems,
- "type": "array",
- }
-
- d = schema_of(t)
- if "title" in d:
- d.pop("title")
- if "description" in d:
- d.pop("description")
-
- return d
+
+ if t is None:
+ return {"type": "null"}
+ elif get_origin(t) is Union:
+ return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
+ elif get_origin(t) in [Tuple, tuple]:
+ prefixItems = [type2schema(tt) for tt in get_args(t)]
+ return {
+ "maxItems": len(prefixItems),
+ "minItems": len(prefixItems),
+ "prefixItems": prefixItems,
+ "type": "array",
+ }
+ else:
+ d = schema_of(t)
+ if "title" in d:
+ d.pop("title")
+ if "description" in d:
+ d.pop("description")
+
+ return d
def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py
index 817fc8abdb6..d31a59d98fb 100644
--- a/autogen/agentchat/__init__.py
+++ b/autogen/agentchat/__init__.py
@@ -1,9 +1,9 @@
from .agent import Agent
from .assistant_agent import AssistantAgent
+from .chat import ChatResult, initiate_chats
from .conversable_agent import ConversableAgent, register_function
from .groupchat import GroupChat, GroupChatManager
from .user_proxy_agent import UserProxyAgent
-from .chat import initiate_chats, ChatResult
from .utils import gather_usage_summary
__all__ = (
diff --git a/autogen/agentchat/assistant_agent.py b/autogen/agentchat/assistant_agent.py
index 25f7edbf073..b5ec7de90c7 100644
--- a/autogen/agentchat/assistant_agent.py
+++ b/autogen/agentchat/assistant_agent.py
@@ -1,7 +1,8 @@
from typing import Callable, Dict, Literal, Optional, Union
+from autogen.runtime_logging import log_new_agent, logging_enabled
+
from .conversable_agent import ConversableAgent
-from autogen.runtime_logging import logging_enabled, log_new_agent
class AssistantAgent(ConversableAgent):
diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py
index bd56cf2f579..b527f8e0bae 100644
--- a/autogen/agentchat/chat.py
+++ b/autogen/agentchat/chat.py
@@ -1,15 +1,15 @@
import asyncio
-from functools import partial
-import logging
-from collections import defaultdict, abc
-from typing import Dict, List, Any, Set, Tuple
-from dataclasses import dataclass
-from .utils import consolidate_chat_info
import datetime
+import logging
import warnings
-from ..io.base import IOStream
-from ..formatting_utils import colored
+from collections import abc, defaultdict
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, List, Set, Tuple
+from ..formatting_utils import colored
+from ..io.base import IOStream
+from .utils import consolidate_chat_info
logger = logging.getLogger(__name__)
Prerequisite = Tuple[int, int]
@@ -25,8 +25,12 @@ class ChatResult:
"""The chat history."""
summary: str = None
"""A summary obtained from the chat."""
- cost: tuple = None # (dict, dict) - (total_cost, actual_cost_with_cache)
- """The cost of the chat. a tuple of (total_cost, total_actual_cost), where total_cost is a dictionary of cost information, and total_actual_cost is a dictionary of information on the actual incurred cost with cache."""
+ cost: Dict[str, dict] = None # keys: "usage_including_cached_inference", "usage_excluding_cached_inference"
+ """The cost of the chat.
+ The value for each usage type is a dictionary containing cost information for that specific type.
+ - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
+ - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
+ """
human_input: List[str] = None
"""A list of human input solicited during the chat."""
@@ -141,25 +145,35 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
-
Args:
- chat_queue (List[Dict]): a list of dictionaries containing the information about the chats.
-
- Each dictionary should contain the input arguments for [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat). For example:
- - "sender": the sender agent.
- - "recipient": the recipient agent.
- - "clear_history" (bool): whether to clear the chat history with the agent. Default is True.
- - "silent" (bool or None): (Experimental) whether to print the messages in this conversation. Default is False.
- - "cache" (AbstractCache or None): the cache client to use for this conversation. Default is None.
- - "max_turns" (int or None): maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- - "summary_method" (str or callable): a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- - "summary_args" (dict): a dictionary of arguments to be passed to the summary_method. Default is {}.
- - "message" (str, callable or None): if None, input() will be called to get the initial message.
- - **context: additional context information to be passed to the chat.
- - "carryover": It can be used to specify the carryover information to be passed to this chat.
- If provided, we will combine this carryover with the "message" content when generating the initial chat
- message in `generate_init_message`.
-
+ chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.
+
+ Each dictionary should contain the input arguments for
+ [`ConversableAgent.initiate_chat`](/docs/reference/agentchat/conversable_agent#initiate_chat).
+ For example:
+ - `"sender"` - the sender agent.
+ - `"recipient"` - the recipient agent.
+ - `"clear_history" (bool) - whether to clear the chat history with the agent.
+ Default is True.
+ - `"silent"` (bool or None) - (Experimental) whether to print the messages in this
+ conversation. Default is False.
+ - `"cache"` (Cache or None) - the cache client to use for this conversation.
+ Default is None.
+ - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
+ will continue until a termination condition is met. Default is None.
+ - `"summary_method"` (str or callable) - a string or callable specifying the method to get
+ a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
+ - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
+ Default is {}.
+ - `"message"` (str, callable or None) - if None, input() will be called to get the
+ initial message.
+ - `**context` - additional context information to be passed to the chat.
+ - `"carryover"` - It can be used to specify the carryover information to be passed
+ to this chat. If provided, we will combine this carryover with the "message" content when
+ generating the initial chat message in `generate_init_message`.
+ - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
+ from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
+ then summary from all the finished chats will be taken.
Returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
@@ -171,9 +185,16 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
while current_chat_queue:
chat_info = current_chat_queue.pop(0)
_chat_carryover = chat_info.get("carryover", [])
+ finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
+ "finished_chat_indexes_to_exclude_from_carryover", []
+ )
+
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
- chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats]
+ chat_info["carryover"] = _chat_carryover + [
+ r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
+ ]
+
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
@@ -228,11 +249,11 @@ async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatRe
"""(async) Initiate a list of chats.
args:
- Please refer to `initiate_chats`.
+ - Please refer to `initiate_chats`.
returns:
- (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
+ - (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""
consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
diff --git a/autogen/agentchat/contrib/agent_builder.py b/autogen/agentchat/contrib/agent_builder.py
index 7a3850d79ae..272d954ff27 100644
--- a/autogen/agentchat/contrib/agent_builder.py
+++ b/autogen/agentchat/contrib/agent_builder.py
@@ -1,10 +1,11 @@
-import autogen
-import time
-import subprocess as sp
-import socket
-import json
import hashlib
-from typing import Optional, List, Dict, Tuple
+import json
+import socket
+import subprocess as sp
+import time
+from typing import Dict, List, Optional, Tuple
+
+import autogen
def _config_check(config: Dict):
@@ -202,9 +203,6 @@ def _create_agent(
Returns:
agent: a set-up agent.
"""
- from huggingface_hub import HfApi
- from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
-
config_list = autogen.config_list_from_json(
self.config_file_or_env,
file_location=self.config_file_location,
@@ -217,10 +215,15 @@ def _create_agent(
f"If you load configs from json, make sure the model in agent_configs is in the {self.config_file_or_env}."
)
try:
+ from huggingface_hub import HfApi
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
+
hf_api = HfApi()
hf_api.model_info(model_name_or_hf_repo)
model_name = model_name_or_hf_repo.split("/")[-1]
server_id = f"{model_name}_{self.host}"
+ except ImportError:
+ server_id = self.online_server_name
except GatedRepoError as e:
raise e
except RepositoryNotFoundError:
@@ -494,9 +497,6 @@ def build_from_library(
agent_list: a list of agents.
cached_configs: cached configs.
"""
- import chromadb
- from chromadb.utils import embedding_functions
-
if code_execution_config is None:
code_execution_config = {
"last_n_messages": 2,
@@ -527,6 +527,9 @@ def build_from_library(
print("==> Looking for suitable agents in library...")
if embedding_model is not None:
+ import chromadb
+ from chromadb.utils import embedding_functions
+
chroma_client = chromadb.Client()
collection = chroma_client.create_collection(
name="agent_list",
diff --git a/autogen/agentchat/contrib/agent_eval/README.md b/autogen/agentchat/contrib/agent_eval/README.md
new file mode 100644
index 00000000000..6588a1ec611
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/README.md
@@ -0,0 +1,7 @@
+Agents for running the AgentEval pipeline.
+
+AgentEval is a process for evaluating a LLM-based system's performance on a given task.
+
+When given a task to evaluate and a few example runs, the critic and subcritic agents create evaluation criteria for evaluating a system's solution. Once the criteria has been created, the quantifier agent can evaluate subsequent task solutions based on the generated criteria.
+
+For more information see: [AgentEval Integration Roadmap](https://github.com/microsoft/autogen/issues/2162)
diff --git a/autogen/agentchat/contrib/agent_eval/agent_eval.py b/autogen/agentchat/contrib/agent_eval/agent_eval.py
new file mode 100644
index 00000000000..b48c65a66d2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/agent_eval.py
@@ -0,0 +1,101 @@
+from typing import Dict, List, Literal, Optional, Union
+
+import autogen
+from autogen.agentchat.contrib.agent_eval.criterion import Criterion
+from autogen.agentchat.contrib.agent_eval.critic_agent import CriticAgent
+from autogen.agentchat.contrib.agent_eval.quantifier_agent import QuantifierAgent
+from autogen.agentchat.contrib.agent_eval.subcritic_agent import SubCriticAgent
+from autogen.agentchat.contrib.agent_eval.task import Task
+
+
+def generate_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ task: Task = None,
+ additional_instructions: str = "",
+ max_round=2,
+ use_subcritic: bool = False,
+):
+ """
+ Creates a list of criteria for evaluating the utility of a given task.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ task (Task): The task to evaluate.
+ additional_instructions (str): Additional instructions for the criteria agent.
+ max_round (int): The maximum number of rounds to run the conversation.
+ use_subcritic (bool): Whether to use the subcritic agent to generate subcriteria.
+ Returns:
+ list: A list of Criterion objects for evaluating the utility of the given task.
+ """
+ critic = CriticAgent(
+ system_message=CriticAgent.DEFAULT_SYSTEM_MESSAGE + "\n" + additional_instructions,
+ llm_config=llm_config,
+ )
+
+ critic_user = autogen.UserProxyAgent(
+ name="critic_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ agents = [critic_user, critic]
+
+ if use_subcritic:
+ subcritic = SubCriticAgent(
+ llm_config=llm_config,
+ )
+ agents.append(subcritic)
+
+ groupchat = autogen.GroupChat(
+ agents=agents, messages=[], max_round=max_round, speaker_selection_method="round_robin"
+ )
+ critic_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
+
+ critic_user.initiate_chat(critic_manager, message=task.get_sys_message())
+ criteria = critic_user.last_message()
+ content = criteria["content"]
+ # need to strip out any extra code around the returned json
+ content = content[content.find("[") : content.rfind("]") + 1]
+ criteria = Criterion.parse_json_str(content)
+ return criteria
+
+
+def quantify_criteria(
+ llm_config: Optional[Union[Dict, Literal[False]]] = None,
+ criteria: List[Criterion] = None,
+ task: Task = None,
+ test_case: str = "",
+ ground_truth: str = "",
+):
+ """
+ Quantifies the performance of a system using the provided criteria.
+ Args:
+ llm_config (dict or bool): llm inference configuration.
+ criteria ([Criterion]): A list of criteria for evaluating the utility of a given task.
+ task (Task): The task to evaluate.
+ test_case (str): The test case to evaluate.
+ ground_truth (str): The ground truth for the test case.
+ Returns:
+ dict: A dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ """
+ quantifier = QuantifierAgent(
+ llm_config=llm_config,
+ )
+
+ quantifier_user = autogen.UserProxyAgent(
+ name="quantifier_user",
+ max_consecutive_auto_reply=0, # terminate without auto-reply
+ human_input_mode="NEVER",
+ code_execution_config={"use_docker": False},
+ )
+
+ quantifier_user.initiate_chat( # noqa: F841
+ quantifier,
+ message=task.get_sys_message()
+ + "Evaluation dictionary: "
+ + Criterion.write_json(criteria)
+ + "actual test case to evaluate: "
+ + test_case,
+ )
+ quantified_results = quantifier_user.last_message()
+ return {"actual_success": ground_truth, "estimated_performance": quantified_results["content"]}
diff --git a/autogen/agentchat/contrib/agent_eval/criterion.py b/autogen/agentchat/contrib/agent_eval/criterion.py
new file mode 100644
index 00000000000..5efd121ec07
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/criterion.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+import json
+from typing import List
+
+import pydantic_core
+from pydantic import BaseModel
+from pydantic.json import pydantic_encoder
+
+
+class Criterion(BaseModel):
+ """
+ A class that represents a criterion for agent evaluation.
+ """
+
+ name: str
+ description: str
+ accepted_values: List[str]
+ sub_criteria: List[Criterion] = list()
+
+ @staticmethod
+ def parse_json_str(criteria: str):
+ """
+ Create a list of Criterion objects from a json string.
+ Args:
+ criteria (str): Json string that represents the criteria
+ returns:
+ [Criterion]: A list of Criterion objects that represents the json criteria information.
+ """
+ return [Criterion(**crit) for crit in json.loads(criteria)]
+
+ @staticmethod
+ def write_json(criteria):
+ """
+ Create a json string from a list of Criterion objects.
+ Args:
+ criteria ([Criterion]): A list of Criterion objects.
+ Returns:
+ str: A json string that represents the list of Criterion objects.
+ """
+ return json.dumps([crit.model_dump() for crit in criteria], indent=2)
diff --git a/autogen/agentchat/contrib/agent_eval/critic_agent.py b/autogen/agentchat/contrib/agent_eval/critic_agent.py
new file mode 100644
index 00000000000..2f5e5598ba6
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/critic_agent.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class CriticAgent(ConversableAgent):
+ """
+ An agent for creating list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant. You suggest criteria for evaluating different tasks. They should be distinguishable, quantifiable and not redundant.
+ Convert the evaluation criteria into a list where each item is a criteria which consists of the following dictionary as follows
+ {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ Make sure "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels and "description" includes the criterion description.
+ Output just the criteria string you have created, no code.
+ """
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating list criteria for evaluating the utility of a given task."
+
+ def __init__(
+ self,
+ name="critic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/quantifier_agent.py b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
new file mode 100644
index 00000000000..02a8f650fab
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/quantifier_agent.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class QuantifierAgent(ConversableAgent):
+ """
+ An agent for quantifying the performance of a system using the provided criteria.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """"You are a helpful assistant. You quantify the output of different tasks based on the given criteria.
+ The criterion is given in a json list format where each element is a distinct criteria.
+ The each element is a dictionary as follows {"name": name of the criterion, "description": criteria description , "accepted_values": possible accepted inputs for this key}
+ You are going to quantify each of the crieria for a given task based on the task description.
+ Return a dictionary where the keys are the criteria and the values are the assessed performance based on accepted values for each criteria.
+ Return only the dictionary, no code."""
+
+ DEFAULT_DESCRIPTION = "An AI agent for quantifing the performance of a system using the provided criteria."
+
+ def __init__(
+ self,
+ name="quantifier",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(name=name, system_message=system_message, description=description, **kwargs)
diff --git a/autogen/agentchat/contrib/agent_eval/subcritic_agent.py b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
new file mode 100755
index 00000000000..fa994ee7bda
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/subcritic_agent.py
@@ -0,0 +1,42 @@
+from typing import Optional
+
+from autogen.agentchat.conversable_agent import ConversableAgent
+
+
+class SubCriticAgent(ConversableAgent):
+ """
+ An agent for creating subcriteria from a given list of criteria for evaluating the utility of a given task.
+ """
+
+ DEFAULT_SYSTEM_MESSAGE = """You are a helpful assistant to the critic agent. You suggest sub criteria for evaluating different tasks based on the criteria provided by the critic agent (if you feel it is needed).
+ They should be distinguishable, quantifiable, and related to the overall theme of the critic's provided criteria.
+ You operate by taking in the description of the criteria. You then create a new key called sub criteria where you provide the sub criteria for the given criteria.
+ The value of the sub_criteria is a dictionary where the keys are the subcriteria and each value is as follows {"description": sub criteria description , "accepted_values": possible accepted inputs for this key}
+ Do this for each criteria provided by the critic (removing the criteria's accepted values). "accepted_values" include the acceptable inputs for each key that are fine-grained and preferably multi-graded levels. "description" includes the criterion description.
+ Once you have created the sub criteria for the given criteria, you return the json (make sure to include the contents of the critic's dictionary in the final dictionary as well).
+ Make sure to return a valid json and no code"""
+
+ DEFAULT_DESCRIPTION = "An AI agent for creating subcriteria from a given list of criteria."
+
+ def __init__(
+ self,
+ name="subcritic",
+ system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
+ description: Optional[str] = DEFAULT_DESCRIPTION,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): agent name.
+ system_message (str): system message for the ChatCompletion inference.
+ Please override this attribute if you want to reprogram the agent.
+ description (str): The description of the agent.
+ **kwargs (dict): Please refer to other kwargs in
+ [ConversableAgent](../../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ description=description,
+ **kwargs,
+ )
diff --git a/autogen/agentchat/contrib/agent_eval/task.py b/autogen/agentchat/contrib/agent_eval/task.py
new file mode 100644
index 00000000000..9f96fbf79e2
--- /dev/null
+++ b/autogen/agentchat/contrib/agent_eval/task.py
@@ -0,0 +1,37 @@
+import json
+
+from pydantic import BaseModel
+
+
+class Task(BaseModel):
+ """
+ Class representing a task for agent completion, includes example agent execution for criteria generation.
+ """
+
+ name: str
+ description: str
+ successful_response: str
+ failed_response: str
+
+ def get_sys_message(self):
+ return f"""Task: {self.name}.
+ Task description: {self.description}
+ Task successful example: {self.successful_response}
+ Task failed example: {self.failed_response}
+ """
+
+ @staticmethod
+ def parse_json_str(task: str):
+ """
+ Create a Task object from a json object.
+ Args:
+ json_data (dict): A dictionary that represents the task.
+ Returns:
+ Task: A Task object that represents the json task information.
+ """
+ json_data = json.loads(task)
+ name = json_data.get("name")
+ description = json_data.get("description")
+ successful_response = json_data.get("successful_response")
+ failed_response = json_data.get("failed_response")
+ return Task(name, description, successful_response, failed_response)
diff --git a/autogen/agentchat/contrib/agent_optimizer.py b/autogen/agentchat/contrib/agent_optimizer.py
index 711874efc8f..af264d4b65f 100644
--- a/autogen/agentchat/contrib/agent_optimizer.py
+++ b/autogen/agentchat/contrib/agent_optimizer.py
@@ -1,8 +1,9 @@
-from autogen.code_utils import execute_code
-from typing import List, Dict, Optional
-import json
import copy
+import json
+from typing import Dict, List, Literal, Optional, Union
+
import autogen
+from autogen.code_utils import execute_code
ADD_FUNC = {
"type": "function",
@@ -171,16 +172,16 @@ class AgentOptimizer:
def __init__(
self,
max_actions_per_step: int,
- config_file_or_env: Optional[str] = "OAI_CONFIG_LIST",
- config_file_location: Optional[str] = "",
+ llm_config: dict,
optimizer_model: Optional[str] = "gpt-4-1106-preview",
):
"""
(These APIs are experimental and may change in the future.)
Args:
max_actions_per_step (int): the maximum number of actions that the optimizer can take in one step.
- config_file_or_env: path or environment of the OpenAI api configs.
- config_file_location: the location of the OpenAI config file.
+ llm_config (dict): llm inference configuration.
+ Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options.
+ When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`.
optimizer_model: the model used for the optimizer.
"""
self.max_actions_per_step = max_actions_per_step
@@ -198,14 +199,17 @@ def __init__(
self._failure_functions_performance = []
self._best_performance = -1
- config_list = autogen.config_list_from_json(
- config_file_or_env,
- file_location=config_file_location,
- filter_dict={"model": [self.optimizer_model]},
+ assert isinstance(llm_config, dict), "llm_config must be a dict"
+ llm_config = copy.deepcopy(llm_config)
+ self.llm_config = llm_config
+ if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]:
+ raise ValueError(
+ "When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
+ )
+ self.llm_config["config_list"] = autogen.filter_config(
+ llm_config["config_list"], {"model": [self.optimizer_model]}
)
- if len(config_list) == 0:
- raise RuntimeError("No valid openai config found in the config file or environment variable.")
- self._client = autogen.OpenAIWrapper(config_list=config_list)
+ self._client = autogen.OpenAIWrapper(**self.llm_config)
def record_one_conversation(self, conversation_history: List[Dict], is_satisfied: bool = None):
"""
@@ -265,7 +269,7 @@ def step(self):
actions_num=action_index,
best_functions=best_functions,
incumbent_functions=incumbent_functions,
- accumerated_experience=failure_experience_prompt,
+ accumulated_experience=failure_experience_prompt,
statistic_informations=statistic_prompt,
)
messages = [{"role": "user", "content": prompt}]
diff --git a/autogen/agentchat/contrib/capabilities/context_handling.py b/autogen/agentchat/contrib/capabilities/context_handling.py
index 1510ae5fcd6..173811842eb 100644
--- a/autogen/agentchat/contrib/capabilities/context_handling.py
+++ b/autogen/agentchat/contrib/capabilities/context_handling.py
@@ -1,9 +1,18 @@
import sys
-from termcolor import colored
-from typing import Dict, Optional, List
-from autogen import ConversableAgent
-from autogen import token_count_utils
+from typing import Dict, List, Optional
+from warnings import warn
+
import tiktoken
+from termcolor import colored
+
+from autogen import ConversableAgent, token_count_utils
+
+warn(
+ "Context handling with TransformChatHistory is deprecated. "
+ "Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.",
+ DeprecationWarning,
+ stacklevel=2,
+)
class TransformChatHistory:
@@ -26,7 +35,8 @@ class TransformChatHistory:
3. Third, it limits the total number of tokens in the chat history
When adding this capability to an agent, the following are modified:
- - A hook is added to the hookable method `process_all_messages_before_reply` to transform the received messages for possible truncation.
+ - A hook is added to the hookable method `process_all_messages_before_reply` to transform the
+ received messages for possible truncation.
Not modifying the stored message history.
"""
diff --git a/autogen/agentchat/contrib/capabilities/generate_images.py b/autogen/agentchat/contrib/capabilities/generate_images.py
index d16121ddb9a..e4a8f1195c2 100644
--- a/autogen/agentchat/contrib/capabilities/generate_images.py
+++ b/autogen/agentchat/contrib/capabilities/generate_images.py
@@ -5,10 +5,10 @@
from PIL.Image import Image
from autogen import Agent, ConversableAgent, code_utils
-from autogen.cache import AbstractCache
from autogen.agentchat.contrib import img_utils
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
+from autogen.cache import AbstractCache
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
diff --git a/autogen/agentchat/contrib/capabilities/teachability.py b/autogen/agentchat/contrib/capabilities/teachability.py
index 58ba35ed425..596e449ce34 100644
--- a/autogen/agentchat/contrib/capabilities/teachability.py
+++ b/autogen/agentchat/contrib/capabilities/teachability.py
@@ -1,11 +1,14 @@
import os
+import pickle
from typing import Dict, Optional, Union
+
import chromadb
from chromadb.config import Settings
-import pickle
+
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
+
from ....formatting_utils import colored
@@ -83,7 +86,7 @@ def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
- def process_last_received_message(self, text):
+ def process_last_received_message(self, text: Union[Dict, str]):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
@@ -100,7 +103,7 @@ def process_last_received_message(self, text):
# Return the (possibly) expanded message text.
return expanded_text
- def _consider_memo_storage(self, comment):
+ def _consider_memo_storage(self, comment: Union[Dict, str]):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False
@@ -158,7 +161,7 @@ def _consider_memo_storage(self, comment):
# Yes. Save them to disk.
self.memo_store._save_memos()
- def _consider_memo_retrieval(self, comment):
+ def _consider_memo_retrieval(self, comment: Union[Dict, str]):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
# First, use the comment directly as the lookup key.
@@ -192,7 +195,7 @@ def _consider_memo_retrieval(self, comment):
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)
- def _retrieve_relevant_memos(self, input_text):
+ def _retrieve_relevant_memos(self, input_text: str) -> list:
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
@@ -210,7 +213,7 @@ def _retrieve_relevant_memos(self, input_text):
memo_list = [memo[1] for memo in memo_list]
return memo_list
- def _concatenate_memo_texts(self, memo_list):
+ def _concatenate_memo_texts(self, memo_list: list) -> str:
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
@@ -222,7 +225,7 @@ def _concatenate_memo_texts(self, memo_list):
memo_texts = memo_texts + "\n" + info
return memo_texts
- def _analyze(self, text_to_analyze, analysis_instructions):
+ def _analyze(self, text_to_analyze: Union[Dict, str], analysis_instructions: Union[Dict, str]):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
@@ -243,10 +246,16 @@ class MemoStore:
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
"""
- def __init__(self, verbosity, reset, path_to_db_dir):
+ def __init__(
+ self,
+ verbosity: Optional[int] = 0,
+ reset: Optional[bool] = False,
+ path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
+ ):
"""
Args:
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
+ - reset (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
@@ -301,7 +310,7 @@ def reset_db(self):
self.uid_text_dict = {}
self._save_memos()
- def add_input_output_pair(self, input_text, output_text):
+ def add_input_output_pair(self, input_text: str, output_text: str):
"""Adds an input-output pair to the vector DB."""
self.last_memo_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
@@ -318,7 +327,7 @@ def add_input_output_pair(self, input_text, output_text):
if self.verbosity >= 3:
self.list_memos()
- def get_nearest_memo(self, query_text):
+ def get_nearest_memo(self, query_text: str):
"""Retrieves the nearest memo to the given query text."""
results = self.vec_db.query(query_texts=[query_text], n_results=1)
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
@@ -335,7 +344,7 @@ def get_nearest_memo(self, query_text):
)
return input_text, output_text, distance
- def get_related_memos(self, query_text, n_results, threshold):
+ def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
diff --git a/autogen/agentchat/contrib/capabilities/text_compressors.py b/autogen/agentchat/contrib/capabilities/text_compressors.py
new file mode 100644
index 00000000000..78554bdc935
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/text_compressors.py
@@ -0,0 +1,68 @@
+from typing import Any, Dict, Optional, Protocol
+
+IMPORT_ERROR: Optional[Exception] = None
+try:
+ import llmlingua
+except ImportError:
+ IMPORT_ERROR = ImportError(
+ "LLMLingua is not installed. Please install it with `pip install pyautogen[long-context]`"
+ )
+ PromptCompressor = object
+else:
+ from llmlingua import PromptCompressor
+
+
+class TextCompressor(Protocol):
+ """Defines a protocol for text compression to optimize agent interactions."""
+
+ def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
+ """This method takes a string as input and returns a dictionary containing the compressed text and other
+ relevant information. The compressed text should be stored under the 'compressed_text' key in the dictionary.
+ To calculate the number of saved tokens, the dictionary should include 'origin_tokens' and 'compressed_tokens' keys.
+ """
+ ...
+
+
+class LLMLingua:
+ """Compresses text messages using LLMLingua for improved efficiency in processing and response generation.
+
+ NOTE: The effectiveness of compression and the resultant token savings can vary based on the content of the messages
+ and the specific configurations used for the PromptCompressor.
+ """
+
+ def __init__(
+ self,
+ prompt_compressor_kwargs: Dict = dict(
+ model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2=True,
+ device_map="cpu",
+ ),
+ structured_compression: bool = False,
+ ) -> None:
+ """
+ Args:
+ prompt_compressor_kwargs (dict): A dictionary of keyword arguments for the PromptCompressor. Defaults to a
+ dictionary with model_name set to "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2 set to True, and device_map set to "cpu".
+ structured_compression (bool): A flag indicating whether to use structured compression. If True, the
+ structured_compress_prompt method of the PromptCompressor is used. Otherwise, the compress_prompt method
+ is used. Defaults to False.
+ dictionary.
+
+ Raises:
+ ImportError: If the llmlingua library is not installed.
+ """
+ if IMPORT_ERROR:
+ raise IMPORT_ERROR
+
+ self._prompt_compressor = PromptCompressor(**prompt_compressor_kwargs)
+
+ assert isinstance(self._prompt_compressor, llmlingua.PromptCompressor)
+ self._compression_method = (
+ self._prompt_compressor.structured_compress_prompt
+ if structured_compression
+ else self._prompt_compressor.compress_prompt
+ )
+
+ def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
+ return self._compression_method([text], **compression_params)
diff --git a/autogen/agentchat/contrib/capabilities/transform_messages.py b/autogen/agentchat/contrib/capabilities/transform_messages.py
new file mode 100644
index 00000000000..e96dc39fa7b
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/transform_messages.py
@@ -0,0 +1,87 @@
+import copy
+from typing import Dict, List
+
+from autogen import ConversableAgent
+
+from ....formatting_utils import colored
+from .transforms import MessageTransform
+
+
+class TransformMessages:
+ """Agent capability for transforming messages before reply generation.
+
+ This capability allows you to apply a series of message transformations to
+ a ConversableAgent's incoming messages before they are processed for response
+ generation. This is useful for tasks such as:
+
+ - Limiting the number of messages considered for context.
+ - Truncating messages to meet token limits.
+ - Filtering sensitive information.
+ - Customizing message formatting.
+
+ To use `TransformMessages`:
+
+ 1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
+ 2. Instantiate `TransformMessages` with a list of these transformations.
+ 3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
+
+ NOTE: Order of message transformations is important. You could get different results based on
+ the order of transformations.
+
+ Example:
+ ```python
+ from agentchat import ConversableAgent
+ from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
+
+ max_messages = MessageHistoryLimiter(max_messages=2)
+ truncate_messages = MessageTokenLimiter(max_tokens=500)
+ transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
+
+ agent = ConversableAgent(...)
+ transform_messages.add_to_agent(agent)
+ ```
+ """
+
+ def __init__(self, *, transforms: List[MessageTransform] = [], verbose: bool = True):
+ """
+ Args:
+ transforms: A list of message transformations to apply.
+ verbose: Whether to print logs of each transformation or not.
+ """
+ self._transforms = transforms
+ self._verbose = verbose
+
+ def add_to_agent(self, agent: ConversableAgent):
+ """Adds the message transformations capability to the specified ConversableAgent.
+
+ This function performs the following modifications to the agent:
+
+ 1. Registers a hook that automatically transforms all messages before they are processed for
+ response generation.
+ """
+ agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
+
+ def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
+ post_transform_messages = copy.deepcopy(messages)
+ system_message = None
+
+ if messages[0]["role"] == "system":
+ system_message = copy.deepcopy(messages[0])
+ post_transform_messages.pop(0)
+
+ for transform in self._transforms:
+ # deepcopy in case pre_transform_messages will later be used for logs printing
+ pre_transform_messages = (
+ copy.deepcopy(post_transform_messages) if self._verbose else post_transform_messages
+ )
+ post_transform_messages = transform.apply_transform(pre_transform_messages)
+
+ if self._verbose:
+ logs_str, had_effect = transform.get_logs(pre_transform_messages, post_transform_messages)
+ if had_effect:
+ print(colored(logs_str, "yellow"))
+
+ if system_message:
+ post_transform_messages.insert(0, system_message)
+
+ return post_transform_messages
diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py
new file mode 100644
index 00000000000..8303843e881
--- /dev/null
+++ b/autogen/agentchat/contrib/capabilities/transforms.py
@@ -0,0 +1,436 @@
+import copy
+import json
+import sys
+from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
+
+import tiktoken
+from termcolor import colored
+
+from autogen import token_count_utils
+from autogen.cache import AbstractCache, Cache
+
+from .text_compressors import LLMLingua, TextCompressor
+
+
+class MessageTransform(Protocol):
+ """Defines a contract for message transformation.
+
+ Classes implementing this protocol should provide an `apply_transform` method
+ that takes a list of messages and returns the transformed list.
+ """
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies a transformation to a list of messages.
+
+ Args:
+ messages: A list of dictionaries representing messages.
+
+ Returns:
+ A new list of dictionaries containing the transformed messages.
+ """
+ ...
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ """Creates the string including the logs of the transformation
+
+ Alongside the string, it returns a boolean indicating whether the transformation had an effect or not.
+
+ Args:
+ pre_transform_messages: A list of dictionaries representing messages before the transformation.
+ post_transform_messages: A list of dictionaries representig messages after the transformation.
+
+ Returns:
+ A tuple with a string with the logs and a flag indicating whether the transformation had an effect or not.
+ """
+ ...
+
+
+class MessageHistoryLimiter:
+ """Limits the number of messages considered by an agent for response generation.
+
+ This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
+ It trims the conversation history by removing older messages, retaining only the most recent messages.
+ """
+
+ def __init__(self, max_messages: Optional[int] = None):
+ """
+ Args:
+ max_messages Optional[int]: Maximum number of messages to keep in the context. Must be greater than 0 if not None.
+ """
+ self._validate_max_messages(max_messages)
+ self._max_messages = max_messages
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Truncates the conversation history to the specified maximum number of messages.
+
+ This method returns a new list containing the most recent messages up to the specified
+ maximum number of messages (max_messages). If max_messages is None, it returns the
+ original list of messages unmodified.
+
+ Args:
+ messages (List[Dict]): The list of messages representing the conversation history.
+
+ Returns:
+ List[Dict]: A new list containing the most recent messages up to the specified maximum.
+ """
+
+ if self._max_messages is None:
+ return messages
+
+ return messages[-self._max_messages :]
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ pre_transform_messages_len = len(pre_transform_messages)
+ post_transform_messages_len = len(post_transform_messages)
+
+ if post_transform_messages_len < pre_transform_messages_len:
+ logs_str = (
+ f"Removed {pre_transform_messages_len - post_transform_messages_len} messages. "
+ f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}."
+ )
+ return logs_str, True
+ return "No messages were removed.", False
+
+ def _validate_max_messages(self, max_messages: Optional[int]):
+ if max_messages is not None and max_messages < 1:
+ raise ValueError("max_messages must be None or greater than 1")
+
+
+class MessageTokenLimiter:
+ """Truncates messages to meet token limits for efficient processing and response generation.
+
+ This transformation applies two levels of truncation to the conversation history:
+
+ 1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
+ 2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
+
+ NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
+ counts for the same text.
+
+ NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
+ (e.g images).
+
+ The truncation process follows these steps in order:
+
+ 1. The minimum tokens threshold (`min_tokens`) is checked (0 by default). If the total number of tokens in messages
+ are less than this threshold, then the messages are returned as is. In other case, the following process is applied.
+ 2. Messages are processed in reverse order (newest to oldest).
+ 3. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
+ and other types of content, only the text content is truncated.
+ 4. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
+ exceeds this limit, the current message being processed get truncated to meet the total token count and any
+ remaining messages get discarded.
+ 5. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
+ original message order.
+ """
+
+ def __init__(
+ self,
+ max_tokens_per_message: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ min_tokens: Optional[int] = None,
+ model: str = "gpt-3.5-turbo-0613",
+ ):
+ """
+ Args:
+ max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
+ Must be greater than or equal to 0 if not None.
+ max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
+ Must be greater than or equal to 0 if not None.
+ min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
+ Must be greater than or equal to 0 if not None.
+ model (str): The target OpenAI model for tokenization alignment.
+ """
+ self._model = model
+ self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
+ self._max_tokens = self._validate_max_tokens(max_tokens)
+ self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies token truncation to the conversation history.
+
+ Args:
+ messages (List[Dict]): The list of messages representing the conversation history.
+
+ Returns:
+ List[Dict]: A new list containing the truncated messages up to the specified token limits.
+ """
+ assert self._max_tokens_per_message is not None
+ assert self._max_tokens is not None
+ assert self._min_tokens is not None
+
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
+ if not _min_tokens_reached(messages, self._min_tokens):
+ return messages
+
+ temp_messages = copy.deepcopy(messages)
+ processed_messages = []
+ processed_messages_tokens = 0
+
+ for msg in reversed(temp_messages):
+ # Some messages may not have content.
+ if not isinstance(msg.get("content"), (str, list)):
+ processed_messages.insert(0, msg)
+ continue
+
+ expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
+
+ # If adding this message would exceed the token limit, truncate the last message to meet the total token
+ # limit and discard all remaining messages
+ if expected_tokens_remained < 0:
+ msg["content"] = self._truncate_str_to_tokens(
+ msg["content"], self._max_tokens - processed_messages_tokens
+ )
+ processed_messages.insert(0, msg)
+ break
+
+ msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
+ msg_tokens = _count_tokens(msg["content"])
+
+ # prepend the message to the list to preserve order
+ processed_messages_tokens += msg_tokens
+ processed_messages.insert(0, msg)
+
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ pre_transform_messages_tokens = sum(
+ _count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
+ )
+ post_transform_messages_tokens = sum(
+ _count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
+ )
+
+ if post_transform_messages_tokens < pre_transform_messages_tokens:
+ logs_str = (
+ f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. "
+ f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}"
+ )
+ return logs_str, True
+ return "No tokens were truncated.", False
+
+ def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
+ if isinstance(contents, str):
+ return self._truncate_tokens(contents, n_tokens)
+ elif isinstance(contents, list):
+ return self._truncate_multimodal_text(contents, n_tokens)
+ else:
+ raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
+
+ def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]:
+ """Truncates text content within a list of multimodal elements, preserving the overall structure."""
+ tmp_contents = []
+ for content in contents:
+ if content["type"] == "text":
+ truncated_text = self._truncate_tokens(content["text"], n_tokens)
+ tmp_contents.append({"type": "text", "text": truncated_text})
+ else:
+ tmp_contents.append(content)
+ return tmp_contents
+
+ def _truncate_tokens(self, text: str, n_tokens: int) -> str:
+ encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
+
+ encoded_tokens = encoding.encode(text)
+ truncated_tokens = encoded_tokens[:n_tokens]
+ truncated_text = encoding.decode(truncated_tokens) # Decode back to text
+
+ return truncated_text
+
+ def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
+ if max_tokens is not None and max_tokens < 0:
+ raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")
+
+ try:
+ allowed_tokens = token_count_utils.get_max_token_limit(self._model)
+ except Exception:
+ print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
+ allowed_tokens = None
+
+ if max_tokens is not None and allowed_tokens is not None:
+ if max_tokens > allowed_tokens:
+ print(
+ colored(
+ f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
+ "yellow",
+ )
+ )
+ return allowed_tokens
+
+ return max_tokens if max_tokens is not None else sys.maxsize
+
+ def _validate_min_tokens(self, min_tokens: Optional[int], max_tokens: Optional[int]) -> int:
+ if min_tokens is None:
+ return 0
+ if min_tokens < 0:
+ raise ValueError("min_tokens must be None or greater than or equal to 0.")
+ if max_tokens is not None and min_tokens > max_tokens:
+ raise ValueError("min_tokens must not be more than max_tokens.")
+ return min_tokens
+
+
+class TextMessageCompressor:
+ """A transform for compressing text messages in a conversation history.
+
+ It uses a specified text compression method to reduce the token count of messages, which can lead to more efficient
+ processing and response generation by downstream models.
+ """
+
+ def __init__(
+ self,
+ text_compressor: Optional[TextCompressor] = None,
+ min_tokens: Optional[int] = None,
+ compression_params: Dict = dict(),
+ cache: Optional[AbstractCache] = Cache.disk(),
+ ):
+ """
+ Args:
+ text_compressor (TextCompressor or None): An instance of a class that implements the TextCompressor
+ protocol. If None, it defaults to LLMLingua.
+ min_tokens (int or None): Minimum number of tokens in messages to apply the transformation. Must be greater
+ than or equal to 0 if not None. If None, no threshold-based compression is applied.
+ compression_args (dict): A dictionary of arguments for the compression method. Defaults to an empty
+ dictionary.
+ cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
+ If None, no caching will be used.
+ """
+
+ if text_compressor is None:
+ text_compressor = LLMLingua()
+
+ self._validate_min_tokens(min_tokens)
+
+ self._text_compressor = text_compressor
+ self._min_tokens = min_tokens
+ self._compression_args = compression_params
+ self._cache = cache
+
+ # Optimizing savings calculations to optimize log generation
+ self._recent_tokens_savings = 0
+
+ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
+ """Applies compression to messages in a conversation history based on the specified configuration.
+
+ The function processes each message according to the `compression_args` and `min_tokens` settings, applying
+ the specified compression configuration and returning a new list of messages with reduced token counts
+ where possible.
+
+ Args:
+ messages (List[Dict]): A list of message dictionaries to be compressed.
+
+ Returns:
+ List[Dict]: A list of dictionaries with the message content compressed according to the configured
+ method and scope.
+ """
+ # Make sure there is at least one message
+ if not messages:
+ return messages
+
+ # if the total number of tokens in the messages is less than the min_tokens, return the messages as is
+ if not _min_tokens_reached(messages, self._min_tokens):
+ return messages
+
+ total_savings = 0
+ processed_messages = messages.copy()
+ for message in processed_messages:
+ # Some messages may not have content.
+ if not isinstance(message.get("content"), (str, list)):
+ continue
+
+ if _is_content_text_empty(message["content"]):
+ continue
+
+ cached_content = self._cache_get(message["content"])
+ if cached_content is not None:
+ savings, compressed_content = cached_content
+ else:
+ savings, compressed_content = self._compress(message["content"])
+
+ self._cache_set(message["content"], compressed_content, savings)
+
+ message["content"] = compressed_content
+ total_savings += savings
+
+ self._recent_tokens_savings = total_savings
+ return processed_messages
+
+ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
+ if self._recent_tokens_savings > 0:
+ return f"{self._recent_tokens_savings} tokens saved with text compression.", True
+ else:
+ return "No tokens saved with text compression.", False
+
+ def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
+ """Compresses the given text or multimodal content using the specified compression method."""
+ if isinstance(content, str):
+ return self._compress_text(content)
+ elif isinstance(content, list):
+ return self._compress_multimodal(content)
+ else:
+ return 0, content
+
+ def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
+ tokens_saved = 0
+ for msg in content:
+ if "text" in msg:
+ savings, msg["text"] = self._compress_text(msg["text"])
+ tokens_saved += savings
+ return tokens_saved, content
+
+ def _compress_text(self, text: str) -> Tuple[int, str]:
+ """Compresses the given text using the specified compression method."""
+ compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
+
+ savings = 0
+ if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
+ savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
+
+ return savings, compressed_text["compressed_prompt"]
+
+ def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
+ if self._cache:
+ cached_value = self._cache.get(self._cache_key(content))
+ if cached_value:
+ return cached_value
+
+ def _cache_set(
+ self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
+ ):
+ if self._cache:
+ value = (tokens_saved, json.dumps(compressed_content))
+ self._cache.set(self._cache_key(content), value)
+
+ def _cache_key(self, content: Union[str, List[Dict]]) -> str:
+ return f"{json.dumps(content)}_{self._min_tokens}"
+
+ def _validate_min_tokens(self, min_tokens: Optional[int]):
+ if min_tokens is not None and min_tokens <= 0:
+ raise ValueError("min_tokens must be greater than 0 or None")
+
+
+def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
+ """Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
+ if not min_tokens:
+ return True
+
+ messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
+ return messages_tokens >= min_tokens
+
+
+def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
+ token_count = 0
+ if isinstance(content, str):
+ token_count = token_count_utils.count_token(content)
+ elif isinstance(content, list):
+ for item in content:
+ token_count += _count_tokens(item.get("text", ""))
+ return token_count
+
+
+def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
+ if isinstance(content, str):
+ return content == ""
+ elif isinstance(content, list):
+ return all(_is_content_text_empty(item.get("text", "")) for item in content)
+ else:
+ return False
diff --git a/autogen/agentchat/contrib/compressible_agent.py b/autogen/agentchat/contrib/compressible_agent.py
index 152cc871a56..9c4e78af852 100644
--- a/autogen/agentchat/contrib/compressible_agent.py
+++ b/autogen/agentchat/contrib/compressible_agent.py
@@ -1,20 +1,28 @@
-from typing import Callable, Dict, Optional, Union, Tuple, List, Any
-from autogen import OpenAIWrapper
-from autogen import Agent, ConversableAgent
-import copy
import asyncio
-import logging
+import copy
import inspect
+import logging
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from warnings import warn
+
+from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.token_count_utils import count_token, get_max_token_limit, num_tokens_from_functions
from ...formatting_utils import colored
logger = logging.getLogger(__name__)
+warn(
+ "Context handling with CompressibleAgent is deprecated. "
+ "Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/reference/agentchat/contrib/capabilities/transform_messages",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
class CompressibleAgent(ConversableAgent):
- """(CompressibleAgent will be deprecated. Refer to https://github.com/microsoft/autogen/blob/main/notebook/agentchat_capability_long_context_handling.ipynb for long context handling capability.) CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`,
- it also provides the added feature of compression when activated through the `compress_config` setting.
+ """CompressibleAgent agent. While this agent retains all the default functionalities of the `AssistantAgent`,
+ it also provides the added feature of compression when activated through the `compress_config` setting.
`compress_config` is set to False by default, making this agent equivalent to the `AssistantAgent`.
This agent does not work well in a GroupChat: The compressed messages will not be sent to all the agents in the group.
diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py
index 20acd2b08f8..0f5de8adcb5 100644
--- a/autogen/agentchat/contrib/gpt_assistant_agent.py
+++ b/autogen/agentchat/contrib/gpt_assistant_agent.py
@@ -1,16 +1,16 @@
-from collections import defaultdict
-import openai
+import copy
import json
-import time
import logging
-import copy
+import time
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import openai
from autogen import OpenAIWrapper
-from autogen.oai.openai_utils import retrieve_assistants_by_name
from autogen.agentchat.agent import Agent
-from autogen.agentchat.assistant_agent import ConversableAgent
-from autogen.agentchat.assistant_agent import AssistantAgent
-from typing import Dict, Optional, Union, List, Tuple, Any
+from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
+from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
logger = logging.getLogger(__name__)
@@ -50,7 +50,8 @@ def __init__(
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
- - file_ids: files used by retrieval in run
+ - file_ids: (Deprecated) files used by retrieval in run. It is Deprecated, use tool_resources instead. https://platform.openai.com/docs/assistants/migration/what-has-changed.
+ - tool_resources: A set of resources that are used by the assistant's tools. The resources are specific to the type of tool.
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
overwrite_tools (bool): whether to overwrite the tools of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
kwargs (dict): Additional configuration options for the agent.
@@ -90,7 +91,6 @@ def __init__(
candidate_assistants,
instructions,
openai_assistant_cfg.get("tools", []),
- openai_assistant_cfg.get("file_ids", []),
)
if len(candidate_assistants) == 0:
@@ -101,12 +101,12 @@ def __init__(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
- self._openai_assistant = self._openai_client.beta.assistants.create(
+ self._openai_assistant = create_gpt_assistant(
+ self._openai_client,
name=name,
instructions=instructions,
- tools=openai_assistant_cfg.get("tools", []),
model=model_name,
- file_ids=openai_assistant_cfg.get("file_ids", []),
+ assistant_config=openai_assistant_cfg,
)
else:
logger.warning(
@@ -127,9 +127,12 @@ def __init__(
logger.warning(
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
)
- self._openai_assistant = self._openai_client.beta.assistants.update(
+ self._openai_assistant = update_gpt_assistant(
+ self._openai_client,
assistant_id=openai_assistant_id,
- instructions=instructions,
+ assistant_config={
+ "instructions": instructions,
+ },
)
else:
logger.warning(
@@ -154,9 +157,13 @@ def __init__(
logger.warning(
"overwrite_tools is True. Provided tools will be used and will modify the assistant in the API"
)
- self._openai_assistant = self._openai_client.beta.assistants.update(
+ self._openai_assistant = update_gpt_assistant(
+ self._openai_client,
assistant_id=openai_assistant_id,
- tools=openai_assistant_cfg.get("tools", []),
+ assistant_config={
+ "tools": specified_tools,
+ "tool_resources": openai_assistant_cfg.get("tool_resources", None),
+ },
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
@@ -198,6 +205,8 @@ def _invoke_assistant(
assistant_thread = self._openai_threads[sender]
# Process each unread message
for message in pending_messages:
+ if message["content"].strip() == "":
+ continue
self._openai_client.beta.threads.messages.create(
thread_id=assistant_thread.id,
content=message["content"],
@@ -426,22 +435,23 @@ def delete_assistant(self):
logger.warning("Permanently deleting assistant...")
self._openai_client.beta.assistants.delete(self.assistant_id)
- def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
+ def find_matching_assistant(self, candidate_assistants, instructions, tools):
"""
Find the matching assistant from a list of candidate assistants.
- Filter out candidates with the same name but different instructions, file IDs, and function names.
- TODO: implement accurate match based on assistant metadata fields.
+ Filter out candidates with the same name but different instructions, and function names.
"""
matching_assistants = []
# Preprocess the required tools for faster comparison
- required_tool_types = set(tool.get("type") for tool in tools)
+ required_tool_types = set(
+ "file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools
+ )
+
required_function_names = set(
tool.get("function", {}).get("name")
for tool in tools
- if tool.get("type") not in ["code_interpreter", "retrieval"]
+ if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"]
)
- required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison
for assistant in candidate_assistants:
# Check if instructions are similar
@@ -454,11 +464,12 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
continue
# Preprocess the assistant's tools
- assistant_tool_types = set(tool.type for tool in assistant.tools)
+ assistant_tool_types = set(
+ "file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools
+ )
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
- assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison
- # Check if the tool types, function names, and file IDs match
+ # Check if the tool types, function names match
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
logger.warning(
"tools not match, skip assistant(%s): tools %s, functions %s",
@@ -467,9 +478,6 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
assistant_function_names,
)
continue
- if required_file_ids != assistant_file_ids:
- logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
- continue
# Append assistant to matching list if all conditions are met
matching_assistants.append(assistant)
@@ -496,7 +504,7 @@ def _process_assistant_config(self, llm_config, assistant_config):
# Move the assistant related configurations to assistant_config
# It's important to keep forward compatibility
- assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
+ assistant_config_items = ["assistant_id", "tools", "file_ids", "tool_resources", "check_every_ms"]
for item in assistant_config_items:
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
openai_assistant_cfg[item] = openai_client_cfg[item]
diff --git a/autogen/agentchat/contrib/llava_agent.py b/autogen/agentchat/contrib/llava_agent.py
index 182f72837b7..063b256d3cd 100644
--- a/autogen/agentchat/contrib/llava_agent.py
+++ b/autogen/agentchat/contrib/llava_agent.py
@@ -1,6 +1,7 @@
import json
import logging
from typing import List, Optional, Tuple
+
import replicate
import requests
@@ -8,8 +9,8 @@
from autogen.agentchat.contrib.img_utils import get_image_data, llava_formatter
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.code_utils import content_str
-from ...formatting_utils import colored
+from ...formatting_utils import colored
logger = logging.getLogger(__name__)
diff --git a/autogen/agentchat/contrib/math_user_proxy_agent.py b/autogen/agentchat/contrib/math_user_proxy_agent.py
index 70f365ef9fe..d2b6b7cde00 100644
--- a/autogen/agentchat/contrib/math_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/math_user_proxy_agent.py
@@ -1,15 +1,15 @@
-import re
import os
-from pydantic import BaseModel, Extra, root_validator
-from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+import re
from time import sleep
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from pydantic import BaseModel, Extra, root_validator
from autogen._pydantic import PYDANTIC_V1
from autogen.agentchat import Agent, UserProxyAgent
-from autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
+from autogen.code_utils import UNKNOWN, execute_code, extract_code, infer_lang
from autogen.math_utils import get_answer
-
PROMPTS = {
# default
"default": """Let's use Python to solve a math problem.
diff --git a/autogen/agentchat/contrib/multimodal_conversable_agent.py b/autogen/agentchat/contrib/multimodal_conversable_agent.py
index 2a016bcffba..edeb88cd531 100644
--- a/autogen/agentchat/contrib/multimodal_conversable_agent.py
+++ b/autogen/agentchat/contrib/multimodal_conversable_agent.py
@@ -11,7 +11,6 @@
from ..._pydantic import model_dump
-
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
DEFAULT_MODEL = "gpt-4-vision-preview"
diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
index c539c716ab8..1ece138963f 100644
--- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
@@ -1,17 +1,21 @@
from typing import Callable, Dict, List, Optional
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
-from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS
-import logging
+from autogen.agentchat.contrib.vectordb.utils import (
+ chroma_results_to_query_results,
+ filter_results_by_distance,
+ get_logger,
+)
+from autogen.retrieve_utils import TEXT_FORMATS, get_files_from_dir, split_files_to_chunks
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__)
try:
+ import fastembed
from qdrant_client import QdrantClient, models
from qdrant_client.fastembed_common import QueryResponse
- import fastembed
except ImportError as e:
- logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
+ logger.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'")
raise e
@@ -136,6 +140,11 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
collection_name=self._collection_name,
embedding_model=self._embedding_model,
)
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results, "distances")
+ results = filter_results_by_distance(results, self._distance_threshold)
+
+ self._search_string = search_string
self._results = results
@@ -190,12 +199,12 @@ def create_qdrant_from_dir(
client.set_model(embedding_model)
if custom_text_split_function is not None:
- chunks = split_files_to_chunks(
+ chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
- chunks = split_files_to_chunks(
+ chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")
@@ -298,5 +307,7 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
+ "distances": [[result.score for result in sublist] for sublist in results],
+ "metadatas": [[result.metadata for result in sublist] for sublist in results],
}
return data
diff --git a/autogen/agentchat/contrib/retrieve_assistant_agent.py b/autogen/agentchat/contrib/retrieve_assistant_agent.py
index a09677710aa..9b5ace200dc 100644
--- a/autogen/agentchat/contrib/retrieve_assistant_agent.py
+++ b/autogen/agentchat/contrib/retrieve_assistant_agent.py
@@ -1,6 +1,7 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent
-from typing import Dict, Optional, Union, List, Tuple, Any
class RetrieveAssistantAgent(AssistantAgent):
diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
index f252f60e5ec..476c7c0739d 100644
--- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
@@ -1,19 +1,35 @@
+import hashlib
+import os
import re
-from typing import Callable, Dict, Optional, Union, List, Tuple, Any
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
from IPython import get_ipython
try:
import chromadb
except ImportError:
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
-from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
-from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
-from autogen.token_count_utils import count_token
+from autogen.agentchat.agent import Agent
+from autogen.agentchat.contrib.vectordb.base import Document, QueryResults, VectorDB, VectorDBFactory
+from autogen.agentchat.contrib.vectordb.utils import (
+ chroma_results_to_query_results,
+ filter_results_by_distance,
+ get_logger,
+)
from autogen.code_utils import extract_code
-from autogen import logger
+from autogen.retrieve_utils import (
+ TEXT_FORMATS,
+ create_vector_db_from_dir,
+ get_files_from_dir,
+ query_vector_db,
+ split_files_to_chunks,
+)
+from autogen.token_count_utils import count_token
+
from ...formatting_utils import colored
+logger = get_logger(__name__)
PROMPT_DEFAULT = """You're a retrieve augmented chatbot. You answer user's questions based on your own knowledge and the
context provided by the user. You should follow the following steps to answer a question:
@@ -33,6 +49,10 @@
User's question is: {input_question}
Context is: {input_context}
+
+The source of the context is: {input_sources}
+
+If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`.
"""
PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the
@@ -60,8 +80,14 @@
Context is: {input_context}
"""
+HASH_LENGTH = int(os.environ.get("HASH_LENGTH", 8))
+
class RetrieveUserProxyAgent(UserProxyAgent):
+ """(In preview) The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding
+ similarity, and sends them along with the question to the Retrieval-Augmented Assistant
+ """
+
def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
@@ -73,67 +99,126 @@ def __init__(
r"""
Args:
name (str): name of the agent.
+
human_input_mode (str): whether to ask for human inputs every time a message is received.
Possible values are "ALWAYS", "TERMINATE", "NEVER".
1. When "ALWAYS", the agent prompts for human input every time a message is received.
Under this mode, the conversation stops when the human input is "exit",
or when is_termination_msg is True and there is no human input.
- 2. When "TERMINATE", the agent only prompts for human input only when a termination message is received or
- the number of auto reply reaches the max_consecutive_auto_reply.
- 3. When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops
- when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.
+ 2. When "TERMINATE", the agent only prompts for human input only when a termination
+ message is received or the number of auto reply reaches
+ the max_consecutive_auto_reply.
+ 3. When "NEVER", the agent will never prompt for human input. Under this mode, the
+ conversation stops when the number of auto reply reaches the
+ max_consecutive_auto_reply or when is_termination_msg is True.
+
is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call".
+
retrieve_config (dict or None): config for the retrieve agent.
- To use default config, set to None. Otherwise, set to a dictionary with the following keys:
- - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System
- prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
- - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
- will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
- - docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
- the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
- - extra_docs (Optional, bool): when true, allows adding documents with unique IDs without overwriting existing ones; when false, it replaces existing documents using default IDs, risking collection overwrite.,
- when set to true it enables the system to assign unique IDs starting from "length+i" for new document chunks, preventing the replacement of existing documents and facilitating the addition of more content to the collection..
- By default, "extra_docs" is set to false, starting document IDs from zero. This poses a risk as new documents might overwrite existing ones, potentially causing unintended loss or alteration of data in the collection.
- - collection_name (Optional, str): the name of the collection.
- If key not provided, a default name `autogen-docs` will be used.
- - model (Optional, str): the model to use for the retrieve chat.
+
+ To use default config, set to None. Otherwise, set to a dictionary with the
+ following keys:
+ - `task` (Optional, str) - the task of the retrieve chat. Possible values are
+ "code", "qa" and "default". System prompt will be different for different tasks.
+ The default value is `default`, which supports both code and qa, and provides
+ source information in the end of the response.
+ - `vector_db` (Optional, Union[str, VectorDB]) - the vector db for the retrieve chat.
+ If it's a string, it should be the type of the vector db, such as "chroma"; otherwise,
+ it should be an instance of the VectorDB protocol. Default is "chroma".
+ Set `None` to use the deprecated `client`.
+ - `db_config` (Optional, Dict) - the config for the vector db. Default is `{}`. Please make
+ sure you understand the config for the vector db you are using, otherwise, leave it as `{}`.
+ Only valid when `vector_db` is a string.
+ - `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a
+ default client `chromadb.Client()` will be used. If you want to use other
+ vector db, extend this class and override the `retrieve_docs` function.
+ **Deprecated**: use `vector_db` instead.
+ - `docs_path` (Optional, Union[str, List[str]]) - the path to the docs directory. It
+ can also be the path to a single file, the url to a single file or a list
+ of directories, files and urls. Default is None, which works only if the
+ collection is already created.
+ - `extra_docs` (Optional, bool) - when true, allows adding documents with unique IDs
+ without overwriting existing ones; when false, it replaces existing documents
+ using default IDs, risking collection overwrite., when set to true it enables
+ the system to assign unique IDs starting from "length+i" for new document
+ chunks, preventing the replacement of existing documents and facilitating the
+ addition of more content to the collection..
+ By default, "extra_docs" is set to false, starting document IDs from zero.
+ This poses a risk as new documents might overwrite existing ones, potentially
+ causing unintended loss or alteration of data in the collection.
+ **Deprecated**: use `new_docs` when use `vector_db` instead of `client`.
+ - `new_docs` (Optional, bool) - when True, only adds new documents to the collection;
+ when False, updates existing documents and adds new ones. Default is True.
+ Document id is used to determine if a document is new or existing. By default, the
+ id is the hash value of the content.
+ - `model` (Optional, str) - the model to use for the retrieve chat.
If key not provided, a default model `gpt-4` will be used.
- - chunk_token_size (Optional, int): the chunk token size for the retrieve chat.
+ - `chunk_token_size` (Optional, int) - the chunk token size for the retrieve chat.
If key not provided, a default size `max_tokens * 0.4` will be used.
- - context_max_tokens (Optional, int): the context max token size for the retrieve chat.
+ - `context_max_tokens` (Optional, int) - the context max token size for the
+ retrieve chat.
If key not provided, a default size `max_tokens * 0.8` will be used.
- - chunk_mode (Optional, str): the chunk mode for the retrieve chat. Possible values are
- "multi_lines" and "one_line". If key not provided, a default mode `multi_lines` will be used.
- - must_break_at_empty_line (Optional, bool): chunk will only break at empty line if True. Default is True.
+ - `chunk_mode` (Optional, str) - the chunk mode for the retrieve chat. Possible values
+ are "multi_lines" and "one_line". If key not provided, a default mode
+ `multi_lines` will be used.
+ - `must_break_at_empty_line` (Optional, bool) - chunk will only break at empty line
+ if True. Default is True.
If chunk_mode is "one_line", this parameter will be ignored.
- - embedding_model (Optional, str): the embedding model to use for the retrieve chat.
- If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
- can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
- fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
- - embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
- SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
- other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
- - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
- If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- - get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb.
- Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None.
- - custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
- The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
- Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
- - custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
- Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
- - custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
- This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types.
- - recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True.
+ - `embedding_model` (Optional, str) - the embedding model to use for the retrieve chat.
+ If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available
+ models can be found at `https://www.sbert.net/docs/pretrained_models.html`.
+ The default model is a fast model. If you want to use a high performance model,
+ `all-mpnet-base-v2` is recommended.
+ **Deprecated**: no need when use `vector_db` instead of `client`.
+ - `embedding_function` (Optional, Callable) - the embedding function for creating the
+ vector db. Default is None, SentenceTransformer with the given `embedding_model`
+ will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
+ functions, you can pass it here,
+ follow the examples in `https://docs.trychroma.com/embeddings`.
+ - `customized_prompt` (Optional, str) - the customized prompt for the retrieve chat.
+ Default is None.
+ - `customized_answer_prefix` (Optional, str) - the customized answer prefix for the
+ retrieve chat. Default is "".
+ If not "" and the customized_answer_prefix is not in the answer,
+ `Update Context` will be triggered.
+ - `update_context` (Optional, bool) - if False, will not apply `Update Context` for
+ interactive retrieval. Default is True.
+ - `collection_name` (Optional, str) - the name of the collection.
+ If key not provided, a default name `autogen-docs` will be used.
+ - `get_or_create` (Optional, bool) - Whether to get the collection if it exists. Default is True.
+ - `overwrite` (Optional, bool) - Whether to overwrite the collection if it exists. Default is False.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+ - `custom_token_count_function` (Optional, Callable) - a custom function to count the
+ number of tokens in a string.
+ The function should take (text:str, model:str) as input and return the
+ token_count(int). the retrieve_config["model"] will be passed in the function.
+ Default is autogen.token_count_utils.count_token that uses tiktoken, which may
+ not be accurate for non-OpenAI models.
+ - `custom_text_split_function` (Optional, Callable) - a custom function to split a
+ string into a list of strings.
+ Default is None, will use the default function in
+ `autogen.retrieve_utils.split_text_to_chunks`.
+ - `custom_text_types` (Optional, List[str]) - a list of file types to be processed.
+ Default is `autogen.retrieve_utils.TEXT_FORMATS`.
+ This only applies to files under the directories in `docs_path`. Explicitly
+ included files and urls will be chunked regardless of their types.
+ - `recursive` (Optional, bool) - whether to search documents recursively in the
+ docs_path. Default is True.
+ - `distance_threshold` (Optional, float) - the threshold for the distance score, only
+ distance smaller than it will be returned. Will be ignored if < 0. Default is -1.
+
`**kwargs` (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Example:
- Example of overriding retrieve_docs - If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code.
+ Example of overriding retrieve_docs - If you have set up a customized vector db, and it's
+ not compatible with chromadb, you can easily plug in it with below code.
+ **Deprecated**: Use `vector_db` instead. You can extend VectorDB and pass it to the agent.
```python
class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent):
def query_vector_db(
@@ -166,9 +251,12 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._task = self._retrieve_config.get("task", "default")
+ self._vector_db = self._retrieve_config.get("vector_db", "chroma")
+ self._db_config = self._retrieve_config.get("db_config", {})
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", None)
self._extra_docs = self._retrieve_config.get("extra_docs", False)
+ self._new_docs = self._retrieve_config.get("new_docs", True)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
logger.warning(
@@ -187,25 +275,104 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
+ self._overwrite = self._retrieve_config.get("overwrite", False)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
self._recursive = self._retrieve_config.get("recursive", True)
- self._context_max_tokens = self._max_tokens * 0.8
+ self._context_max_tokens = self._retrieve_config.get("context_max_tokens", self._max_tokens * 0.8)
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
- self._results = {} # the results of the current query
+ self._results = [] # the results of the current query
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
+ self._current_docs_in_context = [] # the ids of the current context sources
self._search_string = "" # the search string used in the current query
+ self._distance_threshold = self._retrieve_config.get("distance_threshold", -1)
# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
)
+ if isinstance(self._vector_db, str):
+ if not isinstance(self._db_config, dict):
+ raise ValueError("`db_config` should be a dictionary.")
+ if "embedding_function" in self._retrieve_config:
+ self._db_config["embedding_function"] = self._embedding_function
+ self._vector_db = VectorDBFactory.create_vector_db(db_type=self._vector_db, **self._db_config)
self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=2)
+ def _init_db(self):
+ if not self._vector_db:
+ return
+
+ IS_TO_CHUNK = False # whether to chunk the raw files
+ if self._new_docs:
+ IS_TO_CHUNK = True
+ if not self._docs_path:
+ try:
+ self._vector_db.get_collection(self._collection_name)
+ logger.warning(f"`docs_path` is not provided. Use the existing collection `{self._collection_name}`.")
+ self._overwrite = False
+ self._get_or_create = True
+ IS_TO_CHUNK = False
+ except ValueError:
+ raise ValueError(
+ "`docs_path` is not provided. "
+ f"The collection `{self._collection_name}` doesn't exist either. "
+ "Please provide `docs_path` or create the collection first."
+ )
+ elif self._get_or_create and not self._overwrite:
+ try:
+ self._vector_db.get_collection(self._collection_name)
+ logger.info(f"Use the existing collection `{self._collection_name}`.", color="green")
+ except ValueError:
+ IS_TO_CHUNK = True
+ else:
+ IS_TO_CHUNK = True
+
+ self._vector_db.active_collection = self._vector_db.create_collection(
+ self._collection_name, overwrite=self._overwrite, get_or_create=self._get_or_create
+ )
+
+ docs = None
+ if IS_TO_CHUNK:
+ if self.custom_text_split_function is not None:
+ chunks, sources = split_files_to_chunks(
+ get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
+ custom_text_split_function=self.custom_text_split_function,
+ )
+ else:
+ chunks, sources = split_files_to_chunks(
+ get_files_from_dir(self._docs_path, self._custom_text_types, self._recursive),
+ self._max_tokens,
+ self._chunk_mode,
+ self._must_break_at_empty_line,
+ )
+ logger.info(f"Found {len(chunks)} chunks.")
+
+ if self._new_docs:
+ all_docs_ids = set(
+ [
+ doc["id"]
+ for doc in self._vector_db.get_docs_by_ids(ids=None, collection_name=self._collection_name)
+ ]
+ )
+ else:
+ all_docs_ids = set()
+
+ chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
+ chunk_ids_set = set(chunk_ids)
+ chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
+ docs = [
+ Document(id=chunk_ids[idx], content=chunks[idx], metadata=sources[idx])
+ for idx in chunk_ids_set_idx
+ if chunk_ids[idx] not in all_docs_ids
+ ]
+
+ self._vector_db.insert_docs(docs=docs, collection_name=self._collection_name, upsert=True)
+
def _is_termination_msg_retrievechat(self, message):
"""Check if a message is a termination message.
For code generation, terminate when no code block is detected. Currently only detect python code blocks.
@@ -238,37 +405,42 @@ def get_max_tokens(model="gpt-3.5-turbo"):
def _reset(self, intermediate=False):
self._doc_idx = -1 # the index of the current used doc
- self._results = {} # the results of the current query
+ self._results = [] # the results of the current query
if not intermediate:
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
- def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
+ def _get_context(self, results: QueryResults):
doc_contents = ""
+ self._current_docs_in_context = []
current_tokens = 0
_doc_idx = self._doc_idx
_tmp_retrieve_count = 0
- for idx, doc in enumerate(results["documents"][0]):
+ for idx, doc in enumerate(results[0]):
+ doc = doc[0]
if idx <= _doc_idx:
continue
- if results["ids"][0][idx] in self._doc_ids:
+ if doc["id"] in self._doc_ids:
continue
- _doc_tokens = self.custom_token_count_function(doc, self._model)
+ _doc_tokens = self.custom_token_count_function(doc["content"], self._model)
if _doc_tokens > self._context_max_tokens:
- func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
+ func_print = f"Skip doc_id {doc['id']} as it is too long to fit in the context."
print(colored(func_print, "green"), flush=True)
self._doc_idx = idx
continue
if current_tokens + _doc_tokens > self._context_max_tokens:
break
- func_print = f"Adding doc_id {results['ids'][0][idx]} to context."
+ func_print = f"Adding content of doc {doc['id']} to context."
print(colored(func_print, "green"), flush=True)
current_tokens += _doc_tokens
- doc_contents += doc + "\n"
+ doc_contents += doc["content"] + "\n"
+ _metadata = doc.get("metadata")
+ if isinstance(_metadata, dict):
+ self._current_docs_in_context.append(_metadata.get("source", ""))
self._doc_idx = idx
- self._doc_ids.append(results["ids"][0][idx])
- self._doc_contents.append(doc)
+ self._doc_ids.append(doc["id"])
+ self._doc_contents.append(doc["content"])
_tmp_retrieve_count += 1
if _tmp_retrieve_count >= self.n_results:
break
@@ -285,7 +457,9 @@ def _generate_message(self, doc_contents, task="default"):
elif task.upper() == "QA":
message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents)
elif task.upper() == "DEFAULT":
- message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents)
+ message = PROMPT_DEFAULT.format(
+ input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context
+ )
else:
raise NotImplementedError(f"task {task} is not implemented.")
return message
@@ -360,21 +534,40 @@ def _generate_retrieve_user_reply(
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
- In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
- compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
- parameters and add your own parameters with default values, and keep the results in below type.
-
- Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of
- the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer
- to `chromadb.api.types.QueryResult` as an example.
- ids: List[string]
- documents: List[List[string]]
+ The retrieved docs should be type of `QueryResults` which is a list of tuples containing the document and
+ the distance.
Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved. Default is 20.
search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
+ Not used if the vector_db doesn't support it.
+
+ Returns:
+ None.
"""
+ if isinstance(self._vector_db, VectorDB):
+ if not self._collection or not self._get_or_create:
+ print("Trying to create collection.")
+ self._init_db()
+ self._collection = True
+ self._get_or_create = True
+
+ kwargs = {}
+ if hasattr(self._vector_db, "type") and self._vector_db.type == "chroma":
+ kwargs["where_document"] = {"$contains": search_string} if search_string else None
+ results = self._vector_db.retrieve_docs(
+ queries=[problem],
+ n_results=n_results,
+ collection_name=self._collection_name,
+ distance_threshold=self._distance_threshold,
+ **kwargs,
+ )
+ self._search_string = search_string
+ self._results = results
+ print("VectorDB returns doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
+ return
+
if not self._collection or not self._get_or_create:
print("Trying to create collection.")
self._client = create_vector_db_from_dir(
@@ -404,9 +597,13 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results, "distances")
+ results = filter_results_by_distance(results, self._distance_threshold)
+
self._search_string = search_string
self._results = results
- print("doc_ids: ", results["ids"])
+ print("doc_ids: ", [[r[0]["id"] for r in rr] for rr in results])
@staticmethod
def message_generator(sender, recipient, context):
@@ -416,9 +613,9 @@ def message_generator(sender, recipient, context):
sender (Agent): the sender agent. It should be the instance of RetrieveUserProxyAgent.
recipient (Agent): the recipient agent. Usually it's the assistant agent.
context (dict): the context for the message generation. It should contain the following keys:
- - problem (str): the problem to be solved.
- - n_results (int): the number of results to be retrieved. Default is 20.
- - search_string (str): only docs that contain an exact match of this string will be retrieved. Default is "".
+ - `problem` (str) - the problem to be solved.
+ - `n_results` (int) - the number of results to be retrieved. Default is 20.
+ - `search_string` (str) - only docs that contain an exact match of this string will be retrieved. Default is "".
Returns:
str: the generated message ready to be sent to the recipient agent.
"""
diff --git a/autogen/agentchat/contrib/society_of_mind_agent.py b/autogen/agentchat/contrib/society_of_mind_agent.py
index 6a6f4aa2186..97cf6aee1a5 100644
--- a/autogen/agentchat/contrib/society_of_mind_agent.py
+++ b/autogen/agentchat/contrib/society_of_mind_agent.py
@@ -1,10 +1,11 @@
# ruff: noqa: E722
+import copy
import json
import traceback
-import copy
from dataclasses import dataclass
-from typing import Dict, List, Optional, Union, Callable, Literal, Tuple
-from autogen import Agent, ConversableAgent, GroupChatManager, GroupChat, OpenAIWrapper
+from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
+
+from autogen import Agent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper
class SocietyOfMindAgent(ConversableAgent):
diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py
index 10100c9e57f..e917cca574f 100644
--- a/autogen/agentchat/contrib/text_analyzer_agent.py
+++ b/autogen/agentchat/contrib/text_analyzer_agent.py
@@ -1,7 +1,8 @@
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
from autogen import oai
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
-from typing import Callable, Dict, Optional, Union, List, Tuple, Any
system_message = """You are an expert in text analysis.
The user will give you TEXT to analyze.
diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py
new file mode 100644
index 00000000000..29a08008619
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/base.py
@@ -0,0 +1,213 @@
+from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable
+
+Metadata = Union[Mapping[str, Any], None]
+Vector = Union[Sequence[float], Sequence[int]]
+ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
+
+
+class Document(TypedDict):
+ """A Document is a record in the vector database.
+
+ id: ItemID | the unique identifier of the document.
+ content: str | the text content of the chunk.
+ metadata: Metadata, Optional | contains additional information about the document such as source, date, etc.
+ embedding: Vector, Optional | the vector representation of the content.
+ """
+
+ id: ItemID
+ content: str
+ metadata: Optional[Metadata]
+ embedding: Optional[Vector]
+
+
+"""QueryResults is the response from the vector database for a query/queries.
+A query is a list containing one string while queries is a list containing multiple strings.
+The response is a list of query results, each query result is a list of tuples containing the document and the distance.
+"""
+QueryResults = List[List[Tuple[Document, float]]]
+
+
+@runtime_checkable
+class VectorDB(Protocol):
+ """
+ Abstract class for vector database. A vector database is responsible for storing and retrieving documents.
+
+ Attributes:
+ active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
+ type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
+
+ Methods:
+ create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database.
+ get_collection: Callable[[str], Any] | Get the collection from the vector database.
+ delete_collection: Callable[[str], Any] | Delete the collection from the vector database.
+ insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database.
+ update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database.
+ delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database.
+ retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries.
+ get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids.
+ """
+
+ active_collection: Any = None
+ type: str = ""
+
+ def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Any | The collection object.
+ """
+ ...
+
+ def get_collection(self, collection_name: str = None) -> Any:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Any | The collection object.
+ """
+ ...
+
+ def delete_collection(self, collection_name: str) -> Any:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ Any
+ """
+ ...
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ ...
+
+ def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ ...
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ ...
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ ...
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ ...
+
+
+class VectorDBFactory:
+ """
+ Factory class for creating vector databases.
+ """
+
+ PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
+
+ @staticmethod
+ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
+ """
+ Create a vector database.
+
+ Args:
+ db_type: str | The type of the vector database.
+ kwargs: Dict | The keyword arguments for initializing the vector database.
+
+ Returns:
+ VectorDB | The vector database.
+ """
+ if db_type.lower() in ["chroma", "chromadb"]:
+ from .chromadb import ChromaVectorDB
+
+ return ChromaVectorDB(**kwargs)
+ if db_type.lower() in ["pgvector", "pgvectordb"]:
+ from .pgvectordb import PGVectorDB
+
+ return PGVectorDB(**kwargs)
+ else:
+ raise ValueError(
+ f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
+ )
diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py
new file mode 100644
index 00000000000..3f1fbc86a44
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/chromadb.py
@@ -0,0 +1,318 @@
+import os
+from typing import Callable, List
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger
+
+try:
+ import chromadb
+
+ if chromadb.__version__ < "0.4.15":
+ raise ImportError("Please upgrade chromadb to version 0.4.15 or later.")
+ import chromadb.utils.embedding_functions as ef
+ from chromadb.api.models.Collection import Collection
+except ImportError:
+ raise ImportError("Please install chromadb: `pip install chromadb`")
+
+CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)
+logger = get_logger(__name__)
+
+
+class ChromaVectorDB(VectorDB):
+ """
+ A vector database that uses ChromaDB as the backend.
+ """
+
+ def __init__(
+ self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Args:
+ client: chromadb.Client | The client object of the vector database. Default is None.
+ If provided, it will use the client object directly and ignore other arguments.
+ path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24.
+ embedding_function: Callable | The embedding function used to generate the vector representation
+ of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
+ metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
+ setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of
+ the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances),
+ [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184),
+ and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md).
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ self.client = client
+ self.path = path
+ self.embedding_function = (
+ ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2")
+ if embedding_function is None
+ else embedding_function
+ )
+ self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}
+ if not self.client:
+ if self.path is not None:
+ self.client = chromadb.PersistentClient(path=self.path, **kwargs)
+ else:
+ self.client = chromadb.Client(**kwargs)
+ self.active_collection = None
+ self.type = "chroma"
+
+ def create_collection(
+ self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
+ ) -> Collection:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Collection | The collection object.
+ """
+ try:
+ if self.active_collection and self.active_collection.name == collection_name:
+ collection = self.active_collection
+ else:
+ collection = self.client.get_collection(collection_name)
+ except ValueError:
+ collection = None
+ if collection is None:
+ return self.client.create_collection(
+ collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ elif overwrite:
+ self.client.delete_collection(collection_name)
+ return self.client.create_collection(
+ collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ )
+ elif get_or_create:
+ return collection
+ else:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.info(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ if not (self.active_collection and self.active_collection.name == collection_name):
+ self.active_collection = self.client.get_collection(collection_name)
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ None
+ """
+ self.client.delete_collection(collection_name)
+ if self.active_collection and self.active_collection.name == collection_name:
+ self.active_collection = None
+
+ def _batch_insert(
+ self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
+ ) -> None:
+ batch_size = int(CHROMADB_MAX_BATCH_SIZE)
+ for i in range(0, len(documents), min(batch_size, len(documents))):
+ end_idx = i + min(batch_size, len(documents) - i)
+ collection_kwargs = {
+ "documents": documents[i:end_idx],
+ "ids": ids[i:end_idx],
+ "metadatas": metadatas[i:end_idx] if metadatas else None,
+ "embeddings": embeddings[i:end_idx] if embeddings else None,
+ }
+ if upsert:
+ collection.upsert(**collection_kwargs)
+ else:
+ collection.add(**collection_kwargs)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+ documents = [doc.get("content") for doc in docs]
+ ids = [doc.get("id") for doc in docs]
+ collection = self.get_collection(collection_name)
+ if docs[0].get("embedding") is None:
+ logger.info(
+ "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding."
+ )
+ embeddings = None
+ else:
+ embeddings = [doc.get("embedding") for doc in docs]
+ if docs[0].get("metadata") is None:
+ metadatas = None
+ else:
+ metadatas = [doc.get("metadata") for doc in docs]
+ self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+
+ Returns:
+ None
+ """
+ self.insert_docs(docs, collection_name, upsert=True)
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ collection = self.get_collection(collection_name)
+ collection.delete(ids, **kwargs)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ **kwargs,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ collection = self.get_collection(collection_name)
+ if isinstance(queries, str):
+ queries = [queries]
+ results = collection.query(
+ query_texts=queries,
+ n_results=n_results,
+ **kwargs,
+ )
+ results["contents"] = results.pop("documents")
+ results = chroma_results_to_query_results(results)
+ results = filter_results_by_distance(results, distance_threshold)
+ return results
+
+ @staticmethod
+ def _chroma_get_results_to_list_documents(data_dict) -> List[Document]:
+ """Converts a dictionary with list values to a list of Document.
+
+ Args:
+ data_dict: A dictionary where keys map to lists or None.
+
+ Returns:
+ List[Document] | The list of Document.
+
+ Example:
+ data_dict = {
+ "key1s": [1, 2, 3],
+ "key2s": ["a", "b", "c"],
+ "key3s": None,
+ "key4s": ["x", "y", "z"],
+ }
+
+ results = [
+ {"key1": 1, "key2": "a", "key4": "x"},
+ {"key1": 2, "key2": "b", "key4": "y"},
+ {"key1": 3, "key2": "c", "key4": "z"},
+ ]
+ """
+
+ results = []
+ keys = [key for key in data_dict if data_dict[key] is not None]
+
+ for i in range(len(data_dict[keys[0]])):
+ sub_dict = {}
+ for key in data_dict.keys():
+ if data_dict[key] is not None and len(data_dict[key]) > i:
+ sub_dict[key[:-1]] = data_dict[key][i]
+ results.append(sub_dict)
+ return results
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ collection = self.get_collection(collection_name)
+ include = include if include else ["metadatas", "documents"]
+ results = collection.get(ids, include=include, **kwargs)
+ results = self._chroma_get_results_to_list_documents(results)
+ return results
diff --git a/autogen/agentchat/contrib/vectordb/pgvectordb.py b/autogen/agentchat/contrib/vectordb/pgvectordb.py
new file mode 100644
index 00000000000..b5db55f7eb1
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/pgvectordb.py
@@ -0,0 +1,888 @@
+import os
+import re
+import urllib.parse
+from typing import Callable, List
+
+import numpy as np
+from sentence_transformers import SentenceTransformer
+
+from .base import Document, ItemID, QueryResults, VectorDB
+from .utils import get_logger
+
+try:
+ import pgvector
+ from pgvector.psycopg import register_vector
+except ImportError:
+ raise ImportError("Please install pgvector: `pip install pgvector`")
+
+try:
+ import psycopg
+except ImportError:
+ raise ImportError("Please install pgvector: `pip install psycopg`")
+
+PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
+logger = get_logger(__name__)
+
+
+class Collection:
+ """
+ A Collection object for PGVector.
+
+ Attributes:
+ client: The PGVector client.
+ collection_name (str): The name of the collection. Default is "documents".
+ embedding_function (Callable): The embedding function used to generate the vector representation.
+ metadata (Optional[dict]): The metadata of the collection.
+ get_or_create (Optional): The flag indicating whether to get or create the collection.
+ model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+ """
+
+ def __init__(
+ self,
+ client=None,
+ collection_name: str = "autogen-docs",
+ embedding_function: Callable = None,
+ metadata=None,
+ get_or_create=None,
+ model_name="all-MiniLM-L6-v2",
+ ):
+ """
+ Initialize the Collection object.
+
+ Args:
+ client: The PostgreSQL client.
+ collection_name: The name of the collection. Default is "documents".
+ embedding_function: The embedding function used to generate the vector representation.
+ metadata: The metadata of the collection.
+ get_or_create: The flag indicating whether to get or create the collection.
+ model_name: | Sentence embedding model to use. Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+ Returns:
+ None
+ """
+ self.client = client
+ self.embedding_function = embedding_function
+ self.model_name = model_name
+ self.name = self.set_collection_name(collection_name)
+ self.require_embeddings_or_documents = False
+ self.ids = []
+ try:
+ self.embedding_function = (
+ SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
+ )
+ except Exception as e:
+ logger.error(
+ f"Validate the model name entered: {self.model_name} "
+ f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
+ )
+ raise e
+ self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
+ self.documents = ""
+ self.get_or_create = get_or_create
+
+ def set_collection_name(self, collection_name) -> str:
+ name = re.sub("-", "_", collection_name)
+ self.name = name
+ return self.name
+
+ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
+ """
+ Add documents to the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs.
+ embeddings (List): A list of document embeddings. Optional
+ metadatas (List): A list of document metadatas. Optional
+ documents (List): A list of documents.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ if embeddings is not None and metadatas is not None:
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ sql_values.append((doc_id, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
+ )
+ elif embeddings is not None:
+ for doc_id, embedding, document in zip(ids, embeddings, documents):
+ sql_values.append((doc_id, embedding, document))
+ sql_string = f"INSERT INTO {self.name} (id, embedding, documents) " f"VALUES (%s, %s, %s);\n"
+ elif metadatas is not None:
+ for doc_id, metadata, document in zip(ids, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ embedding = self.embedding_function.encode(document)
+ sql_values.append((doc_id, metadata, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
+ )
+ else:
+ for doc_id, document in zip(ids, documents):
+ embedding = self.embedding_function.encode(document)
+ sql_values.append((doc_id, document, embedding))
+ sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
+ logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
+ """
+ Upsert documents into the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs.
+ documents (List): A list of documents.
+ embeddings (List): A list of document embeddings.
+ metadatas (List): A list of document metadatas.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ if embeddings is not None and metadatas is not None:
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n"
+ f"VALUES (%s, %s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET embedding = %s,\n"
+ f"metadatas = %s, documents = %s;\n"
+ )
+ elif embeddings is not None:
+ for doc_id, embedding, document in zip(ids, embeddings, documents):
+ sql_values.append((doc_id, embedding, document, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, documents) "
+ f"VALUES (%s, %s, %s) ON CONFLICT (id)\n"
+ f"DO UPDATE SET embedding = %s, documents = %s;\n"
+ )
+ elif metadatas is not None:
+ for doc_id, metadata, document in zip(ids, metadatas, documents):
+ metadata = re.sub("'", '"', str(metadata))
+ embedding = self.embedding_function.encode(document)
+ sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
+ f"VALUES (%s, %s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n"
+ )
+ else:
+ for doc_id, document in zip(ids, documents):
+ embedding = self.embedding_function.encode(document)
+ sql_values.append((doc_id, document, embedding, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, documents, embedding)\n"
+ f"VALUES (%s, %s, %s)\n"
+ f"ON CONFLICT (id)\n"
+ f"DO UPDATE SET documents = %s;\n"
+ )
+ logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ def count(self) -> int:
+ """
+ Get the total number of documents in the collection.
+
+ Returns:
+ int: The total number of documents.
+ """
+ cursor = self.client.cursor()
+ query = f"SELECT COUNT(*) FROM {self.name}"
+ cursor.execute(query)
+ total = cursor.fetchone()[0]
+ cursor.close()
+ try:
+ total = int(total)
+ except (TypeError, ValueError):
+ total = None
+ return total
+
+ def table_exists(self, table_name: str) -> bool:
+ """
+ Check if a table exists in the PostgreSQL database.
+
+ Args:
+ table_name (str): The name of the table to check.
+
+ Returns:
+ bool: True if the table exists, False otherwise.
+ """
+
+ cursor = self.client.cursor()
+ cursor.execute(
+ """
+ SELECT EXISTS (
+ SELECT 1
+ FROM information_schema.tables
+ WHERE table_name = %s
+ )
+ """,
+ (table_name,),
+ )
+ exists = cursor.fetchone()[0]
+ return exists
+
+ def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> List[Document]:
+ """
+ Retrieve documents from the collection.
+
+ Args:
+ ids (Optional[List]): A list of document IDs.
+ include (Optional): The fields to include.
+ where (Optional): Additional filtering criteria.
+ limit (Optional): The maximum number of documents to retrieve.
+ offset (Optional): The offset for pagination.
+
+ Returns:
+ List: The retrieved documents.
+ """
+ cursor = self.client.cursor()
+
+ # Initialize variables for query components
+ select_clause = "SELECT id, metadatas, documents, embedding"
+ from_clause = f"FROM {self.name}"
+ where_clause = ""
+ limit_clause = ""
+ offset_clause = ""
+
+ # Handle include clause
+ if include:
+ select_clause = f"SELECT id, {', '.join(include)}, embedding"
+
+ # Handle where clause
+ if ids:
+ where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
+ elif where:
+ where_clause = f"WHERE {where}"
+
+ # Handle limit and offset clauses
+ if limit:
+ limit_clause = "LIMIT %s"
+ if offset:
+ offset_clause = "OFFSET %s"
+
+ # Construct the full query
+ query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
+
+ retrieved_documents = []
+ try:
+ # Execute the query with the appropriate values
+ if ids is not None:
+ cursor.execute(query, ids)
+ else:
+ query_params = []
+ if limit:
+ query_params.append(limit)
+ if offset:
+ query_params.append(offset)
+ cursor.execute(query, query_params)
+
+ retrieval = cursor.fetchall()
+ for retrieved_document in retrieval:
+ retrieved_documents.append(
+ Document(
+ id=retrieved_document[0].strip(),
+ metadata=retrieved_document[1],
+ content=retrieved_document[2],
+ embedding=retrieved_document[3],
+ )
+ )
+ except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
+ logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
+ self.create_collection(collection_name=self.name)
+ logger.info(f"Created table {self.name}")
+
+ cursor.close()
+ return retrieved_documents
+
+ def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None:
+ """
+ Update documents in the collection.
+
+ Args:
+ ids (List): A list of document IDs.
+ embeddings (List): A list of document embeddings.
+ metadatas (List): A list of document metadatas.
+ documents (List): A list of documents.
+
+ Returns:
+ None
+ """
+ cursor = self.client.cursor()
+ sql_values = []
+ for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
+ sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document))
+ sql_string = (
+ f"INSERT INTO {self.name} (id, embedding, metadata, document) "
+ f"VALUES (%s, %s, %s, %s) "
+ f"ON CONFLICT (id) "
+ f"DO UPDATE SET id = %s, embedding = %s, "
+ f"metadata = %s, document = %s;\n"
+ )
+ logger.debug(f"Upsert SQL String:\n{sql_string}\n")
+ cursor.executemany(sql_string, sql_values)
+ cursor.close()
+
+ @staticmethod
+ def euclidean_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the Euclidean distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The Euclidean distance between arr1 and arr2.
+ """
+ dist = np.linalg.norm(arr1 - arr2)
+ return dist
+
+ @staticmethod
+ def cosine_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the cosine distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The cosine distance between arr1 and arr2.
+ """
+ dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))
+ return dist
+
+ @staticmethod
+ def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
+ """
+ Calculate the Euclidean distance between two vectors.
+
+ Parameters:
+ - arr1 (List[float]): The first vector.
+ - arr2 (List[float]): The second vector.
+
+ Returns:
+ - float: The Euclidean distance between arr1 and arr2.
+ """
+ dist = np.linalg.norm(arr1 - arr2)
+ return dist
+
+ def query(
+ self,
+ query_texts: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_type: str = "euclidean",
+ distance_threshold: float = -1,
+ include_embedding: bool = False,
+ ) -> QueryResults:
+ """
+ Query documents in the collection.
+
+ Args:
+ query_texts (List[str]): A list of query texts.
+ collection_name (Optional[str]): The name of the collection.
+ n_results (int): The maximum number of results to return.
+ distance_type (Optional[str]): Distance search type - euclidean or cosine
+ distance_threshold (Optional[float]): Distance threshold to limit searches
+ include_embedding (Optional[bool]): Include embedding values in QueryResults
+ Returns:
+ QueryResults: The query results.
+ """
+ if collection_name:
+ self.name = collection_name
+
+ clause = "ORDER BY"
+ if distance_threshold == -1:
+ distance_threshold = ""
+ clause = "ORDER BY"
+ elif distance_threshold > 0:
+ distance_threshold = f"< {distance_threshold}"
+ clause = "WHERE"
+
+ cursor = self.client.cursor()
+ results = []
+ for query_text in query_texts:
+ vector = self.embedding_function.encode(query_text, convert_to_tensor=False).tolist()
+ if distance_type.lower() == "cosine":
+ index_function = "<=>"
+ elif distance_type.lower() == "euclidean":
+ index_function = "<->"
+ elif distance_type.lower() == "inner-product":
+ index_function = "<#>"
+ else:
+ index_function = "<->"
+ query = (
+ f"SELECT id, documents, embedding, metadatas "
+ f"FROM {self.name} "
+ f"{clause} embedding {index_function} '{str(vector)}' {distance_threshold} "
+ f"LIMIT {n_results}"
+ )
+ cursor.execute(query)
+ result = []
+ for row in cursor.fetchall():
+ fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
+ fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
+ if distance_type.lower() == "cosine":
+ distance = self.cosine_distance(fetched_document_array, vector)
+ elif distance_type.lower() == "euclidean":
+ distance = self.euclidean_distance(fetched_document_array, vector)
+ elif distance_type.lower() == "inner-product":
+ distance = self.inner_product_distance(fetched_document_array, vector)
+ else:
+ distance = self.euclidean_distance(fetched_document_array, vector)
+ if not include_embedding:
+ fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
+ result.append((fetched_document, distance))
+ results.append(result)
+ cursor.close()
+ logger.debug(f"Query Results: {results}")
+ return results
+
+ @staticmethod
+ def convert_string_to_array(array_string) -> List[float]:
+ """
+ Convert a string representation of an array to a list of floats.
+
+ Parameters:
+ - array_string (str): The string representation of the array.
+
+ Returns:
+ - list: A list of floats parsed from the input string. If the input is
+ not a string, it returns the input itself.
+ """
+ if not isinstance(array_string, str):
+ return array_string
+ array_string = array_string.strip("[]")
+ array = [float(num) for num in array_string.split()]
+ return array
+
+ def modify(self, metadata, collection_name: str = None) -> None:
+ """
+ Modify metadata for the collection.
+
+ Args:
+ collection_name: The name of the collection.
+ metadata: The new metadata.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(
+ "UPDATE collections" "SET metadata = '%s'" "WHERE collection_name = '%s';", (metadata, self.name)
+ )
+ cursor.close()
+
+ def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
+ """
+ Delete documents from the collection.
+
+ Args:
+ ids (List[ItemID]): A list of document IDs to delete.
+ collection_name (str): The name of the collection to delete.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ id_placeholders = ", ".join(["%s" for _ in ids])
+ cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
+ cursor.close()
+
+ def delete_collection(self, collection_name: str = None) -> None:
+ """
+ Delete the entire collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the collection to delete.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
+ cursor.close()
+
+ def create_collection(self, collection_name: str = None) -> None:
+ """
+ Create a new collection.
+
+ Args:
+ collection_name (Optional[str]): The name of the new collection.
+
+ Returns:
+ None
+ """
+ if collection_name:
+ self.name = collection_name
+ cursor = self.client.cursor()
+ cursor.execute(
+ f"CREATE TABLE {self.name} ("
+ f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384));"
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ f"CREATE INDEX "
+ f'ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata["hnsw:M"]}, '
+ f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
+ )
+ cursor.close()
+
+
+class PGVectorDB(VectorDB):
+ """
+ A vector database that uses PGVector as the backend.
+ """
+
+ def __init__(
+ self,
+ *,
+ connection_string: str = None,
+ host: str = None,
+ port: int = None,
+ dbname: str = None,
+ username: str = None,
+ password: str = None,
+ connect_timeout: int = 10,
+ embedding_function: Callable = None,
+ metadata: dict = None,
+ model_name: str = "all-MiniLM-L6-v2",
+ ) -> None:
+ """
+ Initialize the vector database.
+
+ Note: connection_string or host + port + dbname must be specified
+
+ Args:
+ connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
+ host: str | The host to connect to. Default is None.
+ port: int | The port to connect to. Default is None.
+ dbname: str | The database name to connect to. Default is None.
+ username: str | The database username to use. Default is None.
+ password: str | The database user password to use. Default is None.
+ connect_timeout: int | The timeout to set for the connection. Default is 10.
+ embedding_function: Callable | The embedding function used to generate the vector representation
+ of the documents. Default is None.
+ metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
+ setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
+ using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
+ For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
+ model_name: str | Sentence embedding model to use. Models can be chosen from:
+ https://huggingface.co/models?library=sentence-transformers
+
+ Returns:
+ None
+ """
+ try:
+ if connection_string:
+ parsed_connection = urllib.parse.urlparse(connection_string)
+ encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
+ encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
+ encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
+ encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
+ connection_string_encoded = (
+ f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}"
+ f"@{encoded_host}:{parsed_connection.port}/{encoded_database}"
+ )
+ self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
+ elif host and port and dbname:
+ self.client = psycopg.connect(
+ host=host,
+ port=port,
+ dbname=dbname,
+ username=username,
+ password=password,
+ connect_timeout=connect_timeout,
+ autocommit=True,
+ )
+ except psycopg.Error as e:
+ logger.error("Error connecting to the database: ", e)
+ raise e
+ self.model_name = model_name
+ try:
+ self.embedding_function = (
+ SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
+ )
+ except Exception as e:
+ logger.error(
+ f"Validate the model name entered: {self.model_name} "
+ f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
+ )
+ raise e
+ self.metadata = metadata
+ self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
+ register_vector(self.client)
+ self.active_collection = None
+
+ def create_collection(
+ self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
+ ) -> Collection:
+ """
+ Create a collection in the vector database.
+ Case 1. if the collection does not exist, create the collection.
+ Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
+ Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
+ otherwise it raise a ValueError.
+
+ Args:
+ collection_name: str | The name of the collection.
+ overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
+ get_or_create: bool | Whether to get the collection if it exists. Default is True.
+
+ Returns:
+ Collection | The collection object.
+ """
+ try:
+ if self.active_collection and self.active_collection.name == collection_name:
+ collection = self.active_collection
+ else:
+ collection = self.get_collection(collection_name)
+ except ValueError:
+ collection = None
+ if collection is None:
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ model_name=self.model_name,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ elif overwrite:
+ self.delete_collection(collection_name)
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ model_name=self.model_name,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ elif get_or_create:
+ return collection
+ elif not collection.table_exists(table_name=collection_name):
+ collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ get_or_create=get_or_create,
+ metadata=self.metadata,
+ model_name=self.model_name,
+ )
+ collection.set_collection_name(collection_name=collection_name)
+ collection.create_collection(collection_name=collection_name)
+ return collection
+ else:
+ raise ValueError(f"Collection {collection_name} already exists.")
+
+ def get_collection(self, collection_name: str = None) -> Collection:
+ """
+ Get the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection. Default is None. If None, return the
+ current active collection.
+
+ Returns:
+ Collection | The collection object.
+ """
+ if collection_name is None:
+ if self.active_collection is None:
+ raise ValueError("No collection is specified.")
+ else:
+ logger.debug(
+ f"No collection is specified. Using current active collection {self.active_collection.name}."
+ )
+ else:
+ if not (self.active_collection and self.active_collection.name == collection_name):
+ self.active_collection = Collection(
+ client=self.client,
+ collection_name=collection_name,
+ embedding_function=self.embedding_function,
+ model_name=self.model_name,
+ )
+ return self.active_collection
+
+ def delete_collection(self, collection_name: str) -> None:
+ """
+ Delete the collection from the vector database.
+
+ Args:
+ collection_name: str | The name of the collection.
+
+ Returns:
+ None
+ """
+ if self.active_collection:
+ self.active_collection.delete_collection(collection_name)
+ else:
+ collection = self.get_collection(collection_name)
+ collection.delete_collection(collection_name)
+ if self.active_collection and self.active_collection.name == collection_name:
+ self.active_collection = None
+
+ def _batch_insert(
+ self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
+ ) -> None:
+ batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
+ default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
+ default_metadatas = [default_metadata] * min(batch_size, len(documents))
+ for i in range(0, len(documents), min(batch_size, len(documents))):
+ end_idx = i + min(batch_size, len(documents) - i)
+ collection_kwargs = {
+ "documents": documents[i:end_idx],
+ "ids": ids[i:end_idx],
+ "metadatas": metadatas[i:end_idx] if metadatas else default_metadatas,
+ "embeddings": embeddings[i:end_idx] if embeddings else None,
+ }
+ if upsert:
+ collection.upsert(**collection_kwargs)
+ else:
+ collection.add(**collection_kwargs)
+
+ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
+ """
+ Insert documents into the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
+ collection_name: str | The name of the collection. Default is None.
+ upsert: bool | Whether to update the document if it exists. Default is False.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ if not docs:
+ return
+ if docs[0].get("content") is None:
+ raise ValueError("The document content is required.")
+ if docs[0].get("id") is None:
+ raise ValueError("The document id is required.")
+ documents = [doc.get("content") for doc in docs]
+ ids = [doc.get("id") for doc in docs]
+
+ collection = self.get_collection(collection_name)
+ if docs[0].get("embedding") is None:
+ logger.debug(
+ "No content embedding is provided. "
+ "Will use the VectorDB's embedding function to generate the content embedding."
+ )
+ embeddings = None
+ else:
+ embeddings = [doc.get("embedding") for doc in docs]
+ if docs[0].get("metadata") is None:
+ metadatas = None
+ else:
+ metadatas = [doc.get("metadata") for doc in docs]
+
+ self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
+
+ def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
+ """
+ Update documents in the collection of the vector database.
+
+ Args:
+ docs: List[Document] | A list of documents.
+ collection_name: str | The name of the collection. Default is None.
+
+ Returns:
+ None
+ """
+ self.insert_docs(docs, collection_name, upsert=True)
+
+ def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None:
+ """
+ Delete documents from the collection of the vector database.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
+ collection_name: str | The name of the collection. Default is None.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ None
+ """
+ collection = self.get_collection(collection_name)
+ collection.delete(ids=ids, collection_name=collection_name)
+
+ def retrieve_docs(
+ self,
+ queries: List[str],
+ collection_name: str = None,
+ n_results: int = 10,
+ distance_threshold: float = -1,
+ ) -> QueryResults:
+ """
+ Retrieve documents from the collection of the vector database based on the queries.
+
+ Args:
+ queries: List[str] | A list of queries. Each query is a string.
+ collection_name: str | The name of the collection. Default is None.
+ n_results: int | The number of relevant documents to return. Default is 10.
+ distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
+ returned. Don't filter with it if < 0. Default is -1.
+ kwargs: Dict | Additional keyword arguments.
+
+ Returns:
+ QueryResults | The query results. Each query result is a list of list of tuples containing the document and
+ the distance.
+ """
+ collection = self.get_collection(collection_name)
+ if isinstance(queries, str):
+ queries = [queries]
+ results = collection.query(
+ query_texts=queries,
+ n_results=n_results,
+ distance_threshold=distance_threshold,
+ )
+ logger.debug(f"Retrieve Docs Results:\n{results}")
+ return results
+
+ def get_docs_by_ids(
+ self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
+ ) -> List[Document]:
+ """
+ Retrieve documents from the collection of the vector database based on the ids.
+
+ Args:
+ ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
+ collection_name: str | The name of the collection. Default is None.
+ include: List[str] | The fields to include. Default is None.
+ If None, will include ["metadatas", "documents"], ids will always be included.
+ kwargs: dict | Additional keyword arguments.
+
+ Returns:
+ List[Document] | The results.
+ """
+ collection = self.get_collection(collection_name)
+ include = include if include else ["metadatas", "documents"]
+ results = collection.get(ids, include=include, **kwargs)
+ logger.debug(f"Retrieve Documents by ID Results:\n{results}")
+ return results
diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py
new file mode 100644
index 00000000000..3dcf79f1f55
--- /dev/null
+++ b/autogen/agentchat/contrib/vectordb/utils.py
@@ -0,0 +1,115 @@
+import logging
+from typing import Any, Dict, List
+
+from termcolor import colored
+
+from .base import QueryResults
+
+
+class ColoredLogger(logging.Logger):
+ def __init__(self, name, level=logging.NOTSET):
+ super().__init__(name, level)
+
+ def debug(self, msg, *args, color=None, **kwargs):
+ super().debug(colored(msg, color), *args, **kwargs)
+
+ def info(self, msg, *args, color=None, **kwargs):
+ super().info(colored(msg, color), *args, **kwargs)
+
+ def warning(self, msg, *args, color="yellow", **kwargs):
+ super().warning(colored(msg, color), *args, **kwargs)
+
+ def error(self, msg, *args, color="light_red", **kwargs):
+ super().error(colored(msg, color), *args, **kwargs)
+
+ def critical(self, msg, *args, color="red", **kwargs):
+ super().critical(colored(msg, color), *args, **kwargs)
+
+ def fatal(self, msg, *args, color="red", **kwargs):
+ super().fatal(colored(msg, color), *args, **kwargs)
+
+
+def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger:
+ logger = ColoredLogger(name, level)
+ console_handler = logging.StreamHandler()
+ logger.addHandler(console_handler)
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+ logger.handlers[0].setFormatter(formatter)
+ return logger
+
+
+logger = get_logger(__name__)
+
+
+def filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults:
+ """Filters results based on a distance threshold.
+
+ Args:
+ results: QueryResults | The query results. List[List[Tuple[Document, float]]]
+ distance_threshold: The maximum distance allowed for results.
+
+ Returns:
+ QueryResults | A filtered results containing only distances smaller than the threshold.
+ """
+
+ if distance_threshold > 0:
+ results = [[(key, value) for key, value in data if value < distance_threshold] for data in results]
+
+ return results
+
+
+def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults:
+ """Converts a dictionary with list-of-list values to a list of tuples.
+
+ Args:
+ data_dict: A dictionary where keys map to lists of lists or None.
+ special_key: The key in the dictionary containing the special values
+ for each tuple.
+
+ Returns:
+ A list of tuples, where each tuple contains a sub-dictionary with
+ some keys from the original dictionary and the value from the
+ special_key.
+
+ Example:
+ data_dict = {
+ "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]],
+ "key3s": None,
+ "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]],
+ "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],
+ }
+
+ results = [
+ [
+ ({"key1": 1, "key2": "a", "key4": "x"}, 0.1),
+ ({"key1": 2, "key2": "b", "key4": "y"}, 0.2),
+ ({"key1": 3, "key2": "c", "key4": "z"}, 0.3),
+ ],
+ [
+ ({"key1": 4, "key2": "c", "key4": "1"}, 0.4),
+ ({"key1": 5, "key2": "d", "key4": "2"}, 0.5),
+ ({"key1": 6, "key2": "e", "key4": "3"}, 0.6),
+ ],
+ [
+ ({"key1": 7, "key2": "e", "key4": "4"}, 0.7),
+ ({"key1": 8, "key2": "f", "key4": "5"}, 0.8),
+ ({"key1": 9, "key2": "g", "key4": "6"}, 0.9),
+ ],
+ ]
+ """
+
+ keys = [key for key in data_dict if key != special_key]
+ result = []
+
+ for i in range(len(data_dict[special_key])):
+ sub_result = []
+ for j, distance in enumerate(data_dict[special_key][i]):
+ sub_dict = {}
+ for key in keys:
+ if data_dict[key] is not None and len(data_dict[key]) > i:
+ sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key
+ sub_result.append((sub_dict, distance))
+ result.append(sub_result)
+
+ return result
diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py
index 6cd71dc636d..1a54aeebe15 100644
--- a/autogen/agentchat/contrib/web_surfer.py
+++ b/autogen/agentchat/contrib/web_surfer.py
@@ -1,16 +1,18 @@
-import json
import copy
+import json
import logging
import re
from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Union, Callable, Literal, Tuple
+from datetime import datetime
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
from typing_extensions import Annotated
-from ... import Agent, ConversableAgent, AssistantAgent, UserProxyAgent, GroupChatManager, GroupChat, OpenAIWrapper
+
+from ... import Agent, AssistantAgent, ConversableAgent, GroupChat, GroupChatManager, OpenAIWrapper, UserProxyAgent
from ...browser_utils import SimpleTextBrowser
from ...code_utils import content_str
-from datetime import datetime
-from ...token_count_utils import count_token, get_max_token_limit
from ...oai.openai_utils import filter_config
+from ...token_count_utils import count_token, get_max_token_limit
logger = logging.getLogger(__name__)
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index 9a99386b84f..c3394a96bb6 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -7,7 +7,6 @@
import re
import warnings
from collections import defaultdict
-from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
from openai import BadRequestError
@@ -17,6 +16,7 @@
from .._pydantic import model_dump
from ..cache.cache import AbstractCache
from ..code_utils import (
+ PYTHON_VARIANTS,
UNKNOWN,
check_can_use_docker_or_throw,
content_str,
@@ -29,10 +29,10 @@
from ..coding.factory import CodeExecutorFactory
from ..formatting_utils import colored
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
+from ..io.base import IOStream
from ..oai.client import ModelClient, OpenAIWrapper
-from ..runtime_logging import log_new_agent, logging_enabled
+from ..runtime_logging import log_event, log_new_agent, logging_enabled
from .agent import Agent, LLMAgent
-from ..io.base import IOStream
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary
@@ -76,6 +76,7 @@ def __init__(
llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Union[str, Dict] = "",
description: Optional[str] = None,
+ chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
):
"""
Args:
@@ -121,6 +122,9 @@ def __init__(
default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
+ chat_messages (dict or None): the previous chat messages that this agent had in the past with other agents.
+ Can be used to give the agent a memory by providing the chat history. This will allow the agent to
+ resume previous had conversations. Defaults to an empty chat history.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
@@ -130,7 +134,11 @@ def __init__(
self._name = name
# a dictionary of conversations, default value is list
- self._oai_messages = defaultdict(list)
+ if chat_messages is None:
+ self._oai_messages = defaultdict(list)
+ else:
+ self._oai_messages = chat_messages
+
self._oai_system_message = [{"content": system_message, "role": "system"}]
self._description = description if description is not None else system_message
self._is_termination_msg = (
@@ -138,6 +146,15 @@ def __init__(
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
+ # Take a copy to avoid modifying the given dict
+ if isinstance(llm_config, dict):
+ try:
+ llm_config = copy.deepcopy(llm_config)
+ except TypeError as e:
+ raise TypeError(
+ "Please implement __deepcopy__ method for each value class in llm_config to support deepcopy."
+ " Refer to the docs for more details: https://microsoft.github.io/autogen/docs/topics/llm_configuration#adding-http-client-in-llm_config-for-proxy"
+ ) from e
self._validate_llm_config(llm_config)
@@ -416,10 +433,15 @@ def reply_func_from_nested_chats(
reply_func_from_nested_chats = self._summary_from_nested_chats
if not callable(reply_func_from_nested_chats):
raise ValueError("reply_func_from_nested_chats must be a callable")
- reply_func = partial(reply_func_from_nested_chats, chat_queue)
+
+ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
+ return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
+
+ functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)
+
self.register_reply(
trigger,
- reply_func,
+ wrapped_reply_func,
position,
kwargs.get("config"),
kwargs.get("reset_config"),
@@ -564,6 +586,11 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
if message.get("role") in ["function", "tool"]:
oai_message["role"] = message.get("role")
+ elif "override_role" in message:
+ # If we have a direction to override the role then set the
+ # role accordingly. Used to customise the role for the
+ # select speaker prompt.
+ oai_message["role"] = message.get("override_role")
else:
oai_message["role"] = role
@@ -745,6 +772,9 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
valid = self._append_oai_message(message, "user", sender)
+ if logging_enabled():
+ log_event(self, "received_message", message=message, sender=sender.name, valid=valid)
+
if not valid:
raise ValueError(
"Received message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
@@ -907,6 +937,7 @@ def my_summary_method(
One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect
on the conversation and extract a summary when summary_method is "reflection_with_llm".
The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
+ Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system".
message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message.
- If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context.
If dict, it may contain the following reserved fields (either content or tool_calls need to be provided).
@@ -1116,11 +1147,18 @@ def my_summary_method(
@staticmethod
def _last_msg_as_summary(sender, recipient, summary_args) -> str:
"""Get a chat summary from the last message of the recipient."""
+ summary = ""
try:
- summary = recipient.last_message(sender)["content"].replace("TERMINATE", "")
+ content = recipient.last_message(sender)["content"]
+ if isinstance(content, str):
+ summary = content.replace("TERMINATE", "")
+ elif isinstance(content, list):
+ # Remove the `TERMINATE` word in the content list.
+ summary = "\n".join(
+ x["text"].replace("TERMINATE", "") for x in content if isinstance(x, dict) and "text" in x
+ )
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
- summary = ""
return summary
@staticmethod
@@ -1131,8 +1169,13 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
raise ValueError("The summary_prompt must be a string.")
msg_list = recipient.chat_messages_for_summary(sender)
agent = sender if recipient is None else recipient
+ role = summary_args.get("summary_role", None)
+ if role and not isinstance(role, str):
+ raise ValueError("The summary_role in summary_arg must be a string.")
try:
- summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"))
+ summary = sender._reflection_with_llm(
+ prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role
+ )
except BadRequestError as e:
warnings.warn(
f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning
@@ -1141,7 +1184,12 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
return summary
def _reflection_with_llm(
- self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None
+ self,
+ prompt,
+ messages,
+ llm_agent: Optional[Agent] = None,
+ cache: Optional[AbstractCache] = None,
+ role: Union[str, None] = None,
) -> str:
"""Get a chat summary using reflection with an llm client based on the conversation history.
@@ -1150,10 +1198,14 @@ def _reflection_with_llm(
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (AbstractCache or None): the cache client to be used for this conversation.
+ role (str): the role of the message, usually "system" or "user". Default is "system".
"""
+ if not role:
+ role = "system"
+
system_msg = [
{
- "role": "system",
+ "role": role,
"content": prompt,
}
]
@@ -1168,6 +1220,23 @@ def _reflection_with_llm(
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
return response
+ def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Check the chat queue and add the "sender" key if it's missing.
+
+ Args:
+ chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information.
+
+ Returns:
+ List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing.
+ """
+ chat_queue_with_sender = []
+ for chat_info in chat_queue:
+ if chat_info.get("sender") is None:
+ chat_info["sender"] = self
+ chat_queue_with_sender.append(chat_info)
+ return chat_queue_with_sender
+
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""(Experimental) Initiate chats with multiple agents.
@@ -1177,16 +1246,12 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
- _chat_queue = chat_queue.copy()
- for chat_info in _chat_queue:
- chat_info["sender"] = self
+ _chat_queue = self._check_chat_queue_for_sender(chat_queue)
self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats
async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
- _chat_queue = chat_queue.copy()
- for chat_info in _chat_queue:
- chat_info["sender"] = self
+ _chat_queue = self._check_chat_queue_for_sender(chat_queue)
self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats
@@ -1299,7 +1364,7 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
if extracted_response is None:
- warnings.warn("Extracted_response from {response} is None.", UserWarning)
+ warnings.warn(f"Extracted_response from {response} is None.", UserWarning)
return None
# ensure function and tool calls will be accepted when sent back to the LLM
if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
@@ -1907,6 +1972,15 @@ def generate_reply(
continue
if self._match_trigger(reply_func_tuple["trigger"], sender):
final, reply = reply_func(self, messages=messages, sender=sender, config=reply_func_tuple["config"])
+ if logging_enabled():
+ log_event(
+ self,
+ "reply_func_executed",
+ reply_func_module=reply_func.__module__,
+ reply_func_name=reply_func.__name__,
+ final=final,
+ reply=reply,
+ )
if final:
return reply
return self._default_auto_reply
@@ -2076,7 +2150,7 @@ def execute_code_blocks(self, code_blocks):
)
if lang in ["bash", "shell", "sh"]:
exitcode, logs, image = self.run_code(code, lang=lang, **self._code_execution_config)
- elif lang in ["python", "Python"]:
+ elif lang in PYTHON_VARIANTS:
if code.startswith("# filename: "):
filename = code[11 : code.find("\n")].strip()
else:
@@ -2259,30 +2333,54 @@ def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Un
"""
if message is None:
message = self.get_human_input(">")
+
+ return self._handle_carryover(message, kwargs)
+
+ def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
+ if not kwargs.get("carryover"):
+ return message
+
if isinstance(message, str):
return self._process_carryover(message, kwargs)
+
elif isinstance(message, dict):
- message = message.copy()
- # TODO: Do we need to do the following?
- # if message.get("content") is None:
- # message["content"] = self.get_human_input(">")
- message["content"] = self._process_carryover(message.get("content", ""), kwargs)
- return message
+ if isinstance(message.get("content"), str):
+ # Makes sure the original message is not mutated
+ message = message.copy()
+ message["content"] = self._process_carryover(message["content"], kwargs)
+ elif isinstance(message.get("content"), list):
+ # Makes sure the original message is not mutated
+ message = message.copy()
+ message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
+ else:
+ raise InvalidCarryOverType("Carryover should be a string or a list of strings.")
- def _process_carryover(self, message: str, kwargs: dict) -> str:
- carryover = kwargs.get("carryover")
- if carryover:
- # if carryover is string
- if isinstance(carryover, str):
- message += "\nContext: \n" + carryover
- elif isinstance(carryover, list):
- message += "\nContext: \n" + ("\n").join([t for t in carryover])
- else:
- raise InvalidCarryOverType(
- "Carryover should be a string or a list of strings. Not adding carryover to the message."
- )
return message
+ def _process_carryover(self, content: str, kwargs: dict) -> str:
+ # Makes sure there's a carryover
+ if not kwargs.get("carryover"):
+ return content
+
+ # if carryover is string
+ if isinstance(kwargs["carryover"], str):
+ content += "\nContext: \n" + kwargs["carryover"]
+ elif isinstance(kwargs["carryover"], list):
+ content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
+ else:
+ raise InvalidCarryOverType(
+ "Carryover should be a string or a list of strings. Not adding carryover to the message."
+ )
+ return content
+
+ def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
+ """Prepends the context to a multimodal message."""
+ # Makes sure there's a carryover
+ if not kwargs.get("carryover"):
+ return content
+
+ return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content
+
async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
"""Generate the initial message for the agent.
If message is None, input() will be called to get the initial message.
@@ -2295,12 +2393,8 @@ async def a_generate_init_message(self, message: Union[Dict, str, None], **kwarg
"""
if message is None:
message = await self.a_get_human_input(">")
- if isinstance(message, str):
- return self._process_carryover(message, kwargs)
- elif isinstance(message, dict):
- message = message.copy()
- message["content"] = self._process_carryover(message["content"], kwargs)
- return message
+
+ return self._handle_carryover(message, kwargs)
def register_function(self, function_map: Dict[str, Union[Callable, None]]):
"""Register functions to the agent.
diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py
index 6c0ecec90fe..b1141cfdacf 100644
--- a/autogen/agentchat/groupchat.py
+++ b/autogen/agentchat/groupchat.py
@@ -1,3 +1,5 @@
+import copy
+import json
import logging
import random
import re
@@ -5,13 +7,15 @@
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
-from .agent import Agent
-from .conversable_agent import ConversableAgent
-from ..io.base import IOStream
from ..code_utils import content_str
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
+from ..formatting_utils import colored
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
+from ..io.base import IOStream
from ..runtime_logging import log_new_agent, logging_enabled
+from .agent import Agent
+from .chat import ChatResult
+from .conversable_agent import ConversableAgent
logger = logging.getLogger(__name__)
@@ -28,6 +32,28 @@ class GroupChat:
When set to True and when a message is a function call suggestion,
the next speaker will be chosen from an agent which contains the corresponding function name
in its `function_map`.
+ - select_speaker_message_template: customize the select speaker message (used in "auto" speaker selection), which appears first in the message context and generally includes the agent descriptions and list of agents. If the string contains "{roles}" it will replaced with the agent's and their role descriptions. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You are in a role play game. The following roles are available:
+ {roles}.
+ Read the following conversation.
+ Then select the next role from {agentlist} to play. Only return the role."
+ - select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
+ - select_speaker_auto_multiple_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains multiple agent names. This prompt guides the LLM to return just one agent name. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."
+ - select_speaker_auto_none_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains no agent names. This prompt guides the LLM to return an agent name and provides a list of agent names. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
+ "You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ The only names that are accepted are {agentlist}.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."
- speaker_selection_method: the method for selecting the next speaker. Default is "auto".
Could be any of the following (case insensitive), will raise ValueError if not recognized:
- "auto": the next speaker is selected automatically by LLM.
@@ -44,6 +70,15 @@ def custom_speaker_selection_func(
last_speaker: Agent, groupchat: GroupChat
) -> Union[Agent, str, None]:
```
+ - max_retries_for_selecting_speaker: the maximum number of times the speaker selection requery process will run.
+ If, during speaker selection, multiple agent names or no agent names are returned by the LLM as the next agent, it will be queried again up to the maximum number
+ of times until a single agent is returned or it exhausts the maximum attempts.
+ Applies only to "auto" speaker selection method.
+ Default is 2.
+ - select_speaker_auto_verbose: whether to output the select speaker responses and selections
+ If set to True, the outputs from the two agents in the nested select speaker chat will be output, along with
+ whether the responses were successful, or not, in selecting an agent
+ Applies only to "auto" speaker selection method.
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
Default is True, in which case all speakers are allowed to speak consecutively.
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
@@ -61,6 +96,7 @@ def custom_speaker_selection_func(
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
- send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False)
+ - role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
"""
agents: List[Agent]
@@ -69,11 +105,34 @@ def custom_speaker_selection_func(
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto"
+ max_retries_for_selecting_speaker: Optional[int] = 2
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Literal["allowed", "disallowed", None] = None
enable_clear_history: Optional[bool] = False
send_introductions: bool = False
+ select_speaker_message_template: str = """You are in a role play game. The following roles are available:
+ {roles}.
+ Read the following conversation.
+ Then select the next role from {agentlist} to play. Only return the role."""
+ select_speaker_prompt_template: str = (
+ "Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
+ )
+ select_speaker_auto_multiple_template: str = """You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_auto_none_template: str = """You didn't choose a speaker. As a reminder, to determine the speaker use these prioritised rules:
+ 1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
+ 2. If it refers to the "next" speaker name, choose that name
+ 3. Otherwise, choose the first provided speaker's name in the context
+ The names are case-sensitive and should not be abbreviated or changed.
+ The only names that are accepted are {agentlist}.
+ Respond with ONLY the name of the speaker and DO NOT provide a reason."""
+ select_speaker_auto_verbose: Optional[bool] = False
+ role_for_select_speaker_messages: Optional[str] = "system"
_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
_VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None]
@@ -162,6 +221,37 @@ def __post_init__(self):
agents=self.agents,
)
+ # Check select speaker messages, prompts, roles, and retries have values
+ if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0:
+ raise ValueError("select_speaker_message_template cannot be empty or None.")
+
+ if self.select_speaker_prompt_template is None or len(self.select_speaker_prompt_template) == 0:
+ raise ValueError("select_speaker_prompt_template cannot be empty or None.")
+
+ if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0:
+ raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
+
+ if self.select_speaker_auto_multiple_template is None or len(self.select_speaker_auto_multiple_template) == 0:
+ raise ValueError("select_speaker_auto_multiple_template cannot be empty or None.")
+
+ if self.select_speaker_auto_none_template is None or len(self.select_speaker_auto_none_template) == 0:
+ raise ValueError("select_speaker_auto_none_template cannot be empty or None.")
+
+ if self.max_retries_for_selecting_speaker is None or len(self.role_for_select_speaker_messages) == 0:
+ raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
+
+ # Validate max select speakers retries
+ if self.max_retries_for_selecting_speaker is None or not isinstance(
+ self.max_retries_for_selecting_speaker, int
+ ):
+ raise ValueError("max_retries_for_selecting_speaker cannot be None or non-int")
+ elif self.max_retries_for_selecting_speaker < 0:
+ raise ValueError("max_retries_for_selecting_speaker must be greater than or equal to zero")
+
+ # Validate select_speaker_auto_verbose
+ if self.select_speaker_auto_verbose is None or not isinstance(self.select_speaker_auto_verbose, bool):
+ raise ValueError("select_speaker_auto_verbose cannot be None or non-bool")
+
@property
def agent_names(self) -> List[str]:
"""Return the names of the agents in the group chat."""
@@ -232,17 +322,22 @@ def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
if agents is None:
agents = self.agents
- return f"""You are in a role play game. The following roles are available:
-{self._participant_roles(agents)}.
-Read the following conversation.
-Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
+ roles = self._participant_roles(agents)
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return_msg = self.select_speaker_message_template.format(roles=roles, agentlist=agentlist)
+ return return_msg
def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
if agents is None:
agents = self.agents
- return f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."
+
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return_prompt = self.select_speaker_prompt_template.format(agentlist=agentlist)
+ return return_prompt
def introductions_msg(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
@@ -411,7 +506,7 @@ def _prepare_and_select_agents(
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
elif speaker_selection_method.lower() == "random":
selected_agent = self.random_select_speaker(graph_eligible_agents)
- else:
+ else: # auto
selected_agent = None
select_speaker_messages = self.messages.copy()
# If last message is a tool call or function call, blank the call so the api doesn't throw
@@ -419,30 +514,34 @@ def _prepare_and_select_agents(
select_speaker_messages[-1] = dict(select_speaker_messages[-1], function_call=None)
if select_speaker_messages[-1].get("tool_calls", False):
select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None)
- select_speaker_messages = select_speaker_messages + [
- {"role": "system", "content": self.select_speaker_prompt(graph_eligible_agents)}
- ]
return selected_agent, graph_eligible_agents, select_speaker_messages
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
- """Select the next speaker."""
+ """Select the next speaker (with requery)."""
+
+ # Prepare the list of available agents and select an agent if selection method allows (non-auto)
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
- # auto speaker selection
- selector.update_system_message(self.select_speaker_msg(agents))
- final, name = selector.generate_oai_reply(messages)
- return self._finalize_speaker(last_speaker, final, name, agents)
+ elif self.speaker_selection_method == "manual":
+ # An agent has not been selected while in manual mode, so move to the next agent
+ return self.next_agent(last_speaker)
+
+ # auto speaker selection with 2-agent chat
+ return self._auto_select_speaker(last_speaker, selector, messages, agents)
async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent) -> Agent:
- """Select the next speaker."""
+ """Select the next speaker (with requery), asynchronously."""
+
selected_agent, agents, messages = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
- # auto speaker selection
- selector.update_system_message(self.select_speaker_msg(agents))
- final, name = await selector.a_generate_oai_reply(messages)
- return self._finalize_speaker(last_speaker, final, name, agents)
+ elif self.speaker_selection_method == "manual":
+ # An agent has not been selected while in manual mode, so move to the next agent
+ return self.next_agent(last_speaker)
+
+ # auto speaker selection with 2-agent chat
+ return await self.a_auto_select_speaker(last_speaker, selector, messages, agents)
def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents: Optional[List[Agent]]) -> Agent:
if not final:
@@ -462,6 +561,296 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)
+ def _auto_select_speaker(
+ self,
+ last_speaker: Agent,
+ selector: ConversableAgent,
+ messages: Optional[List[Dict]],
+ agents: Optional[List[Agent]],
+ ) -> Agent:
+ """Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying.
+
+ Speaker selection for "auto" speaker selection method:
+ 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat
+ 2. Inject the group messages into the new chat
+ 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent:
+ - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response
+ 4. Chat continues until a single agent is nominated or there are no more attempts left
+ 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned
+
+ Args:
+ last_speaker Agent: The previous speaker in the group chat
+ selector ConversableAgent:
+ messages Optional[List[Dict]]: Current chat messages
+ agents Optional[List[Agent]]: Valid list of agents for speaker selection
+
+ Returns:
+ Dict: a counter for mentioned agents.
+ """
+
+ # If no agents are passed in, assign all the group chat's agents
+ if agents is None:
+ agents = self.agents
+
+ # The maximum number of speaker selection attempts (including requeries)
+ # is the initial speaker selection attempt plus the maximum number of retries.
+ # We track these and use them in the validation function as we can't
+ # access the max_turns from within validate_speaker_name.
+ max_attempts = 1 + self.max_retries_for_selecting_speaker
+ attempts_left = max_attempts
+ attempt = 0
+
+ # Registered reply function for checking_agent, checks the result of the response for agent names
+ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]:
+
+ # The number of retries left, starting at max_retries_for_selecting_speaker
+ nonlocal attempts_left
+ nonlocal attempt
+
+ attempt = attempt + 1
+ attempts_left = attempts_left - 1
+
+ return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents)
+
+ # Two-agent chat for speaker selection
+
+ # Agent for checking the response from the speaker_select_agent
+ checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)
+
+ # Register the speaker validation function with the checking agent
+ checking_agent.register_reply(
+ [ConversableAgent, None],
+ reply_func=validate_speaker_name, # Validate each response
+ remove_other_reply_funcs=True,
+ )
+
+ # Agent for selecting a single agent name from the response
+ speaker_selection_agent = ConversableAgent(
+ "speaker_selection_agent",
+ system_message=self.select_speaker_msg(agents),
+ chat_messages={checking_agent: messages},
+ llm_config=selector.llm_config,
+ human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
+ )
+
+ # Run the speaker selection chat
+ result = checking_agent.initiate_chat(
+ speaker_selection_agent,
+ cache=None, # don't use caching for the speaker selection chat
+ message={
+ "content": self.select_speaker_prompt(agents),
+ "override_role": self.role_for_select_speaker_messages,
+ },
+ max_turns=2
+ * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
+ clear_history=False,
+ silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute
+ )
+
+ return self._process_speaker_selection_result(result, last_speaker, agents)
+
+ async def a_auto_select_speaker(
+ self,
+ last_speaker: Agent,
+ selector: ConversableAgent,
+ messages: Optional[List[Dict]],
+ agents: Optional[List[Agent]],
+ ) -> Agent:
+ """(Asynchronous) Selects next speaker for the "auto" speaker selection method. Utilises its own two-agent chat to determine the next speaker and supports requerying.
+
+ Speaker selection for "auto" speaker selection method:
+ 1. Create a two-agent chat with a speaker selector agent and a speaker validator agent, like a nested chat
+ 2. Inject the group messages into the new chat
+ 3. Run the two-agent chat, evaluating the result of response from the speaker selector agent:
+ - If a single agent is provided then we return it and finish. If not, we add an additional message to this nested chat in an attempt to guide the LLM to a single agent response
+ 4. Chat continues until a single agent is nominated or there are no more attempts left
+ 5. If we run out of turns and no single agent can be determined, the next speaker in the list of agents is returned
+
+ Args:
+ last_speaker Agent: The previous speaker in the group chat
+ selector ConversableAgent:
+ messages Optional[List[Dict]]: Current chat messages
+ agents Optional[List[Agent]]: Valid list of agents for speaker selection
+
+ Returns:
+ Dict: a counter for mentioned agents.
+ """
+
+ # If no agents are passed in, assign all the group chat's agents
+ if agents is None:
+ agents = self.agents
+
+ # The maximum number of speaker selection attempts (including requeries)
+ # We track these and use them in the validation function as we can't
+ # access the max_turns from within validate_speaker_name
+ max_attempts = 1 + self.max_retries_for_selecting_speaker
+ attempts_left = max_attempts
+ attempt = 0
+
+ # Registered reply function for checking_agent, checks the result of the response for agent names
+ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Union[str, Dict, None]]:
+
+ # The number of retries left, starting at max_retries_for_selecting_speaker
+ nonlocal attempts_left
+ nonlocal attempt
+
+ attempt = attempt + 1
+ attempts_left = attempts_left - 1
+
+ return self._validate_speaker_name(recipient, messages, sender, config, attempts_left, attempt, agents)
+
+ # Two-agent chat for speaker selection
+
+ # Agent for checking the response from the speaker_select_agent
+ checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)
+
+ # Register the speaker validation function with the checking agent
+ checking_agent.register_reply(
+ [ConversableAgent, None],
+ reply_func=validate_speaker_name, # Validate each response
+ remove_other_reply_funcs=True,
+ )
+
+ # Agent for selecting a single agent name from the response
+ speaker_selection_agent = ConversableAgent(
+ "speaker_selection_agent",
+ system_message=self.select_speaker_msg(agents),
+ chat_messages={checking_agent: messages},
+ llm_config=selector.llm_config,
+ human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
+ )
+
+ # Run the speaker selection chat
+ result = await checking_agent.a_initiate_chat(
+ speaker_selection_agent,
+ cache=None, # don't use caching for the speaker selection chat
+ message=self.select_speaker_prompt(agents),
+ max_turns=2
+ * max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
+ clear_history=False,
+ silent=not self.select_speaker_auto_verbose, # Base silence on the verbose attribute
+ )
+
+ return self._process_speaker_selection_result(result, last_speaker, agents)
+
+ def _validate_speaker_name(
+ self, recipient, messages, sender, config, attempts_left, attempt, agents
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Validates the speaker response for each round in the internal 2-agent
+ chat within the auto select speaker method.
+
+ Used by auto_select_speaker and a_auto_select_speaker.
+ """
+
+ # Output the query and requery results
+ if self.select_speaker_auto_verbose:
+ iostream = IOStream.get_default()
+
+ # Validate the speaker name selected
+ select_name = messages[-1]["content"].strip()
+
+ mentions = self._mentioned_agents(select_name, agents)
+
+ if len(mentions) == 1:
+
+ # Success on retry, we have just one name mentioned
+ selected_agent_name = next(iter(mentions))
+
+ # Add the selected agent to the response so we can return it
+ messages.append({"role": "user", "content": f"[AGENT SELECTED]{selected_agent_name}"})
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} successfully selected: {selected_agent_name}",
+ "green",
+ ),
+ flush=True,
+ )
+
+ elif len(mentions) > 1:
+ # More than one name on requery so add additional reminder prompt for next retry
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt {attempt} of {attempt + attempts_left} failed as it included multiple agent names.",
+ "red",
+ ),
+ flush=True,
+ )
+
+ if attempts_left:
+ # Message to return to the chat for the next attempt
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return True, {
+ "content": self.select_speaker_auto_multiple_template.format(agentlist=agentlist),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ # Final failure, no attempts left
+ messages.append(
+ {
+ "role": "user",
+ "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it returned multiple names.",
+ }
+ )
+
+ else:
+ # No names at all on requery so add additional reminder prompt for next retry
+
+ if self.select_speaker_auto_verbose:
+ iostream.print(
+ colored(
+ f">>>>>>>> Select speaker attempt #{attempt} failed as it did not include any agent names.",
+ "red",
+ ),
+ flush=True,
+ )
+
+ if attempts_left:
+ # Message to return to the chat for the next attempt
+ agentlist = f"{[agent.name for agent in agents]}"
+
+ return True, {
+ "content": self.select_speaker_auto_none_template.format(agentlist=agentlist),
+ "override_role": self.role_for_select_speaker_messages,
+ }
+ else:
+ # Final failure, no attempts left
+ messages.append(
+ {
+ "role": "user",
+ "content": f"[AGENT SELECTION FAILED]Select speaker attempt #{attempt} of {attempt + attempts_left} failed as it did not include any agent names.",
+ }
+ )
+
+ return True, None
+
+ def _process_speaker_selection_result(self, result, last_speaker: ConversableAgent, agents: Optional[List[Agent]]):
+ """Checks the result of the auto_select_speaker function, returning the
+ agent to speak.
+
+ Used by auto_select_speaker and a_auto_select_speaker."""
+ if len(result.chat_history) > 0:
+
+ # Use the final message, which will have the selected agent or reason for failure
+ final_message = result.chat_history[-1]["content"]
+
+ if "[AGENT SELECTED]" in final_message:
+
+ # Have successfully selected an agent, return it
+ return self.agent_by_name(final_message.replace("[AGENT SELECTED]", ""))
+
+ else: # "[AGENT SELECTION FAILED]"
+
+ # Failed to select an agent, so we'll select the next agent in the list
+ next_agent = self.next_agent(last_speaker, agents)
+
+ # No agent, return the failed reason
+ return next_agent
+
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
@@ -478,6 +867,10 @@ def _participant_roles(self, agents: List[Agent] = None) -> str:
def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[List[Agent]]) -> Dict:
"""Counts the number of times each agent is mentioned in the provided message content.
+ Agent names will match under any of the following conditions (all case-sensitive):
+ - Exact name match
+ - If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
+ - If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
Args:
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
@@ -496,9 +889,17 @@ def _mentioned_agents(self, message_content: Union[str, List], agents: Optional[
mentions = dict()
for agent in agents:
+ # Finds agent mentions, taking word boundaries into account,
+ # accommodates escaping underscores and underscores as spaces
regex = (
- r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
- ) # Finds agent mentions, taking word boundaries into account
+ r"(?<=\W)("
+ + re.escape(agent.name)
+ + r"|"
+ + re.escape(agent.name.replace("_", " "))
+ + r"|"
+ + re.escape(agent.name.replace("_", r"\_"))
+ + r")(?=\W)"
+ )
count = len(re.findall(regex, f" {message_content} ")) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
@@ -718,6 +1119,290 @@ async def a_run_chat(
a.previous_cache = None
return True, None
+ def resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: str = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ self.send(message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True)
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ async def a_resume(
+ self,
+ messages: Union[List[Dict], str],
+ remove_termination_string: str = None,
+ silent: Optional[bool] = False,
+ ) -> Tuple[ConversableAgent, Dict]:
+ """Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
+ as per the original group chat.
+
+ Args:
+ - messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
+ - remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
+ - silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
+
+ Returns:
+ - Tuple[ConversableAgent, Dict]: A tuple containing the last agent who spoke and their message
+ """
+
+ # Convert messages from string to messages list, if needed
+ if isinstance(messages, str):
+ messages = self.messages_from_string(messages)
+ elif isinstance(messages, list) and all(isinstance(item, dict) for item in messages):
+ messages = copy.deepcopy(messages)
+ else:
+ raise Exception("Messages is not of type str or List[Dict]")
+
+ # Clean up the objects, ensuring there are no messages in the agents and group chat
+
+ # Clear agent message history
+ for agent in self._groupchat.agents:
+ if isinstance(agent, ConversableAgent):
+ agent.clear_history()
+
+ # Clear Manager message history
+ self.clear_history()
+
+ # Clear GroupChat messages
+ self._groupchat.reset()
+
+ # Validation of message and agents
+
+ try:
+ self._valid_resume_messages(messages)
+ except:
+ raise
+
+ # Load the messages into the group chat
+ for i, message in enumerate(messages):
+
+ if "name" in message:
+ message_speaker_agent = self._groupchat.agent_by_name(message["name"])
+ else:
+ # If there's no name, assign the group chat manager (this is an indication the ChatResult messages was used instead of groupchat.messages as state)
+ message_speaker_agent = self
+ message["name"] = self.name
+
+ # If it wasn't an agent speaking, it may be the manager
+ if not message_speaker_agent and message["name"] == self.name:
+ message_speaker_agent = self
+
+ # Add previous messages to each agent (except their own messages and the last message, as we'll kick off the conversation with it)
+ if i != len(messages) - 1:
+ for agent in self._groupchat.agents:
+ if agent.name != message["name"]:
+ await self.a_send(
+ message, self._groupchat.agent_by_name(agent.name), request_reply=False, silent=True
+ )
+
+ # Add previous message to the new groupchat, if it's an admin message the name may not match so add the message directly
+ if message_speaker_agent:
+ self._groupchat.append(message, message_speaker_agent)
+ else:
+ self._groupchat.messages.append(message)
+
+ # Last speaker agent
+ last_speaker_name = message["name"]
+
+ # Last message to check for termination (we could avoid this by ignoring termination check for resume in the future)
+ last_message = message
+
+ # Get last speaker as an agent
+ previous_last_agent = self._groupchat.agent_by_name(name=last_speaker_name)
+
+ # If we didn't match a last speaker agent, we check that it's the group chat's admin name and assign the manager, if so
+ if not previous_last_agent and (
+ last_speaker_name == self._groupchat.admin_name or last_speaker_name == self.name
+ ):
+ previous_last_agent = self
+
+ # Termination removal and check
+ self._process_resume_termination(remove_termination_string, messages)
+
+ if not silent:
+ iostream = IOStream.get_default()
+ iostream.print(
+ f"Prepared group chat with {len(messages)} messages, the last speaker is",
+ colored(last_speaker_name, "yellow"),
+ flush=True,
+ )
+
+ # Update group chat settings for resuming
+ self._groupchat.send_introductions = False
+
+ return previous_last_agent, last_message
+
+ def _valid_resume_messages(self, messages: List[Dict]):
+ """Validates the messages used for resuming
+
+ args:
+ messages (List[Dict]): list of messages to resume with
+
+ returns:
+ - bool: Whether they are valid for resuming
+ """
+ # Must have messages to start with, otherwise they should run run_chat
+ if not messages:
+ raise Exception(
+ "Cannot resume group chat as no messages were provided. Use GroupChatManager.run_chat or ConversableAgent.initiate_chat to start a new chat."
+ )
+
+ # Check that all agents in the chat messages exist in the group chat
+ for message in messages:
+ if message.get("name"):
+ if (
+ not self._groupchat.agent_by_name(message["name"])
+ and not message["name"] == self._groupchat.admin_name # ignore group chat's name
+ and not message["name"] == self.name # ignore group chat manager's name
+ ):
+ raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}")
+
+ def _process_resume_termination(self, remove_termination_string: str, messages: List[Dict]):
+ """Removes termination string, if required, and checks if termination may occur.
+
+ args:
+ remove_termination_string (str): termination string to remove from the last message
+
+ returns:
+ None
+ """
+
+ last_message = messages[-1]
+
+ # Replace any given termination string in the last message
+ if remove_termination_string:
+ if messages[-1].get("content") and remove_termination_string in messages[-1]["content"]:
+ messages[-1]["content"] = messages[-1]["content"].replace(remove_termination_string, "")
+
+ # Check if the last message meets termination (if it has one)
+ if self._is_termination_msg:
+ if self._is_termination_msg(last_message):
+ logger.warning("WARNING: Last message meets termination criteria and this may terminate the chat.")
+
+ def messages_from_string(self, message_string: str) -> List[Dict]:
+ """Reads the saved state of messages in Json format for resume and returns as a messages list
+
+ args:
+ - message_string: Json string, the saved state
+
+ returns:
+ - List[Dict]: List of messages
+ """
+ try:
+ state = json.loads(message_string)
+ except json.JSONDecodeError:
+ raise Exception("Messages string is not a valid JSON string")
+
+ return state
+
+ def messages_to_string(self, messages: List[Dict]) -> str:
+ """Converts the provided messages into a Json string that can be used for resuming the chat.
+ The state is made up of a list of messages
+
+ args:
+ - messages (List[Dict]): set of messages to convert to a string
+
+ returns:
+ - str: Json representation of the messages which can be persisted for resuming later
+ """
+
+ return json.dumps(messages)
+
def _raise_exception_on_async_reply_functions(self) -> None:
"""Raise an exception if any async reply functions are registered.
diff --git a/autogen/agentchat/user_proxy_agent.py b/autogen/agentchat/user_proxy_agent.py
index d1d7f89ab2b..a80296a8355 100644
--- a/autogen/agentchat/user_proxy_agent.py
+++ b/autogen/agentchat/user_proxy_agent.py
@@ -1,7 +1,7 @@
from typing import Callable, Dict, List, Literal, Optional, Union
+from ..runtime_logging import log_new_agent, logging_enabled
from .conversable_agent import ConversableAgent
-from ..runtime_logging import logging_enabled, log_new_agent
class UserProxyAgent(ConversableAgent):
diff --git a/autogen/agentchat/utils.py b/autogen/agentchat/utils.py
index eef3741605d..b32c2f5f0a0 100644
--- a/autogen/agentchat/utils.py
+++ b/autogen/agentchat/utils.py
@@ -1,5 +1,5 @@
import re
-from typing import Any, Callable, Dict, List, Tuple, Union
+from typing import Any, Callable, Dict, List, Union
from .agent import Agent
@@ -26,33 +26,46 @@ def consolidate_chat_info(chat_info, uniform_sender=None) -> None:
), "llm client must be set in either the recipient or sender when summary_method is reflection_with_llm."
-def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str, any]]:
+def gather_usage_summary(agents: List[Agent]) -> Dict[Dict[str, Dict], Dict[str, Dict]]:
r"""Gather usage summary from all agents.
Args:
agents: (list): List of agents.
Returns:
- tuple: (total_usage_summary, actual_usage_summary)
+ dictionary: A dictionary containing two keys:
+ - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
+ - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
Example:
```python
- total_usage_summary = {
- "total_cost": 0.0006090000000000001,
- "gpt-35-turbo": {
- "cost": 0.0006090000000000001,
- "prompt_tokens": 242,
- "completion_tokens": 123,
- "total_tokens": 365
+ {
+ "usage_including_cached_inference" : {
+ "total_cost": 0.0006090000000000001,
+ "gpt-35-turbo": {
+ "cost": 0.0006090000000000001,
+ "prompt_tokens": 242,
+ "completion_tokens": 123,
+ "total_tokens": 365
+ },
+ },
+
+ "usage_excluding_cached_inference" : {
+ "total_cost": 0.0006090000000000001,
+ "gpt-35-turbo": {
+ "cost": 0.0006090000000000001,
+ "prompt_tokens": 242,
+ "completion_tokens": 123,
+ "total_tokens": 365
+ },
}
}
```
Note:
- `actual_usage_summary` follows the same format.
- If none of the agents incurred any cost (not having a client), then the total_usage_summary and actual_usage_summary will be `{'total_cost': 0}`.
+ If none of the agents incurred any cost (not having a client), then the usage_including_cached_inference and usage_excluding_cached_inference will be `{'total_cost': 0}`.
"""
def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None:
@@ -69,15 +82,18 @@ def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, An
usage_summary[model]["completion_tokens"] += data.get("completion_tokens", 0)
usage_summary[model]["total_tokens"] += data.get("total_tokens", 0)
- total_usage_summary = {"total_cost": 0}
- actual_usage_summary = {"total_cost": 0}
+ usage_including_cached_inference = {"total_cost": 0}
+ usage_excluding_cached_inference = {"total_cost": 0}
for agent in agents:
if getattr(agent, "client", None):
- aggregate_summary(total_usage_summary, agent.client.total_usage_summary)
- aggregate_summary(actual_usage_summary, agent.client.actual_usage_summary)
+ aggregate_summary(usage_including_cached_inference, agent.client.total_usage_summary)
+ aggregate_summary(usage_excluding_cached_inference, agent.client.actual_usage_summary)
- return total_usage_summary, actual_usage_summary
+ return {
+ "usage_including_cached_inference": usage_including_cached_inference,
+ "usage_excluding_cached_inference": usage_excluding_cached_inference,
+ }
def parse_tags_from_content(tag: str, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Dict[str, str]]]:
diff --git a/autogen/browser_utils.py b/autogen/browser_utils.py
index 41d2d62f825..99e51fcd4ca 100644
--- a/autogen/browser_utils.py
+++ b/autogen/browser_utils.py
@@ -1,14 +1,15 @@
+import io
import json
+import mimetypes
import os
-import requests
import re
-import markdownify
-import io
import uuid
-import mimetypes
+from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urljoin, urlparse
+
+import markdownify
+import requests
from bs4 import BeautifulSoup
-from typing import Any, Dict, List, Optional, Union, Tuple
# Optional PDF support
IS_PDF_CAPABLE = False
@@ -35,6 +36,7 @@ def __init__(
start_page: Optional[str] = None,
viewport_size: Optional[int] = 1024 * 8,
downloads_folder: Optional[Union[str, None]] = None,
+ bing_base_url: str = "https://api.bing.microsoft.com/v7.0/search",
bing_api_key: Optional[Union[str, None]] = None,
request_kwargs: Optional[Union[Dict[str, Any], None]] = None,
):
@@ -46,6 +48,7 @@ def __init__(
self.viewport_current_page = 0
self.viewport_pages: List[Tuple[int, int]] = list()
self.set_address(self.start_page)
+ self.bing_base_url = bing_base_url
self.bing_api_key = bing_api_key
self.request_kwargs = request_kwargs
@@ -144,7 +147,7 @@ def _bing_api_call(self, query: str) -> Dict[str, Dict[str, List[Dict[str, Union
request_kwargs["stream"] = False
# Make the request
- response = requests.get("https://api.bing.microsoft.com/v7.0/search", **request_kwargs)
+ response = requests.get(self.bing_base_url, **request_kwargs)
response.raise_for_status()
results = response.json()
diff --git a/autogen/cache/__init__.py b/autogen/cache/__init__.py
index febfa8c7c5d..ea547e20c8e 100644
--- a/autogen/cache/__init__.py
+++ b/autogen/cache/__init__.py
@@ -1,3 +1,4 @@
-from .cache import Cache, AbstractCache
+from .abstract_cache_base import AbstractCache
+from .cache import Cache
__all__ = ["Cache", "AbstractCache"]
diff --git a/autogen/cache/abstract_cache_base.py b/autogen/cache/abstract_cache_base.py
index ebf1cecfa40..cfe501083fa 100644
--- a/autogen/cache/abstract_cache_base.py
+++ b/autogen/cache/abstract_cache_base.py
@@ -1,6 +1,6 @@
+import sys
from types import TracebackType
from typing import Any, Optional, Protocol, Type
-import sys
if sys.version_info >= (3, 11):
from typing import Self
diff --git a/autogen/cache/cache.py b/autogen/cache/cache.py
index 31bbfa13529..6a15d993ff6 100644
--- a/autogen/cache/cache.py
+++ b/autogen/cache/cache.py
@@ -1,13 +1,12 @@
from __future__ import annotations
+
+import sys
from types import TracebackType
-from typing import Dict, Any, Optional, Type, Union
+from typing import Any, Dict, Optional, Type, TypedDict, Union
from .abstract_cache_base import AbstractCache
-
from .cache_factory import CacheFactory
-import sys
-
if sys.version_info >= (3, 11):
from typing import Self
else:
@@ -27,7 +26,12 @@ class Cache(AbstractCache):
cache: The cache instance created based on the provided configuration.
"""
- ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]
+ ALLOWED_CONFIG_KEYS = [
+ "cache_seed",
+ "redis_url",
+ "cache_path_root",
+ "cosmos_db_config",
+ ]
@staticmethod
def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache":
@@ -57,6 +61,32 @@ def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> "
"""
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})
+ @staticmethod
+ def cosmos_db(
+ connection_string: Optional[str] = None,
+ container_id: Optional[str] = None,
+ cache_seed: Union[str, int] = 42,
+ client: Optional[any] = None,
+ ) -> "Cache":
+ """
+ Create a Cosmos DB cache instance with 'autogen_cache' as database ID.
+
+ Args:
+ connection_string (str, optional): Connection string to the Cosmos DB account.
+ container_id (str, optional): The container ID for the Cosmos DB account.
+ cache_seed (Union[str, int], optional): A seed for the cache.
+ client: Optional[CosmosClient]: Pass an existing Cosmos DB client.
+ Returns:
+ Cache: A Cache instance configured for Cosmos DB.
+ """
+ cosmos_db_config = {
+ "connection_string": connection_string,
+ "database_id": "autogen_cache",
+ "container_id": container_id,
+ "client": client,
+ }
+ return Cache({"cache_seed": str(cache_seed), "cosmos_db_config": cosmos_db_config})
+
def __init__(self, config: Dict[str, Any]):
"""
Initialize the Cache with the given configuration.
@@ -70,15 +100,19 @@ def __init__(self, config: Dict[str, Any]):
ValueError: If an invalid configuration key is provided.
"""
self.config = config
+ # Ensure that the seed is always treated as a string before being passed to any cache factory or stored.
+ self.config["cache_seed"] = str(self.config.get("cache_seed", 42))
+
# validate config
for key in self.config.keys():
if key not in self.ALLOWED_CONFIG_KEYS:
raise ValueError(f"Invalid config key: {key}")
# create cache instance
self.cache = CacheFactory.cache_factory(
- self.config.get("cache_seed", "42"),
- self.config.get("redis_url", None),
- self.config.get("cache_path_root", None),
+ seed=self.config["cache_seed"],
+ redis_url=self.config.get("redis_url"),
+ cache_path_root=self.config.get("cache_path_root"),
+ cosmosdb_config=self.config.get("cosmos_db_config"),
)
def __enter__(self) -> "Cache":
diff --git a/autogen/cache/cache_factory.py b/autogen/cache/cache_factory.py
index e3c50e9eb2b..7c9d71884cb 100644
--- a/autogen/cache/cache_factory.py
+++ b/autogen/cache/cache_factory.py
@@ -1,32 +1,36 @@
-from typing import Optional, Union
+import logging
+import os
+from typing import Any, Dict, Optional, Union
+
from .abstract_cache_base import AbstractCache
from .disk_cache import DiskCache
-import logging
-
class CacheFactory:
@staticmethod
def cache_factory(
- seed: Union[str, int], redis_url: Optional[str] = None, cache_path_root: str = ".cache"
+ seed: Union[str, int],
+ redis_url: Optional[str] = None,
+ cache_path_root: str = ".cache",
+ cosmosdb_config: Optional[Dict[str, Any]] = None,
) -> AbstractCache:
"""
Factory function for creating cache instances.
- Based on the provided redis_url, this function decides whether to create a RedisCache
- or DiskCache instance. If RedisCache is available and redis_url is provided,
- a RedisCache instance is created. Otherwise, a DiskCache instance is used.
+ This function decides whether to create a RedisCache, DiskCache, or CosmosDBCache instance
+ based on the provided parameters. If RedisCache is available and a redis_url is provided,
+ a RedisCache instance is created. If connection_string, database_id, and container_id
+ are provided, a CosmosDBCache is created. Otherwise, a DiskCache instance is used.
Args:
- seed (Union[str, int]): A string or int used as a seed or namespace for the cache.
- This could be useful for creating distinct cache instances
- or for namespacing keys in the cache.
- redis_url (str or None): The URL for the Redis server. If this is None
- or if RedisCache is not available, a DiskCache instance is created.
+ seed (Union[str, int]): Used as a seed or namespace for the cache.
+ redis_url (Optional[str]): URL for the Redis server.
+ cache_path_root (str): Root path for the disk cache.
+ cosmosdb_config (Optional[Dict[str, str]]): Dictionary containing 'connection_string',
+ 'database_id', and 'container_id' for Cosmos DB cache.
Returns:
- An instance of either RedisCache or DiskCache, depending on the availability of RedisCache
- and the provided redis_url.
+ An instance of RedisCache, DiskCache, or CosmosDBCache.
Examples:
@@ -40,14 +44,36 @@ def cache_factory(
```python
disk_cache = cache_factory("myseed", None)
```
+
+ Creating a Cosmos DB cache:
+ ```python
+ cosmos_cache = cache_factory("myseed", cosmosdb_config={
+ "connection_string": "your_connection_string",
+ "database_id": "your_database_id",
+ "container_id": "your_container_id"}
+ )
+ ```
+
"""
- if redis_url is not None:
+ if redis_url:
try:
from .redis_cache import RedisCache
return RedisCache(seed, redis_url)
except ImportError:
- logging.warning("RedisCache is not available. Creating a DiskCache instance instead.")
- return DiskCache(f"./{cache_path_root}/{seed}")
- else:
- return DiskCache(f"./{cache_path_root}/{seed}")
+ logging.warning(
+ "RedisCache is not available. Checking other cache options. The last fallback is DiskCache."
+ )
+
+ if cosmosdb_config:
+ try:
+ from .cosmos_db_cache import CosmosDBCache
+
+ return CosmosDBCache.create_cache(seed, cosmosdb_config)
+
+ except ImportError:
+ logging.warning("CosmosDBCache is not available. Fallback to DiskCache.")
+
+ # Default to DiskCache if neither Redis nor Cosmos DB configurations are provided
+ path = os.path.join(cache_path_root, str(seed))
+ return DiskCache(os.path.join(".", path))
diff --git a/autogen/cache/cosmos_db_cache.py b/autogen/cache/cosmos_db_cache.py
new file mode 100644
index 00000000000..b85be923c2f
--- /dev/null
+++ b/autogen/cache/cosmos_db_cache.py
@@ -0,0 +1,144 @@
+# Install Azure Cosmos DB SDK if not already
+
+import pickle
+from typing import Any, Optional, TypedDict, Union
+
+from azure.cosmos import CosmosClient, PartitionKey, exceptions
+from azure.cosmos.exceptions import CosmosResourceNotFoundError
+
+from autogen.cache.abstract_cache_base import AbstractCache
+
+
+class CosmosDBConfig(TypedDict, total=False):
+ connection_string: str
+ database_id: str
+ container_id: str
+ cache_seed: Optional[Union[str, int]]
+ client: Optional[CosmosClient]
+
+
+class CosmosDBCache(AbstractCache):
+ """
+ Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API.
+
+ This class provides a concrete implementation of the AbstractCache
+ interface using Azure Cosmos DB for caching data, with synchronous operations.
+
+ Attributes:
+ seed (Union[str, int]): A seed or namespace used as a partition key.
+ client (CosmosClient): The Cosmos DB client used for caching.
+ container: The container instance used for caching.
+ """
+
+ def __init__(self, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
+ """
+ Initialize the CosmosDBCache instance.
+
+ Args:
+ seed (Union[str, int]): A seed or namespace for the cache, used as a partition key.
+ connection_string (str): The connection string for the Cosmos DB account.
+ container_id (str): The container ID to be used for caching.
+ client (Optional[CosmosClient]): An existing CosmosClient instance to be used for caching.
+ """
+ self.seed = str(seed)
+ self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string(
+ cosmosdb_config["connection_string"]
+ )
+ database_id = cosmosdb_config.get("database_id", "autogen_cache")
+ self.database = self.client.get_database_client(database_id)
+ container_id = cosmosdb_config.get("container_id")
+ self.container = self.database.create_container_if_not_exists(
+ id=container_id, partition_key=PartitionKey(path="/partitionKey")
+ )
+
+ @classmethod
+ def create_cache(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
+ """
+ Factory method to create a CosmosDBCache instance based on the provided configuration.
+ This method decides whether to use an existing CosmosClient or create a new one.
+ """
+ if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient):
+ return cls.from_existing_client(seed, **cosmosdb_config)
+ else:
+ return cls.from_config(seed, cosmosdb_config)
+
+ @classmethod
+ def from_config(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
+ return cls(str(seed), cosmosdb_config)
+
+ @classmethod
+ def from_connection_string(cls, seed: Union[str, int], connection_string: str, database_id: str, container_id: str):
+ config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id}
+ return cls(str(seed), config)
+
+ @classmethod
+ def from_existing_client(cls, seed: Union[str, int], client: CosmosClient, database_id: str, container_id: str):
+ config = {"client": client, "database_id": database_id, "container_id": container_id}
+ return cls(str(seed), config)
+
+ def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
+ """
+ Retrieve an item from the Cosmos DB cache.
+
+ Args:
+ key (str): The key identifying the item in the cache.
+ default (optional): The default value to return if the key is not found.
+
+ Returns:
+ The deserialized value associated with the key if found, else the default value.
+ """
+ try:
+ response = self.container.read_item(item=key, partition_key=str(self.seed))
+ return pickle.loads(response["data"])
+ except CosmosResourceNotFoundError:
+ return default
+ except Exception as e:
+ # Log the exception or rethrow after logging if needed
+ # Consider logging or handling the error appropriately here
+ raise e
+
+ def set(self, key: str, value: Any) -> None:
+ """
+ Set an item in the Cosmos DB cache.
+
+ Args:
+ key (str): The key under which the item is to be stored.
+ value: The value to be stored in the cache.
+
+ Notes:
+ The value is serialized using pickle before being stored.
+ """
+ try:
+ serialized_value = pickle.dumps(value)
+ item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value}
+ self.container.upsert_item(item)
+ except Exception as e:
+ # Log or handle exception
+ raise e
+
+ def close(self) -> None:
+ """
+ Close the Cosmos DB client.
+
+ Perform any necessary cleanup, such as closing network connections.
+ """
+ # CosmosClient doesn"t require explicit close in the current SDK
+ # If you created the client inside this class, you should close it if necessary
+ pass
+
+ def __enter__(self):
+ """
+ Context management entry.
+
+ Returns:
+ self: The instance itself.
+ """
+ return self
+
+ def __exit__(self, exc_type: Optional[type], exc_value: Optional[Exception], traceback: Optional[Any]) -> None:
+ """
+ Context management exit.
+
+ Perform cleanup actions such as closing the Cosmos DB client.
+ """
+ self.close()
diff --git a/autogen/cache/disk_cache.py b/autogen/cache/disk_cache.py
index 2cca53e6d2f..7c68e7e908c 100644
--- a/autogen/cache/disk_cache.py
+++ b/autogen/cache/disk_cache.py
@@ -1,8 +1,10 @@
+import sys
from types import TracebackType
from typing import Any, Optional, Type, Union
+
import diskcache
+
from .abstract_cache_base import AbstractCache
-import sys
if sys.version_info >= (3, 11):
from typing import Self
diff --git a/autogen/cache/in_memory_cache.py b/autogen/cache/in_memory_cache.py
new file mode 100644
index 00000000000..b79f9ecfa4f
--- /dev/null
+++ b/autogen/cache/in_memory_cache.py
@@ -0,0 +1,55 @@
+import sys
+from types import TracebackType
+from typing import Any, Dict, Optional, Type, Union
+
+from .abstract_cache_base import AbstractCache
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
+
+
+class InMemoryCache(AbstractCache):
+
+ def __init__(self, seed: Union[str, int] = ""):
+ self._seed = str(seed)
+ self._cache: Dict[str, Any] = {}
+
+ def _prefixed_key(self, key: str) -> str:
+ separator = "_" if self._seed else ""
+ return f"{self._seed}{separator}{key}"
+
+ def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
+ result = self._cache.get(self._prefixed_key(key))
+ if result is None:
+ return default
+ return result
+
+ def set(self, key: str, value: Any) -> None:
+ self._cache[self._prefixed_key(key)] = value
+
+ def close(self) -> None:
+ pass
+
+ def __enter__(self) -> Self:
+ """
+ Enter the runtime context related to the object.
+
+ Returns:
+ self: The instance itself.
+ """
+ return self
+
+ def __exit__(
+ self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
+ ) -> None:
+ """
+ Exit the runtime context related to the object.
+
+ Args:
+ exc_type: The exception type if an exception was raised in the context.
+ exc_value: The exception value if an exception was raised in the context.
+ traceback: The traceback if an exception was raised in the context.
+ """
+ self.close()
diff --git a/autogen/cache/redis_cache.py b/autogen/cache/redis_cache.py
index d125d3ba203..36d601af702 100644
--- a/autogen/cache/redis_cache.py
+++ b/autogen/cache/redis_cache.py
@@ -1,8 +1,10 @@
import pickle
+import sys
from types import TracebackType
from typing import Any, Optional, Type, Union
+
import redis
-import sys
+
from .abstract_cache_base import AbstractCache
if sys.version_info >= (3, 11):
diff --git a/autogen/code_utils.py b/autogen/code_utils.py
index 57a817855f7..98ed6067066 100644
--- a/autogen/code_utils.py
+++ b/autogen/code_utils.py
@@ -6,14 +6,16 @@
import subprocess
import sys
import time
+import venv
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from hashlib import md5
+from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from autogen import oai
-
import docker
+from autogen import oai
+
from .types import UserMessageImageContentPart, UserMessageTextContentPart
SENTINEL = object()
@@ -35,12 +37,13 @@
DEFAULT_TIMEOUT = 600
WIN32 = sys.platform == "win32"
PATH_SEPARATOR = WIN32 and "\\" or "/"
+PYTHON_VARIANTS = ["python", "Python", "py"]
logger = logging.getLogger(__name__)
def content_str(content: Union[str, List[Union[UserMessageTextContentPart, UserMessageImageContentPart]], None]) -> str:
- """Converts the `content` field of an OpenAI merssage into a string format.
+ """Converts the `content` field of an OpenAI message into a string format.
This function processes content that may be a string, a list of mixed text and image URLs, or None,
and converts it into a string. Text is directly appended to the result string, while image URLs are
@@ -244,10 +247,14 @@ def get_powershell_command():
def _cmd(lang: str) -> str:
+ if lang in PYTHON_VARIANTS:
+ return "python"
if lang.startswith("python") or lang in ["bash", "sh"]:
return lang
if lang in ["shell"]:
return "sh"
+ if lang == "javascript":
+ return "node"
if lang in ["ps1", "pwsh", "powershell"]:
powershell_command = get_powershell_command()
return powershell_command
@@ -278,7 +285,7 @@ def in_docker_container() -> bool:
return os.path.exists("/.dockerenv")
-def decide_use_docker(use_docker) -> bool:
+def decide_use_docker(use_docker: Optional[bool]) -> Optional[bool]:
if use_docker is None:
env_var_use_docker = os.environ.get("AUTOGEN_USE_DOCKER", "True")
@@ -714,3 +721,19 @@ def implement(
# cost += metrics["gen_cost"]
# if metrics["succeed_assertions"] or i == len(configs) - 1:
# return responses[metrics["index_selected"]], cost, i
+
+
+def create_virtual_env(dir_path: str, **env_args) -> SimpleNamespace:
+ """Creates a python virtual environment and returns the context.
+
+ Args:
+ dir_path (str): Directory path where the env will be created.
+ **env_args: Any extra args to pass to the `EnvBuilder`
+
+ Returns:
+ SimpleNamespace: the virtual env context object."""
+ if not env_args:
+ env_args = {"with_pip": True}
+ env_builder = venv.EnvBuilder(**env_args)
+ env_builder.create(dir_path)
+ return env_builder.ensure_directories(dir_path)
diff --git a/autogen/coding/__init__.py b/autogen/coding/__init__.py
index 2ba4e9b0734..2f53b88ca3d 100644
--- a/autogen/coding/__init__.py
+++ b/autogen/coding/__init__.py
@@ -1,8 +1,8 @@
from .base import CodeBlock, CodeExecutor, CodeExtractor, CodeResult
+from .docker_commandline_code_executor import DockerCommandLineCodeExecutor
from .factory import CodeExecutorFactory
-from .markdown_code_extractor import MarkdownCodeExtractor
from .local_commandline_code_executor import LocalCommandLineCodeExecutor
-from .docker_commandline_code_executor import DockerCommandLineCodeExecutor
+from .markdown_code_extractor import MarkdownCodeExtractor
__all__ = (
"CodeBlock",
diff --git a/autogen/coding/base.py b/autogen/coding/base.py
index f60ff0de85e..ccbfe6b9293 100644
--- a/autogen/coding/base.py
+++ b/autogen/coding/base.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
from typing import Any, List, Literal, Mapping, Optional, Protocol, TypedDict, Union, runtime_checkable
from pydantic import BaseModel, Field
diff --git a/autogen/coding/docker_commandline_code_executor.py b/autogen/coding/docker_commandline_code_executor.py
index f1db7cd07e7..6d8f4e309c8 100644
--- a/autogen/coding/docker_commandline_code_executor.py
+++ b/autogen/coding/docker_commandline_code_executor.py
@@ -1,22 +1,22 @@
from __future__ import annotations
+
import atexit
-from hashlib import md5
import logging
+import sys
+import uuid
+from hashlib import md5
from pathlib import Path
from time import sleep
from types import TracebackType
-import uuid
-from typing import Any, List, Optional, Type, Union
+from typing import Any, ClassVar, Dict, List, Optional, Type, Union
+
import docker
from docker.errors import ImageNotFound
-from .utils import _get_file_name_from_content, silence_pip
-from .base import CommandLineCodeResult
-
from ..code_utils import TIMEOUT_MSG, _cmd
-from .base import CodeBlock, CodeExecutor, CodeExtractor
+from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
from .markdown_code_extractor import MarkdownCodeExtractor
-import sys
+from .utils import _get_file_name_from_content, silence_pip
if sys.version_info >= (3, 11):
from typing import Self
@@ -39,14 +39,30 @@ def _wait_for_ready(container: Any, timeout: int = 60, stop_time: float = 0.1) -
class DockerCommandLineCodeExecutor(CodeExecutor):
+ DEFAULT_EXECUTION_POLICY: ClassVar[Dict[str, bool]] = {
+ "bash": True,
+ "shell": True,
+ "sh": True,
+ "pwsh": True,
+ "powershell": True,
+ "ps1": True,
+ "python": True,
+ "javascript": False,
+ "html": False,
+ "css": False,
+ }
+ LANGUAGE_ALIASES: ClassVar[Dict[str, str]] = {"py": "python", "js": "javascript"}
+
def __init__(
self,
image: str = "python:3-slim",
container_name: Optional[str] = None,
timeout: int = 60,
work_dir: Union[Path, str] = Path("."),
+ bind_dir: Optional[Union[Path, str]] = None,
auto_remove: bool = True,
stop_container: bool = True,
+ execution_policies: Optional[Dict[str, bool]] = None,
):
"""(Experimental) A code executor class that executes code through
a command line environment in a Docker container.
@@ -67,6 +83,9 @@ def __init__(
timeout (int, optional): The timeout for code execution. Defaults to 60.
work_dir (Union[Path, str], optional): The working directory for the code
execution. Defaults to Path(".").
+ bind_dir (Union[Path, str], optional): The directory that will be bound
+ to the code executor container. Useful for cases where you want to spawn
+ the container from within a container. Defaults to work_dir.
auto_remove (bool, optional): If true, will automatically remove the Docker
container when it is stopped. Defaults to True.
stop_container (bool, optional): If true, will automatically stop the
@@ -76,18 +95,19 @@ def __init__(
Raises:
ValueError: On argument error, or if the container fails to start.
"""
-
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")
if isinstance(work_dir, str):
work_dir = Path(work_dir)
+ work_dir.mkdir(exist_ok=True)
- if not work_dir.exists():
- raise ValueError(f"Working directory {work_dir} does not exist.")
+ if bind_dir is None:
+ bind_dir = work_dir
+ elif isinstance(bind_dir, str):
+ bind_dir = Path(bind_dir)
client = docker.from_env()
-
# Check if the image exists
try:
client.images.get(image)
@@ -106,7 +126,7 @@ def __init__(
entrypoint="/bin/sh",
tty=True,
auto_remove=auto_remove,
- volumes={str(work_dir.resolve()): {"bind": "/workspace", "mode": "rw"}},
+ volumes={str(bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"}},
working_dir="/workspace",
)
self._container.start()
@@ -119,7 +139,6 @@ def cleanup() -> None:
container.stop()
except docker.errors.NotFound:
pass
-
atexit.unregister(cleanup)
if stop_container:
@@ -133,6 +152,10 @@ def cleanup() -> None:
self._timeout = timeout
self._work_dir: Path = work_dir
+ self._bind_dir: Path = bind_dir
+ self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
+ if execution_policies is not None:
+ self.execution_policies.update(execution_policies)
@property
def timeout(self) -> int:
@@ -144,6 +167,11 @@ def work_dir(self) -> Path:
"""(Experimental) The working directory for the code execution."""
return self._work_dir
+ @property
+ def bind_dir(self) -> Path:
+ """(Experimental) The binding directory for the code execution container."""
+ return self._bind_dir
+
@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
@@ -165,35 +193,42 @@ def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeRe
files = []
last_exit_code = 0
for code_block in code_blocks:
- lang = code_block.language
+ lang = self.LANGUAGE_ALIASES.get(code_block.language.lower(), code_block.language.lower())
+ if lang not in self.DEFAULT_EXECUTION_POLICY:
+ outputs.append(f"Unsupported language {lang}\n")
+ last_exit_code = 1
+ break
+
+ execute_code = self.execution_policies.get(lang, False)
code = silence_pip(code_block.code, lang)
+ # Check if there is a filename comment
try:
- # Check if there is a filename comment
- filename = _get_file_name_from_content(code, Path("/workspace"))
+ filename = _get_file_name_from_content(code, self._work_dir)
except ValueError:
- return CommandLineCodeResult(exit_code=1, output="Filename is not in the workspace")
+ outputs.append("Filename is not in the workspace")
+ last_exit_code = 1
+ break
- if filename is None:
- # create a file with an automatically generated name
- code_hash = md5(code.encode()).hexdigest()
- filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
+ if not filename:
+ filename = f"tmp_code_{md5(code.encode()).hexdigest()}.{lang}"
code_path = self._work_dir / filename
with code_path.open("w", encoding="utf-8") as fout:
fout.write(code)
+ files.append(code_path)
- command = ["timeout", str(self._timeout), _cmd(lang), filename]
+ if not execute_code:
+ outputs.append(f"Code saved to {str(code_path)}\n")
+ continue
+ command = ["timeout", str(self._timeout), _cmd(lang), filename]
result = self._container.exec_run(command)
exit_code = result.exit_code
output = result.output.decode("utf-8")
if exit_code == 124:
- output += "\n"
- output += TIMEOUT_MSG
-
+ output += "\n" + TIMEOUT_MSG
outputs.append(output)
- files.append(code_path)
last_exit_code = exit_code
if exit_code != 0:
diff --git a/autogen/coding/factory.py b/autogen/coding/factory.py
index 0c2d41b89da..b484d99cda1 100644
--- a/autogen/coding/factory.py
+++ b/autogen/coding/factory.py
@@ -1,4 +1,4 @@
-from .base import CodeExecutor, CodeExecutionConfig
+from .base import CodeExecutionConfig, CodeExecutor
__all__ = ("CodeExecutorFactory",)
diff --git a/autogen/coding/func_with_reqs.py b/autogen/coding/func_with_reqs.py
index c37c12c1e2f..6f199573822 100644
--- a/autogen/coding/func_with_reqs.py
+++ b/autogen/coding/func_with_reqs.py
@@ -1,16 +1,23 @@
from __future__ import annotations
-import inspect
+
import functools
-from typing import Any, Callable, List, TypeVar, Generic, Union
-from typing_extensions import ParamSpec
-from textwrap import indent, dedent
+import importlib
+import inspect
from dataclasses import dataclass, field
+from importlib.abc import SourceLoader
+from textwrap import dedent, indent
+from typing import Any, Callable, Generic, List, TypeVar, Union
+
+from typing_extensions import ParamSpec
T = TypeVar("T")
P = ParamSpec("P")
-def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T]]) -> str:
+def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
+ if isinstance(func, FunctionWithRequirementsStr):
+ return func.func
+
code = inspect.getsource(func)
# Strip the decorator
if code.startswith("@"):
@@ -50,6 +57,57 @@ def to_str(i: Union[str, Alias]) -> str:
return f"from {im.module} import {imports}"
+class _StringLoader(SourceLoader):
+ def __init__(self, data: str):
+ self.data = data
+
+ def get_source(self, fullname: str) -> str:
+ return self.data
+
+ def get_data(self, path: str) -> bytes:
+ return self.data.encode("utf-8")
+
+ def get_filename(self, fullname: str) -> str:
+ return "/" + fullname + ".py"
+
+
+@dataclass
+class FunctionWithRequirementsStr:
+ func: str
+ _compiled_func: Callable[..., Any]
+ _func_name: str
+ python_packages: List[str] = field(default_factory=list)
+ global_imports: List[Import] = field(default_factory=list)
+
+ def __init__(self, func: str, python_packages: List[str] = [], global_imports: List[Import] = []):
+ self.func = func
+ self.python_packages = python_packages
+ self.global_imports = global_imports
+
+ module_name = "func_module"
+ loader = _StringLoader(func)
+ spec = importlib.util.spec_from_loader(module_name, loader)
+ if spec is None:
+ raise ValueError("Could not create spec")
+ module = importlib.util.module_from_spec(spec)
+ if spec.loader is None:
+ raise ValueError("Could not create loader")
+
+ try:
+ spec.loader.exec_module(module)
+ except Exception as e:
+ raise ValueError(f"Could not compile function: {e}") from e
+
+ functions = inspect.getmembers(module, inspect.isfunction)
+ if len(functions) != 1:
+ raise ValueError("The string must contain exactly one function")
+
+ self._func_name, self._compiled_func = functions[0]
+
+ def __call__(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("String based function with requirement objects are not directly callable")
+
+
@dataclass
class FunctionWithRequirements(Generic[T, P]):
func: Callable[P, T]
@@ -62,6 +120,12 @@ def from_callable(
) -> FunctionWithRequirements[T, P]:
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
+ @staticmethod
+ def from_str(
+ func: str, python_packages: List[str] = [], global_imports: List[Import] = []
+ ) -> FunctionWithRequirementsStr:
+ return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
+
# Type this based on F
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.func(*args, **kwargs)
@@ -91,11 +155,13 @@ def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
return wrapper
-def _build_python_functions_file(funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any]]]) -> str:
+def _build_python_functions_file(
+ funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]]
+) -> str:
# First collect all global imports
global_imports = set()
for func in funcs:
- if isinstance(func, FunctionWithRequirements):
+ if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
global_imports.update(func.global_imports)
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
@@ -106,7 +172,7 @@ def _build_python_functions_file(funcs: List[Union[FunctionWithRequirements[Any,
return content
-def to_stub(func: Callable[..., Any]) -> str:
+def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
"""Generate a stub for a function as a string
Args:
@@ -115,6 +181,9 @@ def to_stub(func: Callable[..., Any]) -> str:
Returns:
str: The stub for the function
"""
+ if isinstance(func, FunctionWithRequirementsStr):
+ return to_stub(func._compiled_func)
+
content = f"def {func.__name__}{inspect.signature(func)}:\n"
docstring = func.__doc__
diff --git a/autogen/coding/jupyter/__init__.py b/autogen/coding/jupyter/__init__.py
index 5c1a9607f56..f6f02313ec1 100644
--- a/autogen/coding/jupyter/__init__.py
+++ b/autogen/coding/jupyter/__init__.py
@@ -1,9 +1,9 @@
from .base import JupyterConnectable, JupyterConnectionInfo
-from .jupyter_client import JupyterClient
-from .local_jupyter_server import LocalJupyterServer
from .docker_jupyter_server import DockerJupyterServer
from .embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
+from .jupyter_client import JupyterClient
from .jupyter_code_executor import JupyterCodeExecutor
+from .local_jupyter_server import LocalJupyterServer
__all__ = [
"JupyterConnectable",
diff --git a/autogen/coding/jupyter/base.py b/autogen/coding/jupyter/base.py
index d896b6ac3cc..0e7acaf1e87 100644
--- a/autogen/coding/jupyter/base.py
+++ b/autogen/coding/jupyter/base.py
@@ -10,9 +10,9 @@ class JupyterConnectionInfo:
"""`str` - Host of the Jupyter gateway server"""
use_https: bool
"""`bool` - Whether to use HTTPS"""
- port: int
- """`int` - Port of the Jupyter gateway server"""
- token: Optional[str]
+ port: Optional[int] = None
+ """`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
+ token: Optional[str] = None
"""`Optional[str]` - Token for authentication. If None, no token is used"""
diff --git a/autogen/coding/jupyter/docker_jupyter_server.py b/autogen/coding/jupyter/docker_jupyter_server.py
index 3b9462186b9..83455e27238 100644
--- a/autogen/coding/jupyter/docker_jupyter_server.py
+++ b/autogen/coding/jupyter/docker_jupyter_server.py
@@ -1,15 +1,16 @@
from __future__ import annotations
-from pathlib import Path
+import atexit
+import io
+import logging
+import secrets
import sys
-from types import TracebackType
import uuid
+from pathlib import Path
+from types import TracebackType
from typing import Dict, Optional, Type, Union
+
import docker
-import secrets
-import io
-import atexit
-import logging
from ..docker_commandline_code_executor import _wait_for_ready
@@ -19,8 +20,8 @@
from typing_extensions import Self
-from .jupyter_client import JupyterClient
from .base import JupyterConnectable, JupyterConnectionInfo
+from .jupyter_client import JupyterClient
class DockerJupyterServer(JupyterConnectable):
diff --git a/autogen/coding/jupyter/embedded_ipython_code_executor.py b/autogen/coding/jupyter/embedded_ipython_code_executor.py
index 0d647082a3c..f9200c7a580 100644
--- a/autogen/coding/jupyter/embedded_ipython_code_executor.py
+++ b/autogen/coding/jupyter/embedded_ipython_code_executor.py
@@ -1,9 +1,9 @@
import base64
import json
import os
-from pathlib import Path
import re
import uuid
+from pathlib import Path
from queue import Empty
from typing import Any, ClassVar, List
diff --git a/autogen/coding/jupyter/jupyter_client.py b/autogen/coding/jupyter/jupyter_client.py
index 8f97ab82418..b3de374fce9 100644
--- a/autogen/coding/jupyter/jupyter_client.py
+++ b/autogen/coding/jupyter/jupyter_client.py
@@ -1,22 +1,22 @@
from __future__ import annotations
+import sys
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Dict, List, Optional, Type, cast
-import sys
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
+import datetime
import json
import uuid
-import datetime
-import requests
-from requests.adapters import HTTPAdapter, Retry
+import requests
import websocket
+from requests.adapters import HTTPAdapter, Retry
from websocket import WebSocket
from .base import JupyterConnectionInfo
@@ -41,10 +41,12 @@ def _get_headers(self) -> Dict[str, str]:
def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
- return f"{protocol}://{self._connection_info.host}:{self._connection_info.port}"
+ port = f":{self._connection_info.port}" if self._connection_info.port else ""
+ return f"{protocol}://{self._connection_info.host}{port}"
def _get_ws_base_url(self) -> str:
- return f"ws://{self._connection_info.host}:{self._connection_info.port}"
+ port = f":{self._connection_info.port}" if self._connection_info.port else ""
+ return f"ws://{self._connection_info.host}{port}"
def list_kernel_specs(self) -> Dict[str, Dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
diff --git a/autogen/coding/jupyter/jupyter_code_executor.py b/autogen/coding/jupyter/jupyter_code_executor.py
index 4d926773517..833565ab47c 100644
--- a/autogen/coding/jupyter/jupyter_code_executor.py
+++ b/autogen/coding/jupyter/jupyter_code_executor.py
@@ -1,12 +1,12 @@
import base64
import json
import os
-from pathlib import Path
import re
-from types import TracebackType
+import sys
import uuid
+from pathlib import Path
+from types import TracebackType
from typing import Any, ClassVar, List, Optional, Type, Union
-import sys
from autogen.coding.utils import silence_pip
diff --git a/autogen/coding/jupyter/local_jupyter_server.py b/autogen/coding/jupyter/local_jupyter_server.py
index 0709f55ee4e..9b892303848 100644
--- a/autogen/coding/jupyter/local_jupyter_server.py
+++ b/autogen/coding/jupyter/local_jupyter_server.py
@@ -1,14 +1,14 @@
from __future__ import annotations
-from types import TracebackType
-from typing import Optional, Type, Union, cast
-import subprocess
-import signal
-import sys
+import atexit
import json
import secrets
+import signal
import socket
-import atexit
+import subprocess
+import sys
+from types import TracebackType
+from typing import Optional, Type, Union, cast
if sys.version_info >= (3, 11):
from typing import Self
diff --git a/autogen/coding/local_commandline_code_executor.py b/autogen/coding/local_commandline_code_executor.py
index b75f54ff121..29172bbe922 100644
--- a/autogen/coding/local_commandline_code_executor.py
+++ b/autogen/coding/local_commandline_code_executor.py
@@ -1,32 +1,60 @@
-from hashlib import md5
-from pathlib import Path
+import logging
+import os
import re
-from string import Template
+import subprocess
import sys
import warnings
-from typing import Any, Callable, ClassVar, List, TypeVar, Union, cast
+from hashlib import md5
+from pathlib import Path
+from string import Template
+from types import SimpleNamespace
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
+
from typing_extensions import ParamSpec
-from autogen.coding.func_with_reqs import FunctionWithRequirements, _build_python_functions_file, to_stub
-from ..code_utils import TIMEOUT_MSG, WIN32, _cmd
+from autogen.coding.func_with_reqs import (
+ FunctionWithRequirements,
+ FunctionWithRequirementsStr,
+ _build_python_functions_file,
+ to_stub,
+)
+
+from ..code_utils import PYTHON_VARIANTS, TIMEOUT_MSG, WIN32, _cmd
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
from .markdown_code_extractor import MarkdownCodeExtractor
-
from .utils import _get_file_name_from_content, silence_pip
-import subprocess
-
-import logging
-
__all__ = ("LocalCommandLineCodeExecutor",)
A = ParamSpec("A")
class LocalCommandLineCodeExecutor(CodeExecutor):
- SUPPORTED_LANGUAGES: ClassVar[List[str]] = ["bash", "shell", "sh", "pwsh", "powershell", "ps1", "python"]
- FUNCTIONS_MODULE: ClassVar[str] = "functions"
- FUNCTIONS_FILENAME: ClassVar[str] = "functions.py"
+ SUPPORTED_LANGUAGES: ClassVar[List[str]] = [
+ "bash",
+ "shell",
+ "sh",
+ "pwsh",
+ "powershell",
+ "ps1",
+ "python",
+ "javascript",
+ "html",
+ "css",
+ ]
+ DEFAULT_EXECUTION_POLICY: ClassVar[Dict[str, bool]] = {
+ "bash": True,
+ "shell": True,
+ "sh": True,
+ "pwsh": True,
+ "powershell": True,
+ "ps1": True,
+ "python": True,
+ "javascript": False,
+ "html": False,
+ "css": False,
+ }
+
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
@@ -38,31 +66,45 @@ class LocalCommandLineCodeExecutor(CodeExecutor):
def __init__(
self,
timeout: int = 60,
+ virtual_env_context: Optional[SimpleNamespace] = None,
work_dir: Union[Path, str] = Path("."),
- functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]] = [],
+ functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
+ functions_module: str = "functions",
+ execution_policies: Optional[Dict[str, bool]] = None,
):
- """(Experimental) A code executor class that executes code through a local command line
+ """(Experimental) A code executor class that executes or saves LLM generated code a local command line
environment.
- **This will execute LLM generated code on the local machine.**
+ **This will execute or save LLM generated code on the local machine.**
+
+ Each code block is saved as a file in the working directory. Depending on the execution policy,
+ the code may be executed in a separate process.
+ The code blocks are executed or save in the order they are received.
+ Command line code is sanitized against a list of dangerous commands to prevent self-destructive commands from being executed,
+ which could potentially affect the user's environment. Supported languages include Python, shell scripts (bash, shell, sh),
+ PowerShell (pwsh, powershell, ps1), HTML, CSS, and JavaScript.
+ Execution policies determine whether each language's code blocks are executed or saved only.
+
+ ## Execution with a Python virtual environment
+ A python virtual env can be used to execute code and install dependencies. This has the added benefit of not polluting the
+ base environment with unwanted modules.
+ ```python
+ from autogen.code_utils import create_virtual_env
+ from autogen.coding import LocalCommandLineCodeExecutor
- Each code block is saved as a file and executed in a separate process in
- the working directory, and a unique file is generated and saved in the
- working directory for each code block.
- The code blocks are executed in the order they are received.
- Command line code is sanitized using regular expression match against a list of dangerous commands in order to prevent self-destructive
- commands from being executed which may potentially affect the users environment.
- Currently the only supported languages is Python and shell scripts.
- For Python code, use the language "python" for the code block.
- For shell scripts, use the language "bash", "shell", or "sh" for the code
- block.
+ venv_dir = ".venv"
+ venv_context = create_virtual_env(venv_dir)
+
+ executor = LocalCommandLineCodeExecutor(virtual_env_context=venv_context)
+ ```
Args:
- timeout (int): The timeout for code execution. Default is 60.
- work_dir (str): The working directory for the code execution. If None,
- a default working directory will be used. The default working
- directory is the current directory ".".
- functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]): A list of functions that are available to the code executor. Default is an empty list.
+ timeout (int): The timeout for code execution, default is 60 seconds.
+ virtual_env_context (Optional[SimpleNamespace]): The virtual environment context to use.
+ work_dir (Union[Path, str]): The working directory for code execution, defaults to the current directory.
+ functions (List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]): A list of callable functions available to the executor.
+ functions_module (str): The module name under which functions are accessible.
+ execution_policies (Optional[Dict[str, bool]]): A dictionary mapping languages to execution policies (True for execution, False for saving only). Defaults to class-wide DEFAULT_EXECUTION_POLICY.
"""
if timeout < 1:
@@ -71,11 +113,16 @@ def __init__(
if isinstance(work_dir, str):
work_dir = Path(work_dir)
- if not work_dir.exists():
- raise ValueError(f"Working directory {work_dir} does not exist.")
+ if not functions_module.isidentifier():
+ raise ValueError("Module name must be a valid Python identifier")
+
+ self._functions_module = functions_module
+
+ work_dir.mkdir(exist_ok=True)
self._timeout = timeout
self._work_dir: Path = work_dir
+ self._virtual_env_context: Optional[SimpleNamespace] = virtual_env_context
self._functions = functions
# Setup could take some time so we intentionally wait for the first code block to do it.
@@ -84,6 +131,10 @@ def __init__(
else:
self._setup_functions_complete = True
+ self.execution_policies = self.DEFAULT_EXECUTION_POLICY.copy()
+ if execution_policies is not None:
+ self.execution_policies.update(execution_policies)
+
def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEMPLATE) -> str:
"""(Experimental) Format the functions for a prompt.
@@ -97,15 +148,21 @@ def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEM
Returns:
str: The formatted prompt.
"""
-
template = Template(prompt_template)
return template.substitute(
- module_name=self.FUNCTIONS_MODULE,
+ module_name=self._functions_module,
functions="\n\n".join([to_stub(func) for func in self._functions]),
)
@property
- def functions(self) -> List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]:
+ def functions_module(self) -> str:
+ """(Experimental) The module name for the functions."""
+ return self._functions_module
+
+ @property
+ def functions(
+ self,
+ ) -> List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]:
"""(Experimental) The functions that are available to the code executor."""
return self._functions
@@ -148,7 +205,7 @@ def sanitize_command(lang: str, code: str) -> None:
def _setup_functions(self) -> None:
func_file_content = _build_python_functions_file(self._functions)
- func_file = self._work_dir / self.FUNCTIONS_FILENAME
+ func_file = self._work_dir / f"{self._functions_module}.py"
func_file.write_text(func_file_content)
# Collect requirements
@@ -157,26 +214,23 @@ def _setup_functions(self) -> None:
required_packages = list(set(flattened_packages))
if len(required_packages) > 0:
logging.info("Ensuring packages are installed in executor.")
-
- cmd = [sys.executable, "-m", "pip", "install"]
- cmd.extend(required_packages)
-
+ if self._virtual_env_context:
+ py_executable = self._virtual_env_context.env_exe
+ else:
+ py_executable = sys.executable
+ cmd = [py_executable, "-m", "pip", "install"] + required_packages
try:
result = subprocess.run(
cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
)
except subprocess.TimeoutExpired as e:
raise ValueError("Pip install timed out") from e
-
if result.returncode != 0:
raise ValueError(f"Pip install failed. {result.stdout}, {result.stderr}")
-
# Attempt to load the function file to check for syntax errors, imports etc.
exec_result = self._execute_code_dont_check_setup([CodeBlock(code=func_file_content, language="python")])
-
if exec_result.exit_code != 0:
raise ValueError(f"Functions failed to load: {exec_result.output}")
-
self._setup_functions_complete = True
def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
@@ -187,10 +241,8 @@ def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CommandLineCodeRe
Returns:
CommandLineCodeResult: The result of the code execution."""
-
if not self._setup_functions_complete:
self._setup_functions()
-
return self._execute_code_dont_check_setup(code_blocks)
def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> CommandLineCodeResult:
@@ -203,6 +255,9 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
LocalCommandLineCodeExecutor.sanitize_command(lang, code)
code = silence_pip(code, lang)
+ if lang in PYTHON_VARIANTS:
+ lang = "python"
+
if WIN32 and lang in ["sh", "shell"]:
lang = "ps1"
@@ -212,6 +267,7 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
logs_all += "\n" + f"unknown language {lang}"
break
+ execute_code = self.execution_policies.get(lang, False)
try:
# Check if there is a filename comment
filename = _get_file_name_from_content(code, self._work_dir)
@@ -222,18 +278,31 @@ def _execute_code_dont_check_setup(self, code_blocks: List[CodeBlock]) -> Comman
# create a file with an automatically generated name
code_hash = md5(code.encode()).hexdigest()
filename = f"tmp_code_{code_hash}.{'py' if lang.startswith('python') else lang}"
-
written_file = (self._work_dir / filename).resolve()
with written_file.open("w", encoding="utf-8") as f:
f.write(code)
file_names.append(written_file)
- program = sys.executable if lang.startswith("python") else _cmd(lang)
+ if not execute_code:
+ # Just return a message that the file is saved.
+ logs_all += f"Code saved to {str(written_file)}\n"
+ exitcode = 0
+ continue
+
+ program = _cmd(lang)
cmd = [program, str(written_file.absolute())]
+ env = os.environ.copy()
+
+ if self._virtual_env_context:
+ path_with_virtualenv = rf"{self._virtual_env_context.bin_path}{os.pathsep}{env['PATH']}"
+ env["PATH"] = path_with_virtualenv
+ if WIN32:
+ activation_script = os.path.join(self._virtual_env_context.bin_path, "activate.bat")
+ cmd = [activation_script, "&&", *cmd]
try:
result = subprocess.run(
- cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout)
+ cmd, cwd=self._work_dir, capture_output=True, text=True, timeout=float(self._timeout), env=env
)
except subprocess.TimeoutExpired:
logs_all += "\n" + TIMEOUT_MSG
diff --git a/autogen/coding/markdown_code_extractor.py b/autogen/coding/markdown_code_extractor.py
index 58ad4a09e2b..7a1084194eb 100644
--- a/autogen/coding/markdown_code_extractor.py
+++ b/autogen/coding/markdown_code_extractor.py
@@ -2,8 +2,8 @@
from typing import Any, Dict, List, Optional, Union
from ..code_utils import CODE_BLOCK_PATTERN, UNKNOWN, content_str, infer_lang
-from .base import CodeBlock, CodeExtractor
from ..types import UserMessageImageContentPart, UserMessageTextContentPart
+from .base import CodeBlock, CodeExtractor
__all__ = ("MarkdownCodeExtractor",)
diff --git a/autogen/coding/utils.py b/autogen/coding/utils.py
index 0a7c5a7785d..d692bfe35b9 100644
--- a/autogen/coding/utils.py
+++ b/autogen/coding/utils.py
@@ -3,23 +3,31 @@
from pathlib import Path
from typing import Optional
+filename_patterns = [
+ re.compile(r"^", re.DOTALL),
+ re.compile(r"^/\* (filename:)?(.+?) \*/", re.DOTALL),
+ re.compile(r"^// (filename:)?(.+?)$", re.DOTALL),
+ re.compile(r"^# (filename:)?(.+?)$", re.DOTALL),
+]
+
# Raises ValueError if the file is not in the workspace
def _get_file_name_from_content(code: str, workspace_path: Path) -> Optional[str]:
- first_line = code.split("\n")[0]
+ first_line = code.split("\n")[0].strip()
# TODO - support other languages
- if first_line.startswith("# filename:"):
- filename = first_line.split(":")[1].strip()
-
- # Handle relative paths in the filename
- path = Path(filename)
- if not path.is_absolute():
- path = workspace_path / path
- path = path.resolve()
- # Throws an error if the file is not in the workspace
- relative = path.relative_to(workspace_path.resolve())
- return str(relative)
+ for pattern in filename_patterns:
+ matches = pattern.match(first_line)
+ if matches is not None:
+ filename = matches.group(2).strip()
+ # Handle relative paths in the filename
+ path = Path(filename)
+ if not path.is_absolute():
+ path = workspace_path / path
+ path = path.resolve()
+ # Throws an error if the file is not in the workspace
+ relative = path.relative_to(workspace_path.resolve())
+ return str(relative)
return None
diff --git a/autogen/function_utils.py b/autogen/function_utils.py
index 0189836a494..dd225fd4719 100644
--- a/autogen/function_utils.py
+++ b/autogen/function_utils.py
@@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
-def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
+def get_param_annotations(typed_signature: inspect.Signature) -> Dict[str, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
@@ -110,9 +110,7 @@ class ToolFunction(BaseModel):
function: Annotated[Function, Field(description="Function under tool")]
-def get_parameter_json_schema(
- k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
-) -> JsonSchemaValue:
+def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Args:
@@ -285,7 +283,7 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet
return model_dump(function)
-def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
+def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[Dict[str, Any], Type[BaseModel]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model
Args:
@@ -319,10 +317,10 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
param_annotations = get_param_annotations(typed_signature)
# get functions for loading BaseModels when needed based on the type annotations
- kwargs_mapping = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()}
+ kwargs_mapping_with_nones = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()}
# remove the None values
- kwargs_mapping = {k: f for k, f in kwargs_mapping.items() if f is not None}
+ kwargs_mapping = {k: f for k, f in kwargs_mapping_with_nones.items() if f is not None}
# a function that loads the parameters before calling the original function
@functools.wraps(func)
diff --git a/autogen/graph_utils.py b/autogen/graph_utils.py
index a84fc89f9cf..d36b47a12ed 100644
--- a/autogen/graph_utils.py
+++ b/autogen/graph_utils.py
@@ -1,7 +1,7 @@
-from typing import Dict, List
import logging
+from typing import Dict, List, Optional
-from autogen.agentchat.groupchat import Agent
+from autogen.agentchat import Agent
def has_self_loops(allowed_speaker_transitions: Dict) -> bool:
@@ -110,13 +110,15 @@ def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agen
return allowed_speaker_transitions_dict
-def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: List[Agent]):
+def visualize_speaker_transitions_dict(
+ speaker_transitions_dict: dict, agents: List[Agent], export_path: Optional[str] = None
+):
"""
Visualize the speaker_transitions_dict using networkx.
"""
try:
- import networkx as nx
import matplotlib.pyplot as plt
+ import networkx as nx
except ImportError as e:
logging.fatal("Failed to import networkx or matplotlib. Try running 'pip install autogen[graphs]'")
raise e
@@ -133,4 +135,8 @@ def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: L
# Visualize
nx.draw(G, with_labels=True, font_weight="bold")
- plt.show()
+
+ if export_path is not None:
+ plt.savefig(export_path)
+ else:
+ plt.show()
diff --git a/autogen/io/__init__.py b/autogen/io/__init__.py
index 20d6d5a578f..6bb8a35680f 100644
--- a/autogen/io/__init__.py
+++ b/autogen/io/__init__.py
@@ -3,6 +3,7 @@
from .websockets import IOWebsockets
# Set the default input/output stream to the console
-IOStream._default_io_stream.set(IOConsole())
+IOStream.set_global_default(IOConsole())
+IOStream.set_default(IOConsole())
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")
diff --git a/autogen/io/base.py b/autogen/io/base.py
index 857d532e4f5..476e37db036 100644
--- a/autogen/io/base.py
+++ b/autogen/io/base.py
@@ -1,9 +1,12 @@
+import logging
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
__all__ = ("OutputStream", "InputStream", "IOStream")
+logger = logging.getLogger(__name__)
+
@runtime_checkable
class OutputStream(Protocol):
@@ -39,6 +42,31 @@ def input(self, prompt: str = "", *, password: bool = False) -> str:
class IOStream(InputStream, OutputStream, Protocol):
"""A protocol for input/output streams."""
+ # ContextVar must be used in multithreaded or async environments
+ _default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None)
+ _default_io_stream.set(None)
+ _global_default: Optional["IOStream"] = None
+
+ @staticmethod
+ def set_global_default(stream: "IOStream") -> None:
+ """Set the default input/output stream.
+
+ Args:
+ stream (IOStream): The input/output stream to set as the default.
+ """
+ IOStream._global_default = stream
+
+ @staticmethod
+ def get_global_default() -> "IOStream":
+ """Get the default input/output stream.
+
+ Returns:
+ IOStream: The default input/output stream.
+ """
+ if IOStream._global_default is None:
+ raise RuntimeError("No global default IOStream has been set")
+ return IOStream._global_default
+
@staticmethod
def get_default() -> "IOStream":
"""Get the default input/output stream.
@@ -48,13 +76,11 @@ def get_default() -> "IOStream":
"""
iostream = IOStream._default_io_stream.get()
if iostream is None:
- raise RuntimeError("No default IOStream has been set")
+ iostream = IOStream.get_global_default()
+ # Set the default IOStream of the current context (thread/cooroutine)
+ IOStream.set_default(iostream)
return iostream
- # ContextVar must be used in multithreaded or async environments
- _default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
- _default_io_stream.set(None)
-
@staticmethod
@contextmanager
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
diff --git a/autogen/io/websockets.py b/autogen/io/websockets.py
index 45caffcdcc2..9d38a718754 100644
--- a/autogen/io/websockets.py
+++ b/autogen/io/websockets.py
@@ -4,7 +4,7 @@
from contextlib import contextmanager
from functools import partial
from time import sleep
-from typing import Any, Callable, Dict, Iterable, Iterator, Optional, TYPE_CHECKING, Protocol, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, Optional, Protocol, Union
from .base import IOStream
diff --git a/autogen/logger/base_logger.py b/autogen/logger/base_logger.py
index 97508b4883c..7c35f8a5091 100644
--- a/autogen/logger/base_logger.py
+++ b/autogen/logger/base_logger.py
@@ -1,15 +1,15 @@
from __future__ import annotations
-from abc import ABC, abstractmethod
-from typing import Any, Dict, List, TYPE_CHECKING, Union
import sqlite3
import uuid
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, List, Union
-from openai import OpenAI, AzureOpenAI
+from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
if TYPE_CHECKING:
- from autogen import ConversableAgent, OpenAIWrapper
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
ConfigItem = Dict[str, Union[str, List[str]]]
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
@@ -68,6 +68,18 @@ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any]) -> N
"""
...
+ @abstractmethod
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ """
+ Log an event for an agent.
+
+ Args:
+ source (str or Agent): The source/creator of the event as a string name or an Agent instance
+ name (str): The name of the event
+ kwargs (dict): The event information to log
+ """
+ ...
+
@abstractmethod
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
"""
diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py
new file mode 100644
index 00000000000..f8578474958
--- /dev/null
+++ b/autogen/logger/file_logger.py
@@ -0,0 +1,211 @@
+from __future__ import annotations
+
+import json
+import logging
+import os
+import threading
+import uuid
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+from openai import AzureOpenAI, OpenAI
+from openai.types.chat import ChatCompletion
+
+from autogen.logger.base_logger import BaseLogger
+from autogen.logger.logger_utils import get_current_ts, to_dict
+
+from .base_logger import LLMConfig
+
+if TYPE_CHECKING:
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
+
+logger = logging.getLogger(__name__)
+
+
+class FileLogger(BaseLogger):
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.session_id = str(uuid.uuid4())
+
+ curr_dir = os.getcwd()
+ self.log_dir = os.path.join(curr_dir, "autogen_logs")
+ os.makedirs(self.log_dir, exist_ok=True)
+
+ self.log_file = os.path.join(self.log_dir, self.config.get("filename", "runtime.log"))
+ try:
+ with open(self.log_file, "a"):
+ pass
+ except Exception as e:
+ logger.error(f"[file_logger] Failed to create logging file: {e}")
+
+ self.logger = logging.getLogger(__name__)
+ self.logger.setLevel(logging.INFO)
+ file_handler = logging.FileHandler(self.log_file)
+ self.logger.addHandler(file_handler)
+
+ def start(self) -> str:
+ """Start the logger and return the session_id."""
+ try:
+ self.logger.info(f"Started new session with Session ID: {self.session_id}")
+ except Exception as e:
+ logger.error(f"[file_logger] Failed to create logging file: {e}")
+ finally:
+ return self.session_id
+
+ def log_chat_completion(
+ self,
+ invocation_id: uuid.UUID,
+ client_id: int,
+ wrapper_id: int,
+ request: Dict[str, Union[float, str, List[Dict[str, str]]]],
+ response: Union[str, ChatCompletion],
+ is_cached: int,
+ cost: float,
+ start_time: str,
+ ) -> None:
+ """
+ Log a chat completion.
+ """
+ thread_id = threading.get_ident()
+ try:
+ log_data = json.dumps(
+ {
+ "invocation_id": str(invocation_id),
+ "client_id": client_id,
+ "wrapper_id": wrapper_id,
+ "request": to_dict(request),
+ "response": str(response),
+ "is_cached": is_cached,
+ "cost": cost,
+ "start_time": start_time,
+ "end_time": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log chat completion: {e}")
+
+ def log_new_agent(self, agent: ConversableAgent, init_args: Dict[str, Any] = {}) -> None:
+ """
+ Log a new agent instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "id": id(agent),
+ "agent_name": agent.name if hasattr(agent, "name") and agent.name is not None else "",
+ "wrapper_id": to_dict(
+ agent.client.wrapper_id if hasattr(agent, "client") and agent.client is not None else ""
+ ),
+ "session_id": self.session_id,
+ "current_time": get_current_ts(),
+ "agent_type": type(agent).__name__,
+ "args": to_dict(init_args),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log new agent: {e}")
+
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ """
+ Log an event from an agent or a string source.
+ """
+ from autogen import Agent
+
+ # This takes an object o as input and returns a string. If the object o cannot be serialized, instead of raising an error,
+ # it returns a string indicating that the object is non-serializable, along with its type's qualified name obtained using __qualname__.
+ json_args = json.dumps(kwargs, default=lambda o: f"<>")
+ thread_id = threading.get_ident()
+
+ if isinstance(source, Agent):
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "event_name": name,
+ "agent_module": source.__module__,
+ "agent_class": source.__class__.__name__,
+ "json_state": json_args,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+ else:
+ try:
+ log_data = json.dumps(
+ {
+ "source_id": id(source),
+ "source_name": str(source.name) if hasattr(source, "name") else source,
+ "event_name": name,
+ "json_state": json_args,
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_new_wrapper(
+ self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]] = {}
+ ) -> None:
+ """
+ Log a new wrapper instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "wrapper_id": id(wrapper),
+ "session_id": self.session_id,
+ "json_state": json.dumps(init_args),
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def log_new_client(self, client: AzureOpenAI | OpenAI, wrapper: OpenAIWrapper, init_args: Dict[str, Any]) -> None:
+ """
+ Log a new client instance.
+ """
+ thread_id = threading.get_ident()
+
+ try:
+ log_data = json.dumps(
+ {
+ "client_id": id(client),
+ "wrapper_id": id(wrapper),
+ "session_id": self.session_id,
+ "class": type(client).__name__,
+ "json_state": json.dumps(init_args),
+ "timestamp": get_current_ts(),
+ "thread_id": thread_id,
+ }
+ )
+ self.logger.info(log_data)
+ except Exception as e:
+ self.logger.error(f"[file_logger] Failed to log event {e}")
+
+ def get_connection(self) -> None:
+ """Method is intentionally left blank because there is no specific connection needed for the FileLogger."""
+ pass
+
+ def stop(self) -> None:
+ """Close the file handler and remove it from the logger."""
+ for handler in self.logger.handlers:
+ if isinstance(handler, logging.FileHandler):
+ handler.close()
+ self.logger.removeHandler(handler)
diff --git a/autogen/logger/logger_factory.py b/autogen/logger/logger_factory.py
index 282efc3263e..ed9567977bb 100644
--- a/autogen/logger/logger_factory.py
+++ b/autogen/logger/logger_factory.py
@@ -1,5 +1,7 @@
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Literal, Optional
+
from autogen.logger.base_logger import BaseLogger
+from autogen.logger.file_logger import FileLogger
from autogen.logger.sqlite_logger import SqliteLogger
__all__ = ("LoggerFactory",)
@@ -7,11 +9,15 @@
class LoggerFactory:
@staticmethod
- def get_logger(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> BaseLogger:
+ def get_logger(
+ logger_type: Literal["sqlite", "file"] = "sqlite", config: Optional[Dict[str, Any]] = None
+ ) -> BaseLogger:
if config is None:
config = {}
if logger_type == "sqlite":
return SqliteLogger(config)
+ elif logger_type == "file":
+ return FileLogger(config)
else:
raise ValueError(f"[logger_factory] Unknown logger type: {logger_type}")
diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py
index 227d57f91ee..6e95a571cd0 100644
--- a/autogen/logger/sqlite_logger.py
+++ b/autogen/logger/sqlite_logger.py
@@ -6,18 +6,18 @@
import sqlite3
import threading
import uuid
+from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
+
+from openai import AzureOpenAI, OpenAI
+from openai.types.chat import ChatCompletion
from autogen.logger.base_logger import BaseLogger
from autogen.logger.logger_utils import get_current_ts, to_dict
-from openai import OpenAI, AzureOpenAI
-from openai.types.chat import ChatCompletion
-from typing import Any, Dict, List, TYPE_CHECKING, Tuple, Union
from .base_logger import LLMConfig
-
if TYPE_CHECKING:
- from autogen import ConversableAgent, OpenAIWrapper
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
logger = logging.getLogger(__name__)
lock = threading.Lock()
@@ -103,6 +103,20 @@ class TEXT, -- type or class name of cli
"""
self._run_query(query=query)
+ query = """
+ CREATE TABLE IF NOT EXISTS events (
+ event_name TEXT,
+ source_id INTEGER,
+ source_name TEXT,
+ agent_module TEXT DEFAULT NULL,
+ agent_class_name TEXT DEFAULT NULL,
+ id INTEGER PRIMARY KEY,
+ json_state TEXT,
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
+ );
+ """
+ self._run_query(query=query)
+
current_verion = self._get_current_db_version()
if current_verion is None:
self._run_query(
@@ -246,6 +260,41 @@ class = excluded.class,
)
self._run_query(query=query, args=args)
+ def log_event(self, source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ from autogen import Agent
+
+ if self.con is None:
+ return
+
+ json_args = json.dumps(kwargs, default=lambda o: f"<>")
+
+ if isinstance(source, Agent):
+ query = """
+ INSERT INTO events (source_id, source_name, event_name, agent_module, agent_class_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?)
+ """
+ args = (
+ id(source),
+ source.name if hasattr(source, "name") else source,
+ name,
+ source.__module__,
+ source.__class__.__name__,
+ json_args,
+ get_current_ts(),
+ )
+ self._run_query(query=query, args=args)
+ else:
+ query = """
+ INSERT INTO events (source_id, source_name, event_name, json_state, timestamp) VALUES (?, ?, ?, ?, ?)
+ """
+ args_str_based = (
+ id(source),
+ source.name if hasattr(source, "name") else source,
+ name,
+ json_args,
+ get_current_ts(),
+ )
+ self._run_query(query=query, args=args_str_based)
+
def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if self.con is None:
return
diff --git a/autogen/math_utils.py b/autogen/math_utils.py
index 00fcae57ad2..aeac2fc9715 100644
--- a/autogen/math_utils.py
+++ b/autogen/math_utils.py
@@ -1,5 +1,6 @@
from typing import Optional
-from autogen import oai, DEFAULT_MODEL
+
+from autogen import DEFAULT_MODEL, oai
_MATH_PROMPT = "{problem} Solve the problem carefully. Simplify your answer as much as possible. Put the final answer in \\boxed{{}}."
_MATH_CONFIG = {
diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py
index 9e8437cecc7..d5e2000f759 100644
--- a/autogen/oai/__init__.py
+++ b/autogen/oai/__init__.py
@@ -1,15 +1,15 @@
-from autogen.oai.client import OpenAIWrapper, ModelClient
-from autogen.oai.completion import Completion, ChatCompletion
+from autogen.cache.cache import Cache
+from autogen.oai.client import ModelClient, OpenAIWrapper
+from autogen.oai.completion import ChatCompletion, Completion
from autogen.oai.openai_utils import (
- get_config_list,
+ config_list_from_dotenv,
+ config_list_from_json,
+ config_list_from_models,
config_list_gpt4_gpt35,
config_list_openai_aoai,
- config_list_from_models,
- config_list_from_json,
- config_list_from_dotenv,
filter_config,
+ get_config_list,
)
-from autogen.cache.cache import Cache
__all__ = [
"OpenAIWrapper",
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index f288ece3961..3edfa40d4ec 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -1,22 +1,20 @@
from __future__ import annotations
-import sys
-from typing import Any, List, Optional, Dict, Callable, Tuple, Union
-import logging
import inspect
+import logging
+import sys
import uuid
-from flaml.automl.logger import logger_formatter
+from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
+from flaml.automl.logger import logger_formatter
from pydantic import BaseModel
-from typing import Protocol
from autogen.cache import Cache
from autogen.io.base import IOStream
-from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
-from autogen.token_count_utils import count_token
-
-from autogen.runtime_logging import logging_enabled, log_chat_completion, log_new_client, log_new_wrapper
from autogen.logger.logger_utils import get_current_ts
+from autogen.oai.openai_utils import OAI_PRICE1K, get_key, is_valid_api_key
+from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
+from autogen.token_count_utils import count_token
TOOL_ENABLED = False
try:
@@ -27,14 +25,15 @@
AzureOpenAI = object
else:
# raises exception if openai>=1 is installed and something is wrong with imports
- from openai import OpenAI, AzureOpenAI, APIError, APITimeoutError, __version__ as OPENAIVERSION
+ from openai import APIError, APITimeoutError, AzureOpenAI, OpenAI
+ from openai import __version__ as OPENAIVERSION
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaFunctionCall,
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
- ChoiceDeltaFunctionCall,
)
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
@@ -43,6 +42,13 @@
TOOL_ENABLED = True
ERROR = None
+try:
+ from autogen.oai.gemini import GeminiClient
+
+ gemini_import_exception: Optional[ImportError] = None
+except ImportError as e:
+ gemini_import_exception = e
+
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@@ -290,6 +296,8 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
n_input_tokens = response.usage.prompt_tokens if response.usage is not None else 0 # type: ignore [union-attr]
n_output_tokens = response.usage.completion_tokens if response.usage is not None else 0 # type: ignore [union-attr]
+ if n_output_tokens is None:
+ n_output_tokens = 0
tmp_price1K = OAI_PRICE1K[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
@@ -424,6 +432,10 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
+ elif api_type is not None and api_type.startswith("google"):
+ if gemini_import_exception:
+ raise ImportError("Please install `google-generativeai` to use Google OpenAI API.")
+ self._clients.append(GeminiClient(**openai_config))
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
@@ -806,6 +818,8 @@ def update_usage(usage_summary, response_usage):
cost = response_usage["cost"]
prompt_tokens = response_usage["prompt_tokens"]
completion_tokens = response_usage["completion_tokens"]
+ if completion_tokens is None:
+ completion_tokens = 0
total_tokens = response_usage["total_tokens"]
if usage_summary is None:
diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py
index 43ccd0b3bc2..e3b01ee4dd8 100644
--- a/autogen/oai/completion.py
+++ b/autogen/oai/completion.py
@@ -1,28 +1,30 @@
-from time import sleep
import logging
-import time
-from typing import List, Optional, Dict, Callable, Union
-import sys
import shutil
+import sys
+import time
+from collections import defaultdict
+from time import sleep
+from typing import Callable, Dict, List, Optional, Union
+
import numpy as np
-from flaml import tune, BlendSearch
-from flaml.tune.space import is_constant
+from flaml import BlendSearch, tune
from flaml.automl.logger import logger_formatter
+from flaml.tune.space import is_constant
+
from .openai_utils import get_key
-from collections import defaultdict
try:
+ import diskcache
import openai
from openai import (
- RateLimitError,
+ APIConnectionError,
APIError,
+ AuthenticationError,
BadRequestError,
- APIConnectionError,
+ RateLimitError,
Timeout,
- AuthenticationError,
)
from openai import Completion as openai_Completion
- import diskcache
ERROR = None
assert openai.__version__ < "1"
diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py
new file mode 100644
index 00000000000..fcf7e09c025
--- /dev/null
+++ b/autogen/oai/gemini.py
@@ -0,0 +1,310 @@
+"""Create a OpenAI-compatible client for Gemini features.
+
+
+Example:
+ llm_config={
+ "config_list": [{
+ "api_type": "google",
+ "model": "models/gemini-pro",
+ "api_key": os.environ.get("GOOGLE_API_KEY")
+ }
+ ]}
+
+ agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
+
+Resources:
+- https://ai.google.dev/docs
+- https://cloud.google.com/vertex-ai/docs/generative-ai/migrate-from-azure
+- https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
+- https://ai.google.dev/api/python/google/generativeai/ChatSession
+"""
+
+from __future__ import annotations
+
+import base64
+import os
+import random
+import re
+import time
+import warnings
+from io import BytesIO
+from typing import Any, Dict, List, Mapping, Union
+
+import google.generativeai as genai
+import requests
+from google.ai.generativelanguage import Content, Part
+from google.api_core.exceptions import InternalServerError
+from openai.types.chat import ChatCompletion
+from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+from openai.types.completion_usage import CompletionUsage
+from PIL import Image
+
+
+class GeminiClient:
+ """Client for Google's Gemini API.
+
+ Please visit this [page](https://github.com/microsoft/autogen/issues/2387) for the roadmap of Gemini integration
+ of AutoGen.
+ """
+
+ def __init__(self, **kwargs):
+ self.api_key = kwargs.get("api_key", None)
+ if not self.api_key:
+ self.api_key = os.getenv("GOOGLE_API_KEY")
+
+ assert (
+ self.api_key
+ ), "Please provide api_key in your config list entry for Gemini or set the GOOGLE_API_KEY env variable."
+
+ def message_retrieval(self, response) -> List:
+ """
+ Retrieve and return a list of strings or a list of Choice.Message from the response.
+
+ NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
+ since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
+ """
+ return [choice.message for choice in response.choices]
+
+ def cost(self, response) -> float:
+ return response.cost
+
+ @staticmethod
+ def get_usage(response) -> Dict:
+ """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
+ # ... # pragma: no cover
+ return {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ "cost": response.cost,
+ "model": response.model,
+ }
+
+ def create(self, params: Dict) -> ChatCompletion:
+ model_name = params.get("model", "gemini-pro")
+ if not model_name:
+ raise ValueError(
+ "Please provide a model name for the Gemini Client. "
+ "You can configurate it in the OAI Config List file. "
+ "See this [LLM configuration tutorial](https://microsoft.github.io/autogen/docs/topics/llm_configuration/) for more details."
+ )
+
+ params.get("api_type", "google") # not used
+ messages = params.get("messages", [])
+ stream = params.get("stream", False)
+ n_response = params.get("n", 1)
+ params.get("temperature", 0.5)
+ params.get("top_p", 1.0)
+ params.get("max_tokens", 4096)
+
+ if stream:
+ # warn user that streaming is not supported
+ warnings.warn(
+ "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
+ UserWarning,
+ )
+
+ if n_response > 1:
+ warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)
+
+ if "vision" not in model_name:
+ # A. create and call the chat model.
+ gemini_messages = oai_messages_to_gemini_messages(messages)
+
+ # we use chat model by default
+ model = genai.GenerativeModel(model_name)
+ genai.configure(api_key=self.api_key)
+ chat = model.start_chat(history=gemini_messages[:-1])
+ max_retries = 5
+ for attempt in range(max_retries):
+ ans = None
+ try:
+ response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
+ except InternalServerError:
+ delay = 5 * (2**attempt)
+ warnings.warn(
+ f"InternalServerError `500` occurs when calling Gemini's chat model. Retry in {delay} seconds...",
+ UserWarning,
+ )
+ time.sleep(delay)
+ except Exception as e:
+ raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}")
+ else:
+ # `ans = response.text` is unstable. Use the following code instead.
+ ans: str = chat.history[-1].parts[0].text
+ break
+
+ if ans is None:
+ raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.")
+
+ prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
+ completion_tokens = model.count_tokens(ans).total_tokens
+ elif model_name == "gemini-pro-vision":
+ # B. handle the vision model
+ # Gemini's vision model does not support chat history yet
+ model = genai.GenerativeModel(model_name)
+ genai.configure(api_key=self.api_key)
+ # chat = model.start_chat(history=gemini_messages[:-1])
+ # response = chat.send_message(gemini_messages[-1])
+ user_message = oai_content_to_gemini_content(messages[-1]["content"])
+ if len(messages) > 2:
+ warnings.warn(
+ "Warning: Gemini's vision model does not support chat history yet.",
+ "We only use the last message as the prompt.",
+ UserWarning,
+ )
+
+ response = model.generate_content(user_message, stream=stream)
+ # ans = response.text
+ ans: str = response._result.candidates[0].content.parts[0].text
+
+ prompt_tokens = model.count_tokens(user_message).total_tokens
+ completion_tokens = model.count_tokens(ans).total_tokens
+
+ # 3. convert output
+ message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
+ choices = [Choice(finish_reason="stop", index=0, message=message)]
+
+ response_oai = ChatCompletion(
+ id=str(random.randint(0, 1000)),
+ model=model_name,
+ created=int(time.time() * 1000),
+ object="chat.completion",
+ choices=choices,
+ usage=CompletionUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
+ )
+
+ return response_oai
+
+
+def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
+ if "1.5" in model_name or "gemini-experimental" in model_name:
+ # "gemini-1.5-pro-preview-0409"
+ # Cost is $7 per million input tokens and $21 per million output tokens
+ return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
+
+ if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
+ warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
+
+ # Cost is $0.5 per million input tokens and $1.5 per million output tokens
+ return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
+
+
+def oai_content_to_gemini_content(content: Union[str, List]) -> List:
+ """Convert content from OAI format to Gemini format"""
+ rst = []
+ if isinstance(content, str):
+ rst.append(Part(text=content))
+ return rst
+
+ assert isinstance(content, list)
+
+ for msg in content:
+ if isinstance(msg, dict):
+ assert "type" in msg, f"Missing 'type' field in message: {msg}"
+ if msg["type"] == "text":
+ rst.append(Part(text=msg["text"]))
+ elif msg["type"] == "image_url":
+ b64_img = get_image_data(msg["image_url"]["url"])
+ img = _to_pil(b64_img)
+ rst.append(img)
+ else:
+ raise ValueError(f"Unsupported message type: {msg['type']}")
+ else:
+ raise ValueError(f"Unsupported message type: {type(msg)}")
+ return rst
+
+
+def concat_parts(parts: List[Part]) -> List:
+ """Concatenate parts with the same type.
+ If two adjacent parts both have the "text" attribute, then it will be joined into one part.
+ """
+ if not parts:
+ return []
+
+ concatenated_parts = []
+ previous_part = parts[0]
+
+ for current_part in parts[1:]:
+ if previous_part.text != "":
+ previous_part.text += current_part.text
+ else:
+ concatenated_parts.append(previous_part)
+ previous_part = current_part
+
+ if previous_part.text == "":
+ previous_part.text = "empty" # Empty content is not allowed.
+ concatenated_parts.append(previous_part)
+
+ return concatenated_parts
+
+
+def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
+ """Convert messages from OAI format to Gemini format.
+ Make sure the "user" role and "model" role are interleaved.
+ Also, make sure the last item is from the "user" role.
+ """
+ prev_role = None
+ rst = []
+ curr_parts = []
+ for i, message in enumerate(messages):
+ parts = oai_content_to_gemini_content(message["content"])
+ role = "user" if message["role"] in ["user", "system"] else "model"
+
+ if prev_role is None or role == prev_role:
+ curr_parts += parts
+ elif role != prev_role:
+ rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
+ curr_parts = parts
+ prev_role = role
+
+ # handle the last message
+ rst.append(Content(parts=concat_parts(curr_parts), role=role))
+
+ # The Gemini is restrict on order of roles, such that
+ # 1. The messages should be interleaved between user and model.
+ # 2. The last message must be from the user role.
+ # We add a dummy message "continue" if the last role is not the user.
+ if rst[-1].role != "user":
+ rst.append(Content(parts=oai_content_to_gemini_content("continue"), role="user"))
+
+ return rst
+
+
+def _to_pil(data: str) -> Image.Image:
+ """
+ Converts a base64 encoded image data string to a PIL Image object.
+
+ This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
+ and finally creates and returns a PIL Image object from the BytesIO object.
+
+ Parameters:
+ data (str): The base64 encoded image data string.
+
+ Returns:
+ Image.Image: The PIL Image object created from the input data.
+ """
+ return Image.open(BytesIO(base64.b64decode(data)))
+
+
+def get_image_data(image_file: str, use_b64=True) -> bytes:
+ if image_file.startswith("http://") or image_file.startswith("https://"):
+ response = requests.get(image_file)
+ content = response.content
+ elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
+ return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
+ else:
+ image = Image.open(image_file).convert("RGB")
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ content = buffered.getvalue()
+
+ if use_b64:
+ return base64.b64encode(content).decode("utf-8")
+ else:
+ return content
diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py
index 411ac03f003..f8dad6d7984 100644
--- a/autogen/oai/openai_utils.py
+++ b/autogen/oai/openai_utils.py
@@ -1,19 +1,46 @@
+import importlib.metadata
import json
import logging
import os
import re
import tempfile
+import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from dotenv import find_dotenv, load_dotenv
-
from openai import OpenAI
from openai.types.beta.assistant import Assistant
+from packaging.version import parse
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
OAI_PRICE1K = {
+ # https://openai.com/api/pricing/
+ # gpt-4o
+ "gpt-4o": (0.005, 0.015),
+ "gpt-4o-2024-05-13": (0.005, 0.015),
+ # gpt-4-turbo
+ "gpt-4-turbo-2024-04-09": (0.01, 0.03),
+ # gpt-4
+ "gpt-4": (0.03, 0.06),
+ "gpt-4-32k": (0.06, 0.12),
+ # gpt-3.5 turbo
+ "gpt-3.5-turbo": (0.0005, 0.0015), # default is 0125
+ "gpt-3.5-turbo-0125": (0.0005, 0.0015), # 16k
+ "gpt-3.5-turbo-instruct": (0.0015, 0.002),
+ # base model
+ "davinci-002": 0.002,
+ "babbage-002": 0.0004,
+ # old model
+ "gpt-4-0125-preview": (0.01, 0.03),
+ "gpt-4-1106-preview": (0.01, 0.03),
+ "gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
+ "gpt-3.5-turbo-1106": (0.001, 0.002),
+ "gpt-3.5-turbo-0613": (0.0015, 0.002),
+ # "gpt-3.5-turbo-16k": (0.003, 0.004),
+ "gpt-3.5-turbo-16k-0613": (0.003, 0.004),
+ "gpt-3.5-turbo-0301": (0.0015, 0.002),
"text-ada-001": 0.0004,
"text-babbage-001": 0.0005,
"text-curie-001": 0.002,
@@ -21,28 +48,20 @@
"code-davinci-002": 0.1,
"text-davinci-002": 0.02,
"text-davinci-003": 0.02,
- "gpt-3.5-turbo-instruct": (0.0015, 0.002),
- "gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
- "gpt-3.5-turbo-0613": (0.0015, 0.002),
- "gpt-3.5-turbo-16k": (0.003, 0.004),
- "gpt-3.5-turbo-16k-0613": (0.003, 0.004),
- "gpt-35-turbo": (0.0015, 0.002),
- "gpt-35-turbo-16k": (0.003, 0.004),
- "gpt-35-turbo-instruct": (0.0015, 0.002),
- "gpt-4": (0.03, 0.06),
- "gpt-4-32k": (0.06, 0.12),
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
"gpt-4-0613": (0.03, 0.06),
"gpt-4-32k-0613": (0.06, 0.12),
- # 11-06
- "gpt-3.5-turbo": (0.0015, 0.002), # default is still 0613
- "gpt-3.5-turbo-1106": (0.001, 0.002),
- "gpt-35-turbo-1106": (0.001, 0.002),
- "gpt-4-1106-preview": (0.01, 0.03),
- "gpt-4-0125-preview": (0.01, 0.03),
"gpt-4-turbo-preview": (0.01, 0.03),
- "gpt-4-1106-vision-preview": (0.01, 0.03), # TODO: support vision pricing of images
+ # https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/#pricing
+ "gpt-35-turbo": (0.0005, 0.0015), # what's the default? using 0125 here.
+ "gpt-35-turbo-0125": (0.0005, 0.0015),
+ "gpt-35-turbo-instruct": (0.0015, 0.002),
+ "gpt-35-turbo-1106": (0.001, 0.002),
+ "gpt-35-turbo-0613": (0.0015, 0.002),
+ "gpt-35-turbo-0301": (0.0015, 0.002),
+ "gpt-35-turbo-16k": (0.003, 0.004),
+ "gpt-35-turbo-16k-0613": (0.003, 0.004),
}
@@ -77,7 +96,7 @@ def is_valid_api_key(api_key: str) -> bool:
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
- api_key_re = re.compile(r"^sk-[A-Za-z0-9]{32,}$")
+ api_key_re = re.compile(r"^sk-(proj-)?[A-Za-z0-9]{32,}$")
return bool(re.fullmatch(api_key_re, api_key))
@@ -541,11 +560,11 @@ def get_config(
"""
config = {"api_key": api_key}
if base_url:
- config["base_url"] = base_url
+ config["base_url"] = os.getenv(base_url, default=base_url)
if api_type:
- config["api_type"] = api_type
+ config["api_type"] = os.getenv(api_type, default=api_type)
if api_version:
- config["api_version"] = api_version
+ config["api_version"] = os.getenv(api_version, default=api_version)
return config
@@ -662,3 +681,107 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
if assistant.name == name:
candidate_assistants.append(assistant)
return candidate_assistants
+
+
+def detect_gpt_assistant_api_version() -> str:
+ """Detect the openai assistant API version"""
+ oai_version = importlib.metadata.version("openai")
+ if parse(oai_version) < parse("1.21"):
+ return "v1"
+ else:
+ return "v2"
+
+
+def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> Any:
+ """Create a openai vector store for gpt assistant"""
+
+ try:
+ vector_store = client.beta.vector_stores.create(name=name)
+ except Exception as e:
+ raise AttributeError(f"Failed to create vector store, please install the latest OpenAI python package: {e}")
+
+ # poll the status of the file batch for completion.
+ batch = client.beta.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)
+
+ if batch.status == "in_progress":
+ time.sleep(1)
+ logging.debug(f"file batch status: {batch.file_counts}")
+ batch = client.beta.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)
+
+ if batch.status == "completed":
+ return vector_store
+
+ raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")
+
+
+def create_gpt_assistant(
+ client: OpenAI, name: str, instructions: str, model: str, assistant_config: Dict[str, Any]
+) -> Assistant:
+ """Create a openai gpt assistant"""
+
+ assistant_create_kwargs = {}
+ gpt_assistant_api_version = detect_gpt_assistant_api_version()
+ tools = assistant_config.get("tools", [])
+
+ if gpt_assistant_api_version == "v2":
+ tool_resources = assistant_config.get("tool_resources", {})
+ file_ids = assistant_config.get("file_ids")
+ if tool_resources.get("file_search") is not None and file_ids is not None:
+ raise ValueError(
+ "Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
+ )
+
+ # Designed for backwards compatibility for the V1 API
+ # Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
+ for tool in tools:
+ if tool["type"] == "retrieval":
+ tool["type"] = "file_search"
+ if file_ids is not None:
+ # create a vector store for the file search tool
+ vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
+ tool_resources["file_search"] = {
+ "vector_store_ids": [vs.id],
+ }
+ elif tool["type"] == "code_interpreter" and file_ids is not None:
+ tool_resources["code_interpreter"] = {
+ "file_ids": file_ids,
+ }
+
+ assistant_create_kwargs["tools"] = tools
+ if len(tool_resources) > 0:
+ assistant_create_kwargs["tool_resources"] = tool_resources
+ else:
+ # not support forwards compatibility
+ if "tool_resources" in assistant_config:
+ raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
+ if any(tool["type"] == "file_search" for tool in tools):
+ raise ValueError(
+ "`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
+ )
+ assistant_create_kwargs["tools"] = tools
+ assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])
+
+ logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
+ return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)
+
+
+def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Dict[str, Any]) -> Assistant:
+ """Update openai gpt assistant"""
+
+ gpt_assistant_api_version = detect_gpt_assistant_api_version()
+ assistant_update_kwargs = {}
+
+ if assistant_config.get("tools") is not None:
+ assistant_update_kwargs["tools"] = assistant_config["tools"]
+
+ if assistant_config.get("instructions") is not None:
+ assistant_update_kwargs["instructions"] = assistant_config["instructions"]
+
+ if gpt_assistant_api_version == "v2":
+ if assistant_config.get("tool_resources") is not None:
+ assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
+ else:
+ if assistant_config.get("file_ids") is not None:
+ assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
+
+ return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py
index c3e50c7a96f..9393903ec86 100644
--- a/autogen/retrieve_utils.py
+++ b/autogen/retrieve_utils.py
@@ -1,18 +1,25 @@
-from typing import List, Union, Callable
+import glob
+import hashlib
import os
-import requests
+import re
+from typing import Callable, List, Tuple, Union
from urllib.parse import urlparse
-import glob
+
import chromadb
+import markdownify
+import requests
+from bs4 import BeautifulSoup
if chromadb.__version__ < "0.4.15":
from chromadb.api import API
else:
from chromadb.api import ClientAPI as API
-from chromadb.api.types import QueryResult
-import chromadb.utils.embedding_functions as ef
import logging
+
+import chromadb.utils.embedding_functions as ef
import pypdf
+from chromadb.api.types import QueryResult
+
from autogen.token_count_utils import count_token
try:
@@ -58,6 +65,7 @@
TEXT_FORMATS += UNSTRUCTURED_FORMATS
TEXT_FORMATS = list(set(TEXT_FORMATS))
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})
+RAG_MINIMUM_MESSAGE_LENGTH = int(os.environ.get("RAG_MINIMUM_MESSAGE_LENGTH", 5))
def split_text_to_chunks(
@@ -65,22 +73,27 @@ def split_text_to_chunks(
max_tokens: int = 4000,
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
- overlap: int = 10,
+ overlap: int = 0, # number of overlapping lines
):
"""Split a long text into chunks of max_tokens."""
if chunk_mode not in VALID_CHUNK_MODES:
raise AssertionError
if chunk_mode == "one_line":
must_break_at_empty_line = False
+ overlap = 0
chunks = []
lines = text.split("\n")
+ num_lines = len(lines)
+ if num_lines < 3 and must_break_at_empty_line:
+ logger.warning("The input text has less than 3 lines. Set `must_break_at_empty_line` to `False`")
+ must_break_at_empty_line = False
lines_tokens = [count_token(line) for line in lines]
sum_tokens = sum(lines_tokens)
while sum_tokens > max_tokens:
if chunk_mode == "one_line":
estimated_line_cut = 2
else:
- estimated_line_cut = int(max_tokens / sum_tokens * len(lines)) + 1
+ estimated_line_cut = max(int(max_tokens / sum_tokens * len(lines)), 2)
cnt = 0
prev = ""
for cnt in reversed(range(estimated_line_cut)):
@@ -94,19 +107,25 @@ def split_text_to_chunks(
f"max_tokens is too small to fit a single line of text. Breaking this line:\n\t{lines[0][:100]} ..."
)
if not must_break_at_empty_line:
- split_len = int(max_tokens / lines_tokens[0] * 0.9 * len(lines[0]))
+ split_len = max(
+ int(max_tokens / (lines_tokens[0] * 0.9 * len(lines[0]) + 0.1)), RAG_MINIMUM_MESSAGE_LENGTH
+ )
prev = lines[0][:split_len]
lines[0] = lines[0][split_len:]
lines_tokens[0] = count_token(lines[0])
else:
logger.warning("Failed to split docs with must_break_at_empty_line being True, set to False.")
must_break_at_empty_line = False
- chunks.append(prev) if len(prev) > 10 else None # don't add chunks less than 10 characters
- lines = lines[cnt:]
- lines_tokens = lines_tokens[cnt:]
+ (
+ chunks.append(prev) if len(prev) >= RAG_MINIMUM_MESSAGE_LENGTH else None
+ ) # don't add chunks less than RAG_MINIMUM_MESSAGE_LENGTH characters
+ lines = lines[cnt - overlap if cnt > overlap else cnt :]
+ lines_tokens = lines_tokens[cnt - overlap if cnt > overlap else cnt :]
sum_tokens = sum(lines_tokens)
- text_to_chunk = "\n".join(lines)
- chunks.append(text_to_chunk) if len(text_to_chunk) > 10 else None # don't add chunks less than 10 characters
+ text_to_chunk = "\n".join(lines).strip()
+ (
+ chunks.append(text_to_chunk) if len(text_to_chunk) >= RAG_MINIMUM_MESSAGE_LENGTH else None
+ ) # don't add chunks less than RAG_MINIMUM_MESSAGE_LENGTH characters
return chunks
@@ -138,12 +157,18 @@ def split_files_to_chunks(
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
custom_text_split_function: Callable = None,
-):
+) -> Tuple[List[str], List[dict]]:
"""Split a list of files into chunks of max_tokens."""
chunks = []
+ sources = []
for file in files:
+ if isinstance(file, tuple):
+ url = file[1]
+ file = file[0]
+ else:
+ url = None
_, file_extension = os.path.splitext(file)
file_extension = file_extension.lower()
@@ -161,11 +186,13 @@ def split_files_to_chunks(
continue # Skip to the next file if no text is available
if custom_text_split_function is not None:
- chunks += custom_text_split_function(text)
+ tmp_chunks = custom_text_split_function(text)
else:
- chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
+ tmp_chunks = split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line)
+ chunks += tmp_chunks
+ sources += [{"source": url if url else file}] * len(tmp_chunks)
- return chunks
+ return chunks, sources
def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
@@ -182,7 +209,9 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
if os.path.isfile(item):
files.append(item)
elif is_url(item):
- files.append(get_file_from_url(item))
+ filepath = get_file_from_url(item)
+ if filepath:
+ files.append(filepath)
elif os.path.exists(item):
try:
files.extend(get_files_from_dir(item, types, recursive))
@@ -198,7 +227,11 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
# If the path is a url, download it and return the downloaded file
if is_url(dir_path):
- return [get_file_from_url(dir_path)]
+ filepath = get_file_from_url(dir_path)
+ if filepath:
+ return [filepath]
+ else:
+ return []
if os.path.exists(dir_path):
for type in types:
@@ -212,19 +245,81 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
return files
-def get_file_from_url(url: str, save_path: str = None):
+def parse_html_to_markdown(html: str, url: str = None) -> str:
+ """Parse HTML to markdown."""
+ soup = BeautifulSoup(html, "html.parser")
+ title = soup.title.string
+ # Remove javascript and style blocks
+ for script in soup(["script", "style"]):
+ script.extract()
+
+ # Convert to markdown -- Wikipedia gets special attention to get a clean version of the page
+ if isinstance(url, str) and url.startswith("https://en.wikipedia.org/"):
+ body_elm = soup.find("div", {"id": "mw-content-text"})
+ title_elm = soup.find("span", {"class": "mw-page-title-main"})
+
+ if body_elm:
+ # What's the title
+ main_title = soup.title.string
+ if title_elm and len(title_elm) > 0:
+ main_title = title_elm.string
+ webpage_text = "# " + main_title + "\n\n" + markdownify.MarkdownConverter().convert_soup(body_elm)
+ else:
+ webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
+ else:
+ webpage_text = markdownify.MarkdownConverter().convert_soup(soup)
+
+ # Convert newlines
+ webpage_text = re.sub(r"\r\n", "\n", webpage_text)
+ webpage_text = re.sub(r"\n{2,}", "\n\n", webpage_text).strip()
+ webpage_text = "# " + title + "\n\n" + webpage_text
+ return webpage_text
+
+
+def _generate_file_name_from_url(url: str, max_length=255) -> str:
+ url_bytes = url.encode("utf-8")
+ hash = hashlib.blake2b(url_bytes).hexdigest()
+ parsed_url = urlparse(url)
+ file_name = os.path.basename(url)
+ file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
+ return file_name
+
+
+def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
"""Download a file from a URL."""
if save_path is None:
- os.makedirs("/tmp/chromadb", exist_ok=True)
- save_path = os.path.join("/tmp/chromadb", os.path.basename(url))
+ save_path = "tmp/chromadb"
+ os.makedirs(save_path, exist_ok=True)
+ if os.path.isdir(save_path):
+ filename = _generate_file_name_from_url(url)
+ save_path = os.path.join(save_path, filename)
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
- with requests.get(url, stream=True) as r:
- r.raise_for_status()
+
+ custom_headers = {
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36",
+ }
+ try:
+ response = requests.get(url, stream=True, headers=custom_headers, timeout=30)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ logger.warning(f"Failed to download {url}, {e}")
+ return None
+
+ content_type = response.headers.get("content-type", "")
+ if "text/html" in content_type:
+ # Get the content of the response
+ html = ""
+ for chunk in response.iter_content(chunk_size=8192, decode_unicode=True):
+ html += chunk
+ text = parse_html_to_markdown(html, url)
+ with open(save_path, "w", encoding="utf-8") as f:
+ f.write(text)
+ else:
with open(save_path, "wb") as f:
- for chunk in r.iter_content(chunk_size=8192):
+ for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
- return save_path
+ return save_path, url
def is_url(string: str):
@@ -240,7 +335,7 @@ def create_vector_db_from_dir(
dir_path: Union[str, List[str]],
max_tokens: int = 4000,
client: API = None,
- db_path: str = "/tmp/chromadb.db",
+ db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
get_or_create: bool = False,
chunk_mode: str = "multi_lines",
@@ -260,7 +355,7 @@ def create_vector_db_from_dir(
dir_path (Union[str, List[str]]): the path to the directory, file, url or a list of them.
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
client (Optional, API): the chromadb client. Default is None.
- db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
+ db_path (Optional, str): the path to the chromadb. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
will be returned if it already exists. Will raise ValueError if the collection already exists and get_or_create is False.
@@ -304,12 +399,12 @@ def create_vector_db_from_dir(
length = len(collection.get()["ids"])
if custom_text_split_function is not None:
- chunks = split_files_to_chunks(
+ chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
- chunks = split_files_to_chunks(
+ chunks, sources = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive),
max_tokens,
chunk_mode,
@@ -322,6 +417,7 @@ def create_vector_db_from_dir(
collection.upsert(
documents=chunks[i:end_idx],
ids=[f"doc_{j+length}" for j in range(i, end_idx)], # unique for each doc
+ metadatas=sources[i:end_idx],
)
except ValueError as e:
logger.warning(f"{e}")
@@ -332,7 +428,7 @@ def query_vector_db(
query_texts: List[str],
n_results: int = 10,
client: API = None,
- db_path: str = "/tmp/chromadb.db",
+ db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
@@ -345,7 +441,7 @@ def query_vector_db(
query_texts (List[str]): the list of strings which will be used to query the vector db.
n_results (Optional, int): the number of results to return. Default is 10.
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
- db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
+ db_path (Optional, str): the path to the vector db. Default is "tmp/chromadb.db". The default was `/tmp/chromadb.db` for version <=0.2.24.
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
search_string (Optional, str): the search string. Only docs that contain an exact match of this string will be retrieved. Default is "".
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py
index 94d7460cd30..d848ca3645e 100644
--- a/autogen/runtime_logging.py
+++ b/autogen/runtime_logging.py
@@ -1,18 +1,18 @@
from __future__ import annotations
-from autogen.logger.logger_factory import LoggerFactory
-from autogen.logger.base_logger import LLMConfig
-
import logging
import sqlite3
-from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
import uuid
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
-from openai import OpenAI, AzureOpenAI
+from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletion
+from autogen.logger.base_logger import BaseLogger, LLMConfig
+from autogen.logger.logger_factory import LoggerFactory
+
if TYPE_CHECKING:
- from autogen import ConversableAgent, OpenAIWrapper
+ from autogen import Agent, ConversableAgent, OpenAIWrapper
logger = logging.getLogger(__name__)
@@ -20,11 +20,27 @@
is_logging = False
-def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None) -> str:
+def start(
+ logger: Optional[BaseLogger] = None,
+ logger_type: Literal["sqlite", "file"] = "sqlite",
+ config: Optional[Dict[str, Any]] = None,
+) -> str:
+ """
+ Start logging for the runtime.
+ Args:
+ logger (BaseLogger): A logger instance
+ logger_type (str): The type of logger to use (default: sqlite)
+ config (dict): Configuration for the logger
+ Returns:
+ session_id (str(uuid.uuid4)): a unique id for the logging session
+ """
global autogen_logger
global is_logging
- autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
+ if logger:
+ autogen_logger = logger
+ else:
+ autogen_logger = LoggerFactory.get_logger(logger_type=logger_type, config=config)
try:
session_id = autogen_logger.start()
@@ -62,6 +78,14 @@ def log_new_agent(agent: ConversableAgent, init_args: Dict[str, Any]) -> None:
autogen_logger.log_new_agent(agent, init_args)
+def log_event(source: Union[str, Agent], name: str, **kwargs: Dict[str, Any]) -> None:
+ if autogen_logger is None:
+ logger.error("[runtime logging] log_event: autogen logger is None")
+ return
+
+ autogen_logger.log_event(source, name, **kwargs)
+
+
def log_new_wrapper(wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLMConfig, List[LLMConfig]]]) -> None:
if autogen_logger is None:
logger.error("[runtime logging] log_new_wrapper: autogen logger is None")
diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py
index 84fe147fd8e..589d7b404a7 100644
--- a/autogen/token_count_utils.py
+++ b/autogen/token_count_utils.py
@@ -1,9 +1,9 @@
-from typing import List, Union, Dict
-import logging
import json
-import tiktoken
+import logging
import re
+from typing import Dict, List, Union
+import tiktoken
logger = logging.getLogger(__name__)
@@ -14,7 +14,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
model = re.sub(r"^gpt4", "gpt-4", model)
max_token_limit = {
- "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo": 16385,
+ "gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-instruct": 4096,
@@ -22,6 +23,8 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
"gpt-3.5-turbo-16k-0613": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-4": 8192,
+ "gpt-4-turbo": 128000,
+ "gpt-4-turbo-2024-04-09": 128000,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768, # deprecate in Sep
"gpt-4-0314": 8192, # deprecate in Sep
@@ -66,7 +69,7 @@ def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613"
elif isinstance(input, list) or isinstance(input, dict):
return _num_token_from_messages(input, model=model)
else:
- raise ValueError("input must be str, list or dict")
+ raise ValueError(f"input must be str, list or dict, but we got {type(input)}")
def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
@@ -111,6 +114,9 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
+ elif "gemini" in model:
+ logger.info("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
+ return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
diff --git a/autogen/version.py b/autogen/version.py
index 198d6db6273..be2d7c2ffe3 100644
--- a/autogen/version.py
+++ b/autogen/version.py
@@ -1 +1 @@
-__version__ = "0.2.20"
+__version__ = "0.2.27"
diff --git a/dotnet/.config/dotnet-tools.json b/dotnet/.config/dotnet-tools.json
index 5b341cff736..6b2517ea2c6 100644
--- a/dotnet/.config/dotnet-tools.json
+++ b/dotnet/.config/dotnet-tools.json
@@ -1,12 +1,18 @@
{
- "version": 1,
- "isRoot": true,
- "tools": {
- "dotnet-repl": {
- "version": "0.1.205",
- "commands": [
- "dotnet-repl"
- ]
- }
+ "version": 1,
+ "isRoot": true,
+ "tools": {
+ "dotnet-repl": {
+ "version": "0.1.205",
+ "commands": [
+ "dotnet-repl"
+ ]
+ },
+ "docfx": {
+ "version": "2.67.5",
+ "commands": [
+ "docfx"
+ ]
}
- }
\ No newline at end of file
+ }
+}
\ No newline at end of file
diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig
new file mode 100644
index 00000000000..4da1adc5de6
--- /dev/null
+++ b/dotnet/.editorconfig
@@ -0,0 +1,178 @@
+ο»Ώ# EditorConfig is awesome:http://EditorConfig.org
+
+# top-most EditorConfig file
+root = true
+
+# Don't use tabs for indentation.
+[*]
+indent_style = space
+# (Please don't specify an indent_size here; that has too many unintended consequences.)
+
+# Code files
+[*.{cs,csx,vb,vbx}]
+indent_size = 4
+insert_final_newline = true
+charset = utf-8-bom
+
+[*.xaml]
+indent_size = 4
+
+[*.ps1]
+indent_size = 2
+
+# Xml project files
+[*.{csproj,vbproj,vcxproj,vcxproj.filters,proj,projitems,shproj}]
+indent_size = 2
+
+# Xml config files
+[*.{props,targets,ruleset,config,nuspec,resx,vsixmanifest,vsct}]
+indent_size = 2
+
+# JSON files
+[*.json]
+indent_size = 2
+
+[*.groovy]
+indent_size = 2
+
+# Dotnet code style settings:
+[*.{cs,vb}]
+# Sort using and Import directives with System.* appearing first
+dotnet_sort_system_directives_first = true
+dotnet_style_require_accessibility_modifiers = always:warning
+
+# No blank line between System.* and Microsoft.*
+dotnet_separate_import_directive_groups = false
+
+# Suggest more modern language features when available
+dotnet_style_object_initializer = true:suggestion
+dotnet_style_collection_initializer = true:suggestion
+dotnet_style_coalesce_expression = true:error
+dotnet_style_null_propagation = true:error
+dotnet_style_explicit_tuple_names = true:suggestion
+dotnet_style_prefer_inferred_tuple_names = true:suggestion
+dotnet_style_prefer_inferred_anonymous_type_member_names = true:suggestion
+dotnet_style_prefer_is_null_check_over_reference_equality_method = true:suggestion
+dotnet_style_prefer_conditional_expression_over_return = false
+dotnet_style_prefer_conditional_expression_over_assignment = false
+dotnet_style_prefer_auto_properties = false
+
+# Use language keywords instead of framework type names for type references
+dotnet_style_predefined_type_for_locals_parameters_members = true:error
+dotnet_style_predefined_type_for_member_access = true:error
+
+# Prefer read-only on fields
+dotnet_style_readonly_field = false
+
+# CSharp code style settings:
+[*.cs]
+
+# Prefer "var" only when the type is apparent
+csharp_style_var_for_built_in_types = false:suggestion
+csharp_style_var_when_type_is_apparent = true:suggestion
+csharp_style_var_elsewhere = false:suggestion
+
+# Prefer method-like constructs to have a block body
+csharp_style_expression_bodied_methods = false:none
+csharp_style_expression_bodied_constructors = false:none
+csharp_style_expression_bodied_operators = false:none
+
+# Prefer property-like constructs to have an expression-body
+csharp_style_expression_bodied_properties = true:none
+csharp_style_expression_bodied_indexers = true:none
+csharp_style_expression_bodied_accessors = true:none
+
+# Use block body for local functions
+csharp_style_expression_bodied_local_functions = when_on_single_line:silent
+
+# Suggest more modern language features when available
+csharp_style_pattern_matching_over_is_with_cast_check = true:error
+csharp_style_pattern_matching_over_as_with_null_check = true:error
+csharp_style_inlined_variable_declaration = true:error
+csharp_style_throw_expression = true:suggestion
+csharp_style_conditional_delegate_call = true:suggestion
+csharp_style_deconstructed_variable_declaration = true:suggestion
+
+# Newline settings
+csharp_new_line_before_open_brace = all
+csharp_new_line_before_else = true
+csharp_new_line_before_catch = true
+csharp_new_line_before_finally = true
+csharp_new_line_before_members_in_object_initializers = true
+csharp_new_line_before_members_in_anonymous_types = true
+csharp_new_line_between_query_expression_clauses = true
+
+# Identation options
+csharp_indent_case_contents = true
+csharp_indent_case_contents_when_block = true
+csharp_indent_switch_labels = true
+csharp_indent_labels = no_change
+csharp_indent_block_contents = true
+csharp_indent_braces = false
+
+# Spacing options
+csharp_space_after_cast = false
+csharp_space_after_keywords_in_control_flow_statements = true
+csharp_space_between_method_call_empty_parameter_list_parentheses = false
+csharp_space_between_method_call_parameter_list_parentheses = false
+csharp_space_between_method_call_name_and_opening_parenthesis = false
+csharp_space_between_method_declaration_parameter_list_parentheses = false
+csharp_space_between_method_declaration_empty_parameter_list_parentheses = false
+csharp_space_between_method_declaration_parameter_list_parentheses = false
+csharp_space_between_method_declaration_name_and_open_parenthesis = false
+csharp_space_between_parentheses = false
+csharp_space_between_square_brackets = false
+csharp_space_between_empty_square_brackets = false
+csharp_space_before_open_square_brackets = false
+csharp_space_around_declaration_statements = false
+csharp_space_around_binary_operators = before_and_after
+csharp_space_after_cast = false
+csharp_space_before_semicolon_in_for_statement = false
+csharp_space_before_dot = false
+csharp_space_after_dot = false
+csharp_space_before_comma = false
+csharp_space_after_comma = true
+csharp_space_before_colon_in_inheritance_clause = true
+csharp_space_after_colon_in_inheritance_clause = true
+csharp_space_after_semicolon_in_for_statement = true
+
+# Wrapping
+csharp_preserve_single_line_statements = true
+csharp_preserve_single_line_blocks = true
+
+# Code block
+csharp_prefer_braces = false:none
+
+# Using statements
+csharp_using_directive_placement = outside_namespace:error
+
+# Modifier settings
+csharp_prefer_static_local_function = true:warning
+csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,volatile,async:warning
+
+# Header template
+file_header_template = Copyright (c) Microsoft Corporation. All rights reserved.\n{fileName}
+dotnet_diagnostic.IDE0073.severity = error
+
+# enable format error
+dotnet_diagnostic.IDE0055.severity = error
+
+# IDE0035: Remove unreachable code
+dotnet_diagnostic.IDE0035.severity = error
+
+# IDE0005: Remove unncecessary usings
+dotnet_diagnostic.CS8019.severity = error
+dotnet_diagnostic.IDE0005.severity = error
+
+# IDE0069: Remove unused local variable
+dotnet_diagnostic.IDE0069.severity = error
+
+# disable CS1573: Parameter has no matching param tag in the XML comment for
+dotnet_diagnostic.CS1573.severity = none
+
+# disable CS1570: XML comment has badly formed XML
+dotnet_diagnostic.CS1570.severity = none
+
+# disable check for generated code
+[*.generated.cs]
+generated_code = true
\ No newline at end of file
diff --git a/dotnet/.gitignore b/dotnet/.gitignore
new file mode 100644
index 00000000000..65e7ba678dd
--- /dev/null
+++ b/dotnet/.gitignore
@@ -0,0 +1,30 @@
+# gitignore file for C#/VS
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+build/
+bld/
+[Bb]in/
+[Oo]bj/
+
+# vs cache
+.vs/
+
+# vs code cache
+.vscode/
+
+# Properties
+Properties/
+
+artifacts/
+output/
+
+*.binlog
+
+# JetBrains Rider
+.idea/
\ No newline at end of file
diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln
new file mode 100644
index 00000000000..b29e5e21e95
--- /dev/null
+++ b/dotnet/AutoGen.sln
@@ -0,0 +1,145 @@
+ο»Ώ
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 17
+VisualStudioVersion = 17.8.34322.80
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen", "src\AutoGen\AutoGen.csproj", "{B2B27ACB-AA50-4FED-A06C-3AD6B4218188}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{18BF8DD7-0585-48BF-8F97-AD333080CE06}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{F823671B-3ECA-4AE6-86DA-25E920D3FE64}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Tests", "test\AutoGen.Tests\AutoGen.Tests.csproj", "{FDD99AEC-4C57-4020-B23F-650612856102}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SourceGenerator", "src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj", "{3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SourceGenerator.Tests", "test\AutoGen.SourceGenerator.Tests\AutoGen.SourceGenerator.Tests.csproj", "{05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.BasicSample", "sample\AutoGen.BasicSamples\AutoGen.BasicSample.csproj", "{7EBF916A-A7B1-4B74-AF10-D705B7A18F58}"
+EndProject
+Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive", "src\AutoGen.DotnetInteractive\AutoGen.DotnetInteractive.csproj", "{B61D8008-7FB7-4C0E-8044-3A74AA63A596}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.LMStudio", "src\AutoGen.LMStudio\AutoGen.LMStudio.csproj", "{F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel", "src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj", "{45D6FC80-36F3-4967-9663-E20B63824621}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Core", "src\AutoGen.Core\AutoGen.Core.csproj", "{D58D43D1-0617-4A3D-9932-C773E6398535}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.OpenAI", "src\AutoGen.OpenAI\AutoGen.OpenAI.csproj", "{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral", "src\AutoGen.Mistral\AutoGen.Mistral.csproj", "{6585D1A4-3D97-4D76-A688-1933B61AEB19}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.Mistral.Tests", "test\AutoGen.Mistral.Tests\AutoGen.Mistral.Tests.csproj", "{15441693-3659-4868-B6C1-B106F52FF3BA}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel.Tests", "test\AutoGen.SemanticKernel.Tests\AutoGen.SemanticKernel.Tests.csproj", "{1DFABC4A-8458-4875-8DCB-59F3802DAC65}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.OpenAI.Tests", "test\AutoGen.OpenAI.Tests\AutoGen.OpenAI.Tests.csproj", "{D36A85F9-C172-487D-8192-6BFE5D05B4A7}"
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.DotnetInteractive.Tests", "test\AutoGen.DotnetInteractive.Tests\AutoGen.DotnetInteractive.Tests.csproj", "{B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama", "src\Autogen.Ollama\Autogen.Ollama.csproj", "{A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Autogen.Ollama.Tests", "test\Autogen.Ollama.Tests\Autogen.Ollama.Tests.csproj", "{C24FDE63-952D-4F8E-A807-AF31D43AD675}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188}.Release|Any CPU.Build.0 = Release|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {FDD99AEC-4C57-4020-B23F-650612856102}.Release|Any CPU.Build.0 = Release|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6}.Release|Any CPU.Build.0 = Release|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596}.Release|Any CPU.Build.0 = Release|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Release|Any CPU.Build.0 = Release|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D58D43D1-0617-4A3D-9932-C773E6398535}.Release|Any CPU.Build.0 = Release|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Release|Any CPU.Build.0 = Release|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19}.Release|Any CPU.Build.0 = Release|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {15441693-3659-4868-B6C1-B106F52FF3BA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1}.Release|Any CPU.Build.0 = Release|Any CPU
+ {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {C24FDE63-952D-4F8E-A807-AF31D43AD675}.Release|Any CPU.Build.0 = Release|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65}.Release|Any CPU.Build.0 = Release|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(NestedProjects) = preSolution
+ {B2B27ACB-AA50-4FED-A06C-3AD6B4218188} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {FDD99AEC-4C57-4020-B23F-650612856102} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {3FFD14E3-D6BC-4EA7-97A2-D21733060FD6} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {05A2FAD8-03B0-4B2F-82AF-2F6BF0F050E5} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {7EBF916A-A7B1-4B74-AF10-D705B7A18F58} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9}
+ {B61D8008-7FB7-4C0E-8044-3A74AA63A596} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {45D6FC80-36F3-4967-9663-E20B63824621} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {D58D43D1-0617-4A3D-9932-C773E6398535} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {6585D1A4-3D97-4D76-A688-1933B61AEB19} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {15441693-3659-4868-B6C1-B106F52FF3BA} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {A4EFA175-44CC-44A9-B93E-1C7B6FAC38F1} = {18BF8DD7-0585-48BF-8F97-AD333080CE06}
+ {C24FDE63-952D-4F8E-A807-AF31D43AD675} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {1DFABC4A-8458-4875-8DCB-59F3802DAC65} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {D36A85F9-C172-487D-8192-6BFE5D05B4A7} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ {B61388CA-DC73-4B7F-A7B2-7B9A86C9229E} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64}
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B}
+ EndGlobalSection
+EndGlobal
diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props
new file mode 100644
index 00000000000..aeb667438e2
--- /dev/null
+++ b/dotnet/Directory.Build.props
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+ net8.0
+ preview
+ enable
+ True
+ $(MSBuildThisFileDirectory)eng/opensource.snk
+ 0024000004800000940000000602000000240000525341310004000001000100f1d038d0b85ae392ad72011df91e9343b0b5df1bb8080aa21b9424362d696919e0e9ac3a8bca24e283e10f7a569c6f443e1d4e3ebc84377c87ca5caa562e80f9932bf5ea91b7862b538e13b8ba91c7565cf0e8dfeccfea9c805ae3bda044170ecc7fc6f147aeeac422dd96aeb9eb1f5a5882aa650efe2958f2f8107d2038f2ab
+ CS1998;CS1591
+ $(NoWarn);$(CSNoWarn);NU5104
+ true
+ false
+ true
+ true
+
+
+
+ $(MSBuildThisFileDirectory)
+
+
diff --git a/dotnet/NuGet.config b/dotnet/NuGet.config
new file mode 100644
index 00000000000..1d0cf4c2bc7
--- /dev/null
+++ b/dotnet/NuGet.config
@@ -0,0 +1,8 @@
+ο»Ώ
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dotnet/README.md b/dotnet/README.md
new file mode 100644
index 00000000000..5b0803b6e11
--- /dev/null
+++ b/dotnet/README.md
@@ -0,0 +1,103 @@
+### AutoGen for .NET
+
+[![dotnet-ci](https://github.com/microsoft/autogen/actions/workflows/dotnet-build.yml/badge.svg)](https://github.com/microsoft/autogen/actions/workflows/dotnet-build.yml)
+[![NuGet version](https://badge.fury.io/nu/AutoGen.Core.svg)](https://badge.fury.io/nu/AutoGen.Core)
+
+> [!NOTE]
+> Nightly build is available at:
+> - ![Static Badge](https://img.shields.io/badge/public-blue?style=flat) ![Static Badge](https://img.shields.io/badge/nightly-yellow?style=flat) ![Static Badge](https://img.shields.io/badge/github-grey?style=flat): https://nuget.pkg.github.com/microsoft/index.json
+> - ![Static Badge](https://img.shields.io/badge/public-blue?style=flat) ![Static Badge](https://img.shields.io/badge/nightly-yellow?style=flat) ![Static Badge](https://img.shields.io/badge/myget-grey?style=flat): https://www.myget.org/F/agentchat/api/v3/index.json
+> - ![Static Badge](https://img.shields.io/badge/internal-blue?style=flat) ![Static Badge](https://img.shields.io/badge/nightly-yellow?style=flat) ![Static Badge](https://img.shields.io/badge/azure_devops-grey?style=flat) : https://devdiv.pkgs.visualstudio.com/DevDiv/_packaging/AutoGen/nuget/v3/index.json
+
+
+Firstly, following the [installation guide](./website/articles/Installation.md) to install AutoGen packages.
+
+Then you can start with the following code snippet to create a conversable agent and chat with it.
+
+```csharp
+using AutoGen;
+using AutoGen.OpenAI;
+
+var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+var gpt35Config = new OpenAIConfig(openAIKey, "gpt-3.5-turbo");
+
+var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35Config],
+ })
+ .RegisterPrintMessage(); // register a hook to print message nicely to console
+
+// set human input mode to ALWAYS so that user always provide input
+var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: ConversableAgent.HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+// start the conversation
+await userProxyAgent.InitiateChatAsync(
+ receiver: assistantAgent,
+ message: "Hey assistant, please do me a favor.",
+ maxRound: 10);
+```
+
+#### Samples
+You can find more examples under the [sample project](https://github.com/microsoft/autogen/tree/dotnet/dotnet/sample/AutoGen.BasicSamples).
+
+#### Functionality
+- ConversableAgent
+ - [x] function call
+ - [x] code execution (dotnet only, powered by [`dotnet-interactive`](https://github.com/dotnet/interactive))
+
+- Agent communication
+ - [x] Two-agent chat
+ - [x] Group chat
+
+- [ ] Enhanced LLM Inferences
+
+- Exclusive for dotnet
+ - [x] Source generator for type-safe function definition generation
+
+#### Update log
+##### Update on 0.0.11 (2024-03-26)
+- Add link to Discord channel in nuget's readme.md
+- Document improvements
+##### Update on 0.0.10 (2024-03-12)
+- Rename `Workflow` to `Graph`
+- Rename `AddInitializeMessage` to `SendIntroduction`
+- Rename `SequentialGroupChat` to `RoundRobinGroupChat`
+##### Update on 0.0.9 (2024-03-02)
+- Refactor over @AutoGen.Message and introducing `TextMessage`, `ImageMessage`, `MultiModalMessage` and so on. PR [#1676](https://github.com/microsoft/autogen/pull/1676)
+- Add `AutoGen.SemanticKernel` to support seamless integration with Semantic Kernel
+- Move the agent contract abstraction to `AutoGen.Core` package. The `AutoGen.Core` package provides the abstraction for message type, agent and group chat and doesn't contain dependencies over `Azure.AI.OpenAI` or `Semantic Kernel`. This is useful when you want to leverage AutoGen's abstraction only and want to avoid introducing any other dependencies.
+- Move `GPTAgent`, `OpenAIChatAgent` and all openai-dependencies to `AutoGen.OpenAI`
+##### Update on 0.0.8 (2024-02-28)
+- Fix [#1804](https://github.com/microsoft/autogen/pull/1804)
+- Streaming support for IAgent [#1656](https://github.com/microsoft/autogen/pull/1656)
+- Streaming support for middleware via `MiddlewareStreamingAgent` [#1656](https://github.com/microsoft/autogen/pull/1656)
+- Graph chat support with conditional transition workflow [#1761](https://github.com/microsoft/autogen/pull/1761)
+- AutoGen.SourceGenerator: Generate `FunctionContract` from `FunctionAttribute` [#1736](https://github.com/microsoft/autogen/pull/1736)
+##### Update on 0.0.7 (2024-02-11)
+- Add `AutoGen.LMStudio` to support comsume openai-like API from LMStudio local server
+##### Update on 0.0.6 (2024-01-23)
+- Add `MiddlewareAgent`
+- Use `MiddlewareAgent` to implement existing agent hooks (RegisterPreProcess, RegisterPostProcess, RegisterReply)
+- Remove `AutoReplyAgent`, `PreProcessAgent`, `PostProcessAgent` because they are replaced by `MiddlewareAgent`
+##### Update on 0.0.5
+- Simplify `IAgent` interface by removing `ChatLLM` Property
+- Add `GenerateReplyOptions` to `IAgent.GenerateReplyAsync` which allows user to specify or override the options when generating reply
+
+##### Update on 0.0.4
+- Move out dependency of Semantic Kernel
+- Add type `IChatLLM` as connector to LLM
+
+##### Update on 0.0.3
+- In AutoGen.SourceGenerator, rename FunctionAttribution to FunctionAttribute
+- In AutoGen, refactor over ConversationAgent, UserProxyAgent, and AssistantAgent
+
+##### Update on 0.0.2
+- update Azure.OpenAI.AI to 1.0.0-beta.12
+- update Semantic kernel to 1.0.1
diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props
new file mode 100644
index 00000000000..0444dadfd5e
--- /dev/null
+++ b/dotnet/eng/MetaInfo.props
@@ -0,0 +1,12 @@
+
+
+
+ 0.0.14
+ AutoGen
+ https://microsoft.github.io/autogen-for-net/
+ https://github.com/microsoft/autogen
+ git
+ MIT
+ false
+
+
\ No newline at end of file
diff --git a/dotnet/eng/Sign.props b/dotnet/eng/Sign.props
new file mode 100644
index 00000000000..0d69e7797e4
--- /dev/null
+++ b/dotnet/eng/Sign.props
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers
+
+
+
+ Microsoft400
+
+
+
+
+ NuGet
+
+
+
diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props
new file mode 100644
index 00000000000..ae213015471
--- /dev/null
+++ b/dotnet/eng/Version.props
@@ -0,0 +1,17 @@
+
+
+
+ 1.0.0-beta.17
+ 1.10.0
+ 1.10.0-alpha
+ 5.0.0
+ 4.3.0
+ 6.0.0
+ 6.8.0
+ 2.4.2
+ 17.7.0
+ 1.0.0-beta.24229.4
+ 8.0.0
+ 4.0.0
+
+
\ No newline at end of file
diff --git a/dotnet/eng/opensource.snk b/dotnet/eng/opensource.snk
new file mode 100644
index 00000000000..779df7c8366
Binary files /dev/null and b/dotnet/eng/opensource.snk differ
diff --git a/dotnet/global.json b/dotnet/global.json
new file mode 100644
index 00000000000..a604954f983
--- /dev/null
+++ b/dotnet/global.json
@@ -0,0 +1,6 @@
+{
+ "sdk": {
+ "version": "8.0.104",
+ "rollForward": "latestMinor"
+ }
+}
\ No newline at end of file
diff --git a/dotnet/nuget/NUGET.md b/dotnet/nuget/NUGET.md
new file mode 100644
index 00000000000..34fdbca33ca
--- /dev/null
+++ b/dotnet/nuget/NUGET.md
@@ -0,0 +1,8 @@
+### About AutoGen for .NET
+`AutoGen for .NET` is the official .NET SDK for [AutoGen](https://github.com/microsoft/autogen). It enables you to create LLM agents and construct multi-agent workflows with ease. It also provides integration with popular platforms like OpenAI, Semantic Kernel, and LM Studio.
+
+### Gettings started
+- Find documents and examples on our [document site](https://microsoft.github.io/autogen-for-net/)
+- Join our [Discord channel](https://discord.gg/pAbnFJrkgZ) to get help and discuss with the community
+- Report a bug or request a feature by creating a new issue in our [github repo](https://github.com/microsoft/autogen)
+- Consume the nightly build package from one of the [nightly build feeds](https://microsoft.github.io/autogen-for-net/articles/Installation.html#nighly-build)
\ No newline at end of file
diff --git a/dotnet/nuget/icon.png b/dotnet/nuget/icon.png
new file mode 100644
index 00000000000..076fc48c562
--- /dev/null
+++ b/dotnet/nuget/icon.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02dbf31fea0b92714c80fdc90888da7e96374a1f52c621a939835fd3c876ddcc
+size 426084
diff --git a/dotnet/nuget/nuget-package.props b/dotnet/nuget/nuget-package.props
new file mode 100644
index 00000000000..c6ddf38916f
--- /dev/null
+++ b/dotnet/nuget/nuget-package.props
@@ -0,0 +1,54 @@
+
+
+ true
+
+
+ AutoGen
+ Microsoft
+ AutoGen
+ A programming framework for agentic AI
+ AI, Artificial Intelligence, SDK
+ $(AssemblyName)
+
+
+ MIT
+ Β© Microsoft Corporation. All rights reserved.
+ https://microsoft.github.io/autogen-for-net
+ https://github.com/microsoft/autogen
+ true
+
+
+ icon.png
+ icon.png
+ NUGET.md
+
+
+ true
+ snupkg
+
+
+ true
+
+
+ true
+
+
+ bin\$(Configuration)\$(TargetFramework)\$(AssemblyName).xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
\ No newline at end of file
diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
new file mode 100644
index 00000000000..0cafff3c0d0
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj
@@ -0,0 +1,25 @@
+ο»Ώ
+
+
+ Exe
+ $(TestTargetFramework)
+ enable
+ enable
+ True
+ $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054;SKEXP0050;SKEXP0110
+
+
+
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs
new file mode 100644
index 00000000000..abaf94cbd4f
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs
@@ -0,0 +1,31 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// AgentCodeSnippet.cs
+using AutoGen.Core;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+internal class AgentCodeSnippet
+{
+ public async Task ChatWithAnAgent(IStreamingAgent agent)
+ {
+ #region ChatWithAnAgent_GenerateReplyAsync
+ var message = new TextMessage(Role.User, "Hello");
+ IMessage reply = await agent.GenerateReplyAsync([message]);
+ #endregion ChatWithAnAgent_GenerateReplyAsync
+
+ #region ChatWithAnAgent_SendAsync
+ reply = await agent.SendAsync("Hello");
+ #endregion ChatWithAnAgent_SendAsync
+
+ #region ChatWithAnAgent_GenerateStreamingReplyAsync
+ var textMessage = new TextMessage(Role.User, "Hello");
+ await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message]))
+ {
+ if (streamingReply is TextMessageUpdate update)
+ {
+ Console.Write(update.Content);
+ }
+ }
+ #endregion ChatWithAnAgent_GenerateStreamingReplyAsync
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs
new file mode 100644
index 00000000000..f26485116c8
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs
@@ -0,0 +1,42 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// BuildInMessageCodeSnippet.cs
+
+using AutoGen.Core;
+namespace AutoGen.BasicSample.CodeSnippet;
+
+internal class BuildInMessageCodeSnippet
+{
+ public async Task StreamingCallCodeSnippetAsync()
+ {
+ IStreamingAgent agent = default;
+ #region StreamingCallCodeSnippet
+ var helloTextMessage = new TextMessage(Role.User, "Hello");
+ var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
+ var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name);
+ await foreach (var message in reply)
+ {
+ if (message is TextMessageUpdate textMessage)
+ {
+ Console.Write(textMessage.Content);
+ finalTextMessage.Update(textMessage);
+ }
+ }
+ #endregion StreamingCallCodeSnippet
+
+ #region StreamingCallWithFinalMessage
+ reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
+ TextMessage finalMessage = null;
+ await foreach (var message in reply)
+ {
+ if (message is TextMessageUpdate textMessage)
+ {
+ Console.Write(textMessage.Content);
+ }
+ else if (message is TextMessage txtMessage)
+ {
+ finalMessage = txtMessage;
+ }
+ }
+ #endregion StreamingCallWithFinalMessage
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
new file mode 100644
index 00000000000..4833c6195c9
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs
@@ -0,0 +1,142 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// CreateAnAgent.cs
+
+using AutoGen;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using FluentAssertions;
+
+public partial class AssistantCodeSnippet
+{
+ public void CodeSnippet1()
+ {
+ #region code_snippet_1
+ // get OpenAI Key and create config
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var llmConfig = new OpenAIConfig(openAIKey, "gpt-3.5-turbo");
+
+ // create assistant agent
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[] { llmConfig },
+ });
+ #endregion code_snippet_1
+
+ }
+
+ public void CodeSnippet2()
+ {
+ #region code_snippet_2
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+
+ // create assistant agent
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[] { llmConfig },
+ });
+ #endregion code_snippet_2
+ }
+
+ #region code_snippet_3
+ ///
+ /// convert input to upper case
+ ///
+ /// input
+ [Function]
+ public async Task UpperCase(string input)
+ {
+ var result = input.ToUpper();
+ return result;
+ }
+
+ #endregion code_snippet_3
+
+ public async Task CodeSnippet4()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+ #region code_snippet_4
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[]
+ {
+ llmConfig
+ },
+ FunctionContracts = new[]
+ {
+ this.UpperCaseFunctionContract, // The FunctionDefinition object for the UpperCase function
+ },
+ });
+
+ var response = await assistantAgent.SendAsync("hello");
+ response.Should().BeOfType();
+ var toolCallMessage = (ToolCallMessage)response;
+ toolCallMessage.ToolCalls.Count().Should().Be(1);
+ toolCallMessage.ToolCalls.First().FunctionName.Should().Be("UpperCase");
+ #endregion code_snippet_4
+ }
+
+ public async Task CodeSnippet5()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+ #region code_snippet_5
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[]
+ {
+ llmConfig
+ },
+ FunctionContracts = new[]
+ {
+ this.UpperCaseFunctionContract, // The FunctionDefinition object for the UpperCase function
+ },
+ },
+ functionMap: new Dictionary>>
+ {
+ { this.UpperCaseFunction.Name, this.UpperCaseWrapper }, // The wrapper function for the UpperCase function
+ });
+
+ var response = await assistantAgent.SendAsync("hello");
+ response.Should().BeOfType();
+ response.From.Should().Be("assistant");
+ var textMessage = (TextMessage)response;
+ textMessage.Content.Should().Be("HELLO");
+ #endregion code_snippet_5
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
new file mode 100644
index 00000000000..2b7e25fee0c
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs
@@ -0,0 +1,149 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionCallCodeSnippet.cs
+
+using AutoGen;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using FluentAssertions;
+
+public partial class FunctionCallCodeSnippet
+{
+ public async Task CodeSnippet4()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+ #region code_snippet_4
+ var function = new TypeSafeFunctionCall();
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that convert user input to upper case.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[]
+ {
+ llmConfig
+ },
+ FunctionContracts = new[]
+ {
+ function.WeatherReportFunctionContract,
+ },
+ });
+
+ var response = await assistantAgent.SendAsync("hello What's the weather in Seattle today? today is 2024-01-01");
+ response.Should().BeOfType();
+ var toolCallMessage = (ToolCallMessage)response;
+ toolCallMessage.ToolCalls.Count().Should().Be(1);
+ toolCallMessage.ToolCalls[0].FunctionName.Should().Be("WeatherReport");
+ toolCallMessage.ToolCalls[0].FunctionArguments.Should().Be(@"{""location"":""Seattle"",""date"":""2024-01-01""}");
+ #endregion code_snippet_4
+ }
+
+
+ public async Task CodeSnippet6()
+ {
+ // get OpenAI Key and create config
+ var apiKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY");
+ string endPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); // change to your endpoint
+
+ var llmConfig = new AzureOpenAIConfig(
+ endpoint: endPoint,
+ deploymentName: "gpt-3.5-turbo-16k", // change to your deployment name
+ apiKey: apiKey);
+ #region code_snippet_6
+ var function = new TypeSafeFunctionCall();
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = new[]
+ {
+ llmConfig
+ },
+ FunctionContracts = new[]
+ {
+ function.WeatherReportFunctionContract,
+ },
+ },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name, function.WeatherReportWrapper }, // The function wrapper for the weather report function
+ });
+
+ #endregion code_snippet_6
+
+ #region code_snippet_6_1
+ var response = await assistantAgent.SendAsync("What's the weather in Seattle today? today is 2024-01-01");
+ response.Should().BeOfType();
+ var textMessage = (TextMessage)response;
+ textMessage.Content.Should().Be("Weather report for Seattle on 2024-01-01 is sunny");
+ #endregion code_snippet_6_1
+ }
+
+ public async Task OverriderFunctionContractAsync()
+ {
+ IAgent agent = default;
+ IEnumerable messages = new List();
+ #region overrider_function_contract
+ var function = new TypeSafeFunctionCall();
+ var reply = agent.GenerateReplyAsync(messages, new GenerateReplyOptions
+ {
+ Functions = new[] { function.WeatherReportFunctionContract },
+ });
+ #endregion overrider_function_contract
+ }
+
+ public async Task RegisterFunctionCallMiddlewareAsync()
+ {
+ IAgent agent = default;
+ #region register_function_call_middleware
+ var function = new TypeSafeFunctionCall();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: new[] { function.WeatherReportFunctionContract },
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name, function.WeatherReportWrapper },
+ });
+
+ agent = agent!.RegisterMiddleware(functionCallMiddleware);
+ var reply = await agent.SendAsync("What's the weather in Seattle today? today is 2024-01-01");
+ #endregion register_function_call_middleware
+ }
+
+ public async Task TwoAgentWeatherChatTestAsync()
+ {
+ var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new ArgumentException("AZURE_OPENAI_API_KEY is not set");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentException("AZURE_OPENAI_ENDPOINT is not set");
+ var deploymentName = "gpt-35-turbo-16k";
+ var config = new AzureOpenAIConfig(endpoint, deploymentName, key);
+ #region two_agent_weather_chat
+ var function = new TypeSafeFunctionCall();
+ var assistant = new AssistantAgent(
+ "assistant",
+ llmConfig: new ConversableAgentConfig
+ {
+ ConfigList = new[] { config },
+ FunctionContracts = new[]
+ {
+ function.WeatherReportFunctionContract,
+ },
+ });
+
+ var user = new UserProxyAgent(
+ name: "user",
+ functionMap: new Dictionary>>
+ {
+ { function.WeatherReportFunctionContract.Name, function.WeatherReportWrapper },
+ });
+
+ await user.InitiateChatAsync(assistant, "what's weather in Seattle today, today is 2024-01-01", 10);
+ #endregion two_agent_weather_chat
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
new file mode 100644
index 00000000000..fe97152183a
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/GetStartCodeSnippet.cs
@@ -0,0 +1,41 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// GetStartCodeSnippet.cs
+
+#region snippet_GetStartCodeSnippet
+using AutoGen;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+#endregion snippet_GetStartCodeSnippet
+
+public class GetStartCodeSnippet
+{
+ public async Task CodeSnippet1()
+ {
+ #region code_snippet_1
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var gpt35Config = new OpenAIConfig(openAIKey, "gpt-3.5-turbo");
+
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35Config],
+ })
+ .RegisterPrintMessage(); // register a hook to print message nicely to console
+
+ // set human input mode to ALWAYS so that user always provide input
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ // start the conversation
+ await userProxyAgent.InitiateChatAsync(
+ receiver: assistantAgent,
+ message: "Hey assistant, please do me a favor.",
+ maxRound: 10);
+ #endregion code_snippet_1
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
new file mode 100644
index 00000000000..320afd0de67
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs
@@ -0,0 +1,169 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// MiddlewareAgentCodeSnippet.cs
+
+using System.Text.Json;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using FluentAssertions;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+public class MiddlewareAgentCodeSnippet
+{
+ public async Task CreateMiddlewareAgentAsync()
+ {
+ #region create_middleware_agent_with_original_agent
+ // Create an agent that always replies "Hello World"
+ IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hello World");
+
+ // Create a middleware agent on top of default reply agent
+ var middlewareAgent = new MiddlewareAgent(innerAgent: agent);
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ var reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware 0] Hello World");
+ #endregion create_middleware_agent_with_original_agent
+
+ #region register_middleware_agent
+ middlewareAgent = agent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+ #endregion register_middleware_agent
+
+ #region short_circuit_middleware_agent
+ // This middleware will short circuit the agent and return the last message directly.
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware shortcut]";
+ return lastMessage;
+ });
+ #endregion short_circuit_middleware_agent
+ }
+
+ public async Task RegisterStreamingMiddlewareAsync()
+ {
+ IStreamingAgent streamingAgent = default;
+ #region register_streaming_middleware
+ var connector = new OpenAIChatRequestMessageConnector();
+ var agent = streamingAgent!
+ .RegisterStreamingMiddleware(connector);
+ #endregion register_streaming_middleware
+ }
+
+ public async Task CodeSnippet1()
+ {
+ #region code_snippet_1
+ // Create an agent that always replies "Hello World"
+ IAgent agent = new DefaultReplyAgent(name: "assistant", defaultReply: "Hello World");
+
+ // Create a middleware agent on top of default reply agent
+ var middlewareAgent = new MiddlewareAgent(innerAgent: agent);
+
+ // Since no middleware is added, middlewareAgent will simply proxy into the inner agent to generate reply.
+ var reply = await middlewareAgent.SendAsync("Hello World");
+ reply.From.Should().Be("assistant");
+ reply.GetContent().Should().Be("Hello World");
+ #endregion code_snippet_1
+
+ #region code_snippet_2
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.Should().BeOfType();
+ var textReply = (TextMessage)reply;
+ textReply.Content.Should().Be("[middleware 0] Hello World");
+ #endregion code_snippet_2
+ #region code_snippet_2_1
+ middlewareAgent = agent.RegisterMiddleware(async (messages, options, agnet, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware 0] Hello World");
+ #endregion code_snippet_2_1
+ #region code_snippet_3
+ middlewareAgent.Use(async (messages, options, agent, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware 1] {lastMessage.Content}";
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware 0] [middleware 1] Hello World");
+ #endregion code_snippet_3
+
+ #region code_snippet_4
+ middlewareAgent.Use(async (messages, options, next, ct) =>
+ {
+ var lastMessage = messages.Last() as TextMessage;
+ lastMessage.Content = $"[middleware shortcut]";
+
+ return lastMessage;
+ });
+
+ reply = await middlewareAgent.SendAsync("Hello World");
+ reply.GetContent().Should().Be("[middleware shortcut]");
+ #endregion code_snippet_4
+
+ #region retrieve_inner_agent
+ var innerAgent = middlewareAgent.Agent;
+ #endregion retrieve_inner_agent
+
+ #region code_snippet_logging_to_console
+ var agentWithLogging = middlewareAgent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var reply = await agent.GenerateReplyAsync(messages, options, ct);
+ var formattedMessage = reply.FormatMessage();
+ Console.WriteLine(formattedMessage);
+
+ return reply;
+ });
+ #endregion code_snippet_logging_to_console
+
+ #region code_snippet_response_format_forcement
+ var jsonAgent = middlewareAgent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var maxAttempt = 5;
+ var reply = await agent.GenerateReplyAsync(messages, options, ct);
+ while (maxAttempt-- > 0)
+ {
+ if (JsonSerializer.Deserialize>(reply.GetContent()) is { } dict)
+ {
+ return reply;
+ }
+ else
+ {
+ await Task.Delay(1000);
+ var reviewPrompt = @"The format is not json, please modify your response to json format
+ -- ORIGINAL MESSAGE --
+ {reply.Content}
+ -- END OF ORIGINAL MESSAGE --
+
+ Reply again with json format.";
+ reply = await agent.SendAsync(reviewPrompt, messages, ct);
+ }
+ }
+
+ throw new Exception("agent fails to generate json response");
+ });
+ #endregion code_snippet_response_format_forcement
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs
new file mode 100644
index 00000000000..0ce1d840d36
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs
@@ -0,0 +1,86 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// MistralAICodeSnippet.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.Mistral;
+using AutoGen.Mistral.Extension;
+using FluentAssertions;
+#endregion using_statement
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+#region weather_function
+public partial class MistralAgentFunction
+{
+ [Function]
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+}
+#endregion weather_function
+
+internal class MistralAICodeSnippet
+{
+ public async Task CreateMistralAIClientAsync()
+ {
+ #region create_mistral_agent
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new Exception("Missing MISTRAL_API_KEY environment variable");
+ var client = new MistralClient(apiKey: apiKey);
+ var agent = new MistralClientAgent(
+ client: client,
+ name: "MistralAI",
+ model: MistralAIModelID.OPEN_MISTRAL_7B)
+ .RegisterMessageConnector(); // support more AutoGen built-in message types.
+
+ await agent.SendAsync("Hello, how are you?");
+ #endregion create_mistral_agent
+
+ #region streaming_chat
+ var reply = agent.GenerateStreamingReplyAsync(
+ messages: [new TextMessage(Role.User, "Hello, how are you?")]
+ );
+
+ await foreach (var message in reply)
+ {
+ if (message is TextMessageUpdate textMessageUpdate && textMessageUpdate.Content is string content)
+ {
+ Console.WriteLine(content);
+ }
+ }
+ #endregion streaming_chat
+ }
+
+ public async Task MistralAIChatAgentGetWeatherToolUsageAsync()
+ {
+ #region create_mistral_function_call_agent
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new Exception("Missing MISTRAL_API_KEY environment variable");
+ var client = new MistralClient(apiKey: apiKey);
+ var agent = new MistralClientAgent(
+ client: client,
+ name: "MistralAI",
+ model: MistralAIModelID.MISTRAL_SMALL_LATEST)
+ .RegisterMessageConnector(); // support more AutoGen built-in message types like ToolCallMessage and ToolCallResultMessage
+ #endregion create_mistral_function_call_agent
+
+ #region create_get_weather_function_call_middleware
+ var mistralFunctions = new MistralAgentFunction();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [mistralFunctions.GetWeatherFunctionContract],
+ functionMap: new Dictionary>> // with functionMap, the function will be automatically triggered if the tool name matches one of the keys.
+ {
+ { mistralFunctions.GetWeatherFunctionContract.Name, mistralFunctions.GetWeather }
+ });
+ #endregion create_get_weather_function_call_middleware
+
+ #region register_function_call_middleware
+ agent = agent.RegisterStreamingMiddleware(functionCallMiddleware);
+ #endregion register_function_call_middleware
+
+ #region send_message_with_function_call
+ var reply = await agent.SendAsync("What is the weather in Seattle?");
+ reply.GetContent().Should().Be("The weather in Seattle is sunny.");
+ #endregion send_message_with_function_call
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
new file mode 100644
index 00000000000..022f7e9f984
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs
@@ -0,0 +1,136 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// OpenAICodeSnippet.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+#endregion using_statement
+using FluentAssertions;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+#region weather_function
+public partial class Functions
+{
+ [Function]
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+}
+#endregion weather_function
+public partial class OpenAICodeSnippet
+{
+ [Function]
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+
+ public async Task CreateOpenAIChatAgentAsync()
+ {
+ #region create_openai_chat_agent
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var openAIClient = new OpenAIClient(openAIKey);
+
+ // create an open ai chat agent
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: modelId,
+ systemMessage: "You are an assistant that help user to do some tasks.");
+
+ // OpenAIChatAgent supports the following message types:
+ // - IMessage where ChatRequestMessage is from Azure.AI.OpenAI
+
+ var helloMessage = new ChatRequestUserMessage("Hello");
+
+ // Use MessageEnvelope.Create to create an IMessage
+ var chatMessageContent = MessageEnvelope.Create(helloMessage);
+ var reply = await openAIChatAgent.SendAsync(chatMessageContent);
+
+ // The type of reply is MessageEnvelope where ChatResponseMessage is from Azure.AI.OpenAI
+ reply.Should().BeOfType>();
+
+ // You can un-envelop the reply to get the ChatResponseMessage
+ ChatResponseMessage response = reply.As>().Content;
+ response.Role.Should().Be(ChatRole.Assistant);
+ #endregion create_openai_chat_agent
+
+ #region create_openai_chat_agent_streaming
+ var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
+
+ await foreach (var streamingMessage in streamingReply)
+ {
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().Content.Role.Should().Be(ChatRole.Assistant);
+ }
+ #endregion create_openai_chat_agent_streaming
+
+ #region register_openai_chat_message_connector
+ // register message connector to support more message types
+ var agentWithConnector = openAIChatAgent
+ .RegisterMessageConnector();
+
+ // now the agentWithConnector supports more message types
+ var messages = new IMessage[]
+ {
+ MessageEnvelope.Create(new ChatRequestUserMessage("Hello")),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ ],
+ from: "user"),
+ new Message(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead
+ };
+
+ foreach (var message in messages)
+ {
+ reply = await agentWithConnector.SendAsync(message);
+
+ reply.Should().BeOfType();
+ reply.As().From.Should().Be("assistant");
+ }
+ #endregion register_openai_chat_message_connector
+ }
+
+ public async Task OpenAIChatAgentGetWeatherFunctionCallAsync()
+ {
+ #region openai_chat_agent_get_weather_function_call
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var openAIClient = new OpenAIClient(openAIKey);
+
+ // create an open ai chat agent
+ var openAIChatAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: modelId,
+ systemMessage: "You are an assistant that help user to do some tasks.")
+ .RegisterMessageConnector();
+
+ #endregion openai_chat_agent_get_weather_function_call
+
+ #region create_function_call_middleware
+ var functions = new Functions();
+ var functionCallMiddleware = new FunctionCallMiddleware(
+ functions: [functions.GetWeatherFunctionContract], // GetWeatherFunctionContract is auto-generated from the GetWeather function
+ functionMap: new Dictionary>>
+ {
+ { functions.GetWeatherFunctionContract.Name, functions.GetWeatherWrapper } // GetWeatherWrapper is a wrapper function for GetWeather, which is also auto-generated
+ });
+
+ openAIChatAgent = openAIChatAgent.RegisterStreamingMiddleware(functionCallMiddleware);
+ #endregion create_function_call_middleware
+
+ #region chat_agent_send_function_call
+ var reply = await openAIChatAgent.SendAsync("what is the weather in Seattle?");
+ reply.GetContent().Should().Be("The weather in Seattle is sunny.");
+ reply.GetToolCalls().Count.Should().Be(1);
+ reply.GetToolCalls().First().Should().Be(this.GetWeatherFunctionContract.Name);
+ #endregion chat_agent_send_function_call
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
new file mode 100644
index 00000000000..bf4f9c976e2
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/PrintMessageMiddlewareCodeSnippet.cs
@@ -0,0 +1,44 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// PrintMessageMiddlewareCodeSnippet.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure;
+using Azure.AI.OpenAI;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+internal class PrintMessageMiddlewareCodeSnippet
+{
+ public async Task PrintMessageMiddlewareAsync()
+ {
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endpoint = new Uri(config.Endpoint);
+ var openaiClient = new OpenAIClient(endpoint, new AzureKeyCredential(config.ApiKey));
+ var agent = new OpenAIChatAgent(openaiClient, "assistant", config.DeploymentName)
+ .RegisterMessageConnector();
+
+ #region PrintMessageMiddleware
+ var agentWithPrintMessageMiddleware = agent
+ .RegisterPrintMessage();
+
+ await agentWithPrintMessageMiddleware.SendAsync("write a long poem");
+ #endregion PrintMessageMiddleware
+ }
+
+ public async Task PrintMessageStreamingMiddlewareAsync()
+ {
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endpoint = new Uri(config.Endpoint);
+ var openaiClient = new OpenAIClient(endpoint, new AzureKeyCredential(config.ApiKey));
+
+ #region print_message_streaming
+ var streamingAgent = new OpenAIChatAgent(openaiClient, "assistant", config.DeploymentName)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ await streamingAgent.SendAsync("write a long poem");
+ #endregion print_message_streaming
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
new file mode 100644
index 00000000000..e498650b6aa
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/RunCodeSnippetCodeSnippet.cs
@@ -0,0 +1,48 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// RunCodeSnippetCodeSnippet.cs
+
+#region code_snippet_0_1
+using AutoGen.Core;
+using AutoGen.DotnetInteractive;
+#endregion code_snippet_0_1
+
+namespace AutoGen.BasicSample.CodeSnippet;
+public class RunCodeSnippetCodeSnippet
+{
+ public async Task CodeSnippet1()
+ {
+ IAgent agent = default;
+
+ #region code_snippet_1_1
+ var workingDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName());
+ Directory.CreateDirectory(workingDirectory);
+ var interactiveService = new InteractiveService(installingDirectory: workingDirectory);
+ await interactiveService.StartAsync(workingDirectory: workingDirectory);
+ #endregion code_snippet_1_1
+
+ #region code_snippet_1_2
+ // register dotnet code block execution hook to an arbitrary agent
+ var dotnetCodeAgent = agent.RegisterDotnetCodeBlockExectionHook(interactiveService: interactiveService);
+
+ var codeSnippet = @"
+ ```csharp
+ Console.WriteLine(""Hello World"");
+ ```";
+
+ await dotnetCodeAgent.SendAsync(codeSnippet);
+ // output: Hello World
+ #endregion code_snippet_1_2
+
+ #region code_snippet_1_3
+ var content = @"
+ ```csharp
+ // This is csharp code snippet
+ ```
+
+ ```python
+ // This is python code snippet
+ ```
+ ";
+ #endregion code_snippet_1_3
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs
new file mode 100644
index 00000000000..20dd12d90ce
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs
@@ -0,0 +1,101 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// SemanticKernelCodeSnippet.cs
+
+using AutoGen.Core;
+using AutoGen.SemanticKernel;
+using AutoGen.SemanticKernel.Extension;
+using FluentAssertions;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+public class SemanticKernelCodeSnippet
+{
+ public async Task GetWeather(string location)
+ {
+ return "The weather in " + location + " is sunny.";
+ }
+ public async Task CreateSemanticKernelAgentAsync()
+ {
+ #region create_semantic_kernel_agent
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var builder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ var kernel = builder.Build();
+
+ // create a semantic kernel agent
+ var semanticKernelAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.");
+
+ // SemanticKernelAgent supports the following message types:
+ // - IMessage where ChatMessageContent is from Azure.AI.OpenAI
+
+ var helloMessage = new ChatMessageContent(AuthorRole.User, "Hello");
+
+ // Use MessageEnvelope.Create to create an IMessage
+ var chatMessageContent = MessageEnvelope.Create(helloMessage);
+ var reply = await semanticKernelAgent.SendAsync(chatMessageContent);
+
+ // The type of reply is MessageEnvelope where ChatResponseMessage is from Azure.AI.OpenAI
+ reply.Should().BeOfType>();
+
+ // You can un-envelop the reply to get the ChatResponseMessage
+ ChatMessageContent response = reply.As>().Content;
+ response.Role.Should().Be(AuthorRole.Assistant);
+ #endregion create_semantic_kernel_agent
+
+ #region create_semantic_kernel_agent_streaming
+ var streamingReply = semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
+
+ await foreach (var streamingMessage in streamingReply)
+ {
+ streamingMessage.Should().BeOfType>();
+ streamingMessage.As>().From.Should().Be("assistant");
+ }
+ #endregion create_semantic_kernel_agent_streaming
+ }
+
+ public async Task SemanticKernelChatMessageContentConnector()
+ {
+ #region register_semantic_kernel_chat_message_content_connector
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var builder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ var kernel = builder.Build();
+
+ // create a semantic kernel agent
+ var semanticKernelAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.");
+
+ // Register the connector middleware to the kernel agent
+ var semanticKernelAgentWithConnector = semanticKernelAgent
+ .RegisterMessageConnector();
+
+ // now semanticKernelAgentWithConnector supports more message types
+ IMessage[] messages = [
+ MessageEnvelope.Create(new ChatMessageContent(AuthorRole.User, "Hello")),
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ new MultiModalMessage(Role.Assistant,
+ [
+ new TextMessage(Role.Assistant, "Hello", from: "user"),
+ ],
+ from: "user"),
+ ];
+
+ foreach (var message in messages)
+ {
+ var reply = await semanticKernelAgentWithConnector.SendAsync(message);
+
+ // SemanticKernelChatMessageContentConnector will convert the reply message to TextMessage
+ reply.Should().BeOfType();
+ }
+ #endregion register_semantic_kernel_chat_message_content_connector
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
new file mode 100644
index 00000000000..50bcd8a8048
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs
@@ -0,0 +1,121 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// TypeSafeFunctionCallCodeSnippet.cs
+
+using System.Text.Json;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+#region weather_report_using_statement
+using AutoGen.Core;
+#endregion weather_report_using_statement
+
+#region weather_report
+public partial class TypeSafeFunctionCall
+{
+ ///
+ /// Get weather report
+ ///
+ /// city
+ /// date
+ [Function]
+ public async Task WeatherReport(string city, string date)
+ {
+ return $"Weather report for {city} on {date} is sunny";
+ }
+}
+#endregion weather_report
+
+public partial class TypeSafeFunctionCall
+{
+ public async Task Consume()
+ {
+ #region weather_report_consume
+ var functionInstance = new TypeSafeFunctionCall();
+
+ // Get the generated function definition
+ FunctionDefinition functionDefiniton = functionInstance.WeatherReportFunctionContract.ToOpenAIFunctionDefinition();
+
+ // Get the generated function wrapper
+ Func> functionWrapper = functionInstance.WeatherReportWrapper;
+
+ // ...
+ #endregion weather_report_consume
+ }
+}
+#region code_snippet_3
+// file: FunctionCall.cs
+
+public partial class TypeSafeFunctionCall
+{
+ ///
+ /// convert input to upper case
+ ///
+ /// input
+ [Function]
+ public async Task UpperCase(string input)
+ {
+ var result = input.ToUpper();
+ return result;
+ }
+}
+#endregion code_snippet_3
+
+public class TypeSafeFunctionCallCodeSnippet
+{
+ public async Task UpperCase(string input)
+ {
+ var result = input.ToUpper();
+ return result;
+ }
+
+ #region code_snippet_1
+ // file: FunctionDefinition.generated.cs
+ public FunctionDefinition UpperCaseFunction
+ {
+ get => new FunctionDefinition
+ {
+ Name = @"UpperCase",
+ Description = "convert input to upper case",
+ Parameters = BinaryData.FromObjectAsJson(new
+ {
+ Type = "object",
+ Properties = new
+ {
+ input = new
+ {
+ Type = @"string",
+ Description = @"input",
+ },
+ },
+ Required = new[]
+ {
+ "input",
+ },
+ },
+ new JsonSerializerOptions
+ {
+ PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
+ })
+ };
+ }
+ #endregion code_snippet_1
+
+ #region code_snippet_2
+ // file: FunctionDefinition.generated.cs
+ private class UpperCaseSchema
+ {
+ public string input { get; set; }
+ }
+
+ public Task UpperCaseWrapper(string arguments)
+ {
+ var schema = JsonSerializer.Deserialize(
+ arguments,
+ new JsonSerializerOptions
+ {
+ PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
+ });
+
+ return UpperCase(schema.input);
+ }
+ #endregion code_snippet_2
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/UserProxyAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/UserProxyAgentCodeSnippet.cs
new file mode 100644
index 00000000000..85aecae959e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/UserProxyAgentCodeSnippet.cs
@@ -0,0 +1,20 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// UserProxyAgentCodeSnippet.cs
+using AutoGen.Core;
+
+namespace AutoGen.BasicSample.CodeSnippet;
+
+public class UserProxyAgentCodeSnippet
+{
+ public async Task CodeSnippet1()
+ {
+ #region code_snippet_1
+ // create a user proxy agent which always ask user for input
+ var agent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS);
+
+ await agent.SendAsync("hello");
+ #endregion code_snippet_1
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
new file mode 100644
index 00000000000..3ee363bfc06
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs
@@ -0,0 +1,46 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example01_AssistantAgent.cs
+
+using AutoGen;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using FluentAssertions;
+
+///
+/// This example shows the basic usage of class.
+///
+public static class Example01_AssistantAgent
+{
+ public static async Task RunAsync()
+ {
+ var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var config = new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35],
+ };
+
+ // create assistant agent
+ var assistantAgent = new AssistantAgent(
+ name: "assistant",
+ systemMessage: "You convert what user said to all uppercase.",
+ llmConfig: config)
+ .RegisterPrintMessage();
+
+ // talk to the assistant agent
+ var reply = await assistantAgent.SendAsync("hello world");
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("HELLO WORLD");
+
+ // to carry on the conversation, pass the previous conversation history to the next call
+ var conversationHistory = new List
+ {
+ new TextMessage(Role.User, "hello world"), // first message
+ reply, // reply from assistant agent
+ };
+
+ reply = await assistantAgent.SendAsync("hello world again", conversationHistory);
+ reply.Should().BeOfType();
+ reply.GetContent().Should().Be("HELLO WORLD AGAIN");
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
new file mode 100644
index 00000000000..c2957f32da7
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs
@@ -0,0 +1,80 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example02_TwoAgent_MathChat.cs
+
+using AutoGen;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using FluentAssertions;
+public static class Example02_TwoAgent_MathChat
+{
+ public static async Task RunAsync()
+ {
+ #region code_snippet_1
+ // get gpt-3.5-turbo config
+ var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+
+ // create teacher agent
+ // teacher agent will create math questions
+ var teacher = new AssistantAgent(
+ name: "teacher",
+ systemMessage: @"You are a teacher that create pre-school math question for student and check answer.
+ If the answer is correct, you stop the conversation by saying [COMPLETE].
+ If the answer is wrong, you ask student to fix it.",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35],
+ })
+ .RegisterMiddleware(async (msgs, option, agent, _) =>
+ {
+ var reply = await agent.GenerateReplyAsync(msgs, option);
+ if (reply.GetContent()?.ToLower().Contains("complete") is true)
+ {
+ return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From);
+ }
+
+ return reply;
+ })
+ .RegisterPrintMessage();
+
+ // create student agent
+ // student agent will answer the math questions
+ var student = new AssistantAgent(
+ name: "student",
+ systemMessage: "You are a student that answer question from teacher",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35],
+ })
+ .RegisterPrintMessage();
+
+ // start the conversation
+ var conversation = await student.InitiateChatAsync(
+ receiver: teacher,
+ message: "Hey teacher, please create math question for me.",
+ maxRound: 10);
+
+ // output
+ // Message from teacher
+ // --------------------
+ // content: Of course!Here's a math question for you:
+ //
+ // What is 2 + 3 ?
+ // --------------------
+ //
+ // Message from student
+ // --------------------
+ // content: The sum of 2 and 3 is 5.
+ // --------------------
+ //
+ // Message from teacher
+ // --------------------
+ // content: [GROUPCHAT_TERMINATE]
+ // --------------------
+ #endregion code_snippet_1
+
+ conversation.Count().Should().BeLessThan(10);
+ conversation.Last().IsGroupChatTerminateMessage().Should().BeTrue();
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
new file mode 100644
index 00000000000..57b9ea76dcb
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs
@@ -0,0 +1,96 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example03_Agent_FunctionCall.cs
+
+using AutoGen;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using FluentAssertions;
+
+///
+/// This example shows how to add type-safe function call to an agent.
+///
+public partial class Example03_Agent_FunctionCall
+{
+ ///
+ /// upper case the message when asked.
+ ///
+ ///
+ [Function]
+ public async Task UpperCase(string message)
+ {
+ return message.ToUpper();
+ }
+
+ ///
+ /// Concatenate strings.
+ ///
+ /// strings to concatenate
+ [Function]
+ public async Task ConcatString(string[] strings)
+ {
+ return string.Join(" ", strings);
+ }
+
+ ///
+ /// calculate tax
+ ///
+ /// price, should be an integer
+ /// tax rate, should be in range (0, 1)
+ [FunctionAttribute]
+ public async Task CalculateTax(int price, float taxRate)
+ {
+ return $"tax is {price * taxRate}";
+ }
+
+ public static async Task RunAsync()
+ {
+ var instance = new Example03_Agent_FunctionCall();
+ var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+
+ // AutoGen makes use of AutoGen.SourceGenerator to automatically generate FunctionDefinition and FunctionCallWrapper for you.
+ // The FunctionDefinition will be created based on function signature and XML documentation.
+ // The return type of type-safe function needs to be Task. And to get the best performance, please try only use primitive types and arrays of primitive types as parameters.
+ var config = new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gpt35],
+ FunctionContracts = new[]
+ {
+ instance.ConcatStringFunctionContract,
+ instance.UpperCaseFunctionContract,
+ instance.CalculateTaxFunctionContract,
+ },
+ };
+
+ var agent = new AssistantAgent(
+ name: "agent",
+ systemMessage: "You are a helpful AI assistant",
+ llmConfig: config,
+ functionMap: new Dictionary>>
+ {
+ { nameof(ConcatString), instance.ConcatStringWrapper },
+ { nameof(UpperCase), instance.UpperCaseWrapper },
+ { nameof(CalculateTax), instance.CalculateTaxWrapper },
+ })
+ .RegisterPrintMessage();
+
+ // talk to the assistant agent
+ var upperCase = await agent.SendAsync("convert to upper case: hello world");
+ upperCase.GetContent()?.Should().Be("HELLO WORLD");
+ upperCase.Should().BeOfType>();
+ upperCase.GetToolCalls().Should().HaveCount(1);
+ upperCase.GetToolCalls().First().FunctionName.Should().Be(nameof(UpperCase));
+
+ var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e");
+ concatString.GetContent()?.Should().Be("a b c d e");
+ concatString.Should().BeOfType>();
+ concatString.GetToolCalls().Should().HaveCount(1);
+ concatString.GetToolCalls().First().FunctionName.Should().Be(nameof(ConcatString));
+
+ var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1");
+ calculateTax.GetContent().Should().Be("tax is 10");
+ calculateTax.Should().BeOfType>();
+ calculateTax.GetToolCalls().Should().HaveCount(1);
+ calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
new file mode 100644
index 00000000000..c5d9a01f971
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs
@@ -0,0 +1,263 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example04_Dynamic_GroupChat_Coding_Task.cs
+
+using AutoGen;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.DotnetInteractive;
+using AutoGen.OpenAI;
+using FluentAssertions;
+
+public partial class Example04_Dynamic_GroupChat_Coding_Task
+{
+ public static async Task RunAsync()
+ {
+ var instance = new Example04_Dynamic_GroupChat_Coding_Task();
+
+ // setup dotnet interactive
+ var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
+ if (!Directory.Exists(workDir))
+ Directory.CreateDirectory(workDir);
+
+ using var service = new InteractiveService(workDir);
+ var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
+
+ var result = Path.Combine(workDir, "result.txt");
+ if (File.Exists(result))
+ File.Delete(result);
+
+ await service.StartAsync(workDir, default);
+
+ var gptConfig = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+
+ var helperAgent = new GPTAgent(
+ name: "helper",
+ systemMessage: "You are a helpful AI assistant",
+ temperature: 0f,
+ config: gptConfig);
+
+ var groupAdmin = new GPTAgent(
+ name: "groupAdmin",
+ systemMessage: "You are the admin of the group chat",
+ temperature: 0f,
+ config: gptConfig);
+
+ var userProxy = new UserProxyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE, humanInputMode: HumanInputMode.NEVER)
+ .RegisterPrintMessage();
+
+ // Create admin agent
+ var admin = new AssistantAgent(
+ name: "admin",
+ systemMessage: """
+ You are a manager who takes coding problem from user and resolve problem by splitting them into small tasks and assign each task to the most appropriate agent.
+ Here's available agents who you can assign task to:
+ - coder: write dotnet code to resolve task
+ - runner: run dotnet code from coder
+
+ The workflow is as follows:
+ - You take the coding problem from user
+ - You break the problem into small tasks. For each tasks you first ask coder to write code to resolve the task. Once the code is written, you ask runner to run the code.
+ - Once a small task is resolved, you summarize the completed steps and create the next step.
+ - You repeat the above steps until the coding problem is resolved.
+
+ You can use the following json format to assign task to agents:
+ ```task
+ {
+ "to": "{agent_name}",
+ "task": "{a short description of the task}",
+ "context": "{previous context from scratchpad}"
+ }
+ ```
+
+ If you need to ask user for extra information, you can use the following format:
+ ```ask
+ {
+ "question": "{question}"
+ }
+ ```
+
+ Once the coding problem is resolved, summarize each steps and results and send the summary to the user using the following format:
+ ```summary
+ {
+ "problem": "{coding problem}",
+ "steps": [
+ {
+ "step": "{step}",
+ "result": "{result}"
+ }
+ ]
+ }
+ ```
+
+ Your reply must contain one of [task|ask|summary] to indicate the type of your message.
+ """,
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = [gptConfig],
+ })
+ .RegisterPrintMessage();
+
+ // create coder agent
+ // The coder agent is a composite agent that contains dotnet coder, code reviewer and nuget agent.
+ // The dotnet coder write dotnet code to resolve the task.
+ // The code reviewer review the code block from coder's reply.
+ // The nuget agent install nuget packages if there's any.
+ var coderAgent = new GPTAgent(
+ name: "coder",
+ systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you.
+
+Here're some rules to follow on writing dotnet code:
+- put code between ```csharp and ```
+- When creating http client, use `var httpClient = new HttpClient()`. Don't use `using var httpClient = new HttpClient()` because it will cause error when running the code.
+- Try to use `var` instead of explicit type.
+- Try avoid using external library, use .NET Core library instead.
+- Use top level statement to write code.
+- Always print out the result to console. Don't write code that doesn't print out anything.
+
+If you need to install nuget packages, put nuget packages in the following format:
+```nuget
+nuget_package_name
+```
+
+If your code is incorrect, Fix the error and send the code again.
+
+Here's some externel information
+- The link to mlnet repo is: https://github.com/dotnet/machinelearning. you don't need a token to use github pr api. Make sure to include a User-Agent header, otherwise github will reject it.
+",
+ config: gptConfig,
+ temperature: 0.4f)
+ .RegisterPrintMessage();
+
+ // code reviewer agent will review if code block from coder's reply satisfy the following conditions:
+ // - There's only one code block
+ // - The code block is csharp code block
+ // - The code block is top level statement
+ // - The code block is not using declaration
+ var codeReviewAgent = new GPTAgent(
+ name: "reviewer",
+ systemMessage: """
+ You are a code reviewer who reviews code from coder. You need to check if the code satisfy the following conditions:
+ - The reply from coder contains at least one code block, e.g ```csharp and ```
+ - There's only one code block and it's csharp code block
+ - The code block is not inside a main function. a.k.a top level statement
+ - The code block is not using declaration when creating http client
+
+ You don't check the code style, only check if the code satisfy the above conditions.
+
+ Put your comment between ```review and ```, if the code satisfies all conditions, put APPROVED in review.result field. Otherwise, put REJECTED along with comments. make sure your comment is clear and easy to understand.
+
+ ## Example 1 ##
+ ```review
+ comment: The code satisfies all conditions.
+ result: APPROVED
+ ```
+
+ ## Example 2 ##
+ ```review
+ comment: The code is inside main function. Please rewrite the code in top level statement.
+ result: REJECTED
+ ```
+
+ """,
+ config: gptConfig,
+ temperature: 0f)
+ .RegisterPrintMessage();
+
+ // create runner agent
+ // The runner agent will run the code block from coder's reply.
+ // It runs dotnet code using dotnet interactive service hook.
+ // It also truncate the output if the output is too long.
+ var runner = new AssistantAgent(
+ name: "runner",
+ defaultReply: "No code available, coder, write code please")
+ .RegisterDotnetCodeBlockExectionHook(interactiveService: service)
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ var mostRecentCoderMessage = msgs.LastOrDefault(x => x.From == "coder") ?? throw new Exception("No coder message found");
+ return await agent.GenerateReplyAsync(new[] { mostRecentCoderMessage }, option, ct);
+ })
+ .RegisterPrintMessage();
+
+ var adminToCoderTransition = Transition.Create(admin, coderAgent, async (from, to, messages) =>
+ {
+ // the last message should be from admin
+ var lastMessage = messages.Last();
+ if (lastMessage.From != admin.Name)
+ {
+ return false;
+ }
+
+ return true;
+ });
+ var coderToReviewerTransition = Transition.Create(coderAgent, codeReviewAgent);
+ var adminToRunnerTransition = Transition.Create(admin, runner, async (from, to, messages) =>
+ {
+ // the last message should be from admin
+ var lastMessage = messages.Last();
+ if (lastMessage.From != admin.Name)
+ {
+ return false;
+ }
+
+ // the previous messages should contain a message from coder
+ var coderMessage = messages.FirstOrDefault(x => x.From == coderAgent.Name);
+ if (coderMessage is null)
+ {
+ return false;
+ }
+
+ return true;
+ });
+
+ var runnerToAdminTransition = Transition.Create(runner, admin);
+
+ var reviewerToAdminTransition = Transition.Create(codeReviewAgent, admin);
+
+ var adminToUserTransition = Transition.Create(admin, userProxy, async (from, to, messages) =>
+ {
+ // the last message should be from admin
+ var lastMessage = messages.Last();
+ if (lastMessage.From != admin.Name)
+ {
+ return false;
+ }
+
+ return true;
+ });
+
+ var userToAdminTransition = Transition.Create(userProxy, admin);
+
+ var workflow = new Graph(
+ [
+ adminToCoderTransition,
+ coderToReviewerTransition,
+ reviewerToAdminTransition,
+ adminToRunnerTransition,
+ runnerToAdminTransition,
+ adminToUserTransition,
+ userToAdminTransition,
+ ]);
+
+ // create group chat
+ var groupChat = new GroupChat(
+ admin: groupAdmin,
+ members: [admin, coderAgent, runner, codeReviewAgent, userProxy],
+ workflow: workflow);
+
+ // task 1: retrieve the most recent pr from mlnet and save it in result.txt
+ var groupChatManager = new GroupChatManager(groupChat);
+ await userProxy.SendAsync(groupChatManager, "Retrieve the most recent pr from mlnet and save it in result.txt", maxRound: 30);
+ File.Exists(result).Should().BeTrue();
+
+ // task 2: calculate the 39th fibonacci number
+ var answer = 63245986;
+ // clear the result file
+ File.Delete(result);
+
+ var conversationHistory = await userProxy.InitiateChatAsync(groupChatManager, "What's the 39th of fibonacci number? Save the result in result.txt", maxRound: 10);
+ File.Exists(result).Should().BeTrue();
+ var resultContent = File.ReadAllText(result);
+ resultContent.Should().Contain(answer.ToString());
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
new file mode 100644
index 00000000000..9fccd7ab385
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs
@@ -0,0 +1,152 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example05_Dalle_And_GPT4V.cs
+
+using AutoGen;
+using AutoGen.Core;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+using autogen = AutoGen.LLMConfigAPI;
+
+public partial class Example05_Dalle_And_GPT4V
+{
+ private readonly OpenAIClient openAIClient;
+
+ public Example05_Dalle_And_GPT4V(OpenAIClient openAIClient)
+ {
+ this.openAIClient = openAIClient;
+ }
+
+ ///
+ /// Generate image from prompt using DALL-E.
+ ///
+ /// prompt with feedback
+ ///
+ [Function]
+ public async Task GenerateImage(string prompt)
+ {
+ // TODO
+ // generate image from prompt using DALL-E
+ // and return url.
+ var option = new ImageGenerationOptions
+ {
+ Size = ImageSize.Size1024x1024,
+ Style = ImageGenerationStyle.Vivid,
+ ImageCount = 1,
+ Prompt = prompt,
+ Quality = ImageGenerationQuality.Standard,
+ DeploymentName = "dall-e-3",
+ };
+
+ var imageResponse = await openAIClient.GetImageGenerationsAsync(option);
+ var imageUrl = imageResponse.Value.Data.First().Url.OriginalString;
+
+ return $@"// ignore this line [IMAGE_GENERATION]
+The image is generated from prompt {prompt}
+
+{imageUrl}";
+ }
+
+ public static async Task RunAsync()
+ {
+ // This example shows how to use DALL-E and GPT-4V to generate image from prompt and feedback.
+ // The DALL-E agent will generate image from prompt.
+ // The GPT-4V agent will provide feedback to DALL-E agent to help it generate better image.
+ // The conversation will be terminated when the image satisfies the condition.
+ // The image will be saved to image.jpg in current directory.
+
+ // get OpenAI Key and create config
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var gpt35Config = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" });
+ var gpt4vConfig = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-4-vision-preview" });
+ var openAIClient = new OpenAIClient(openAIKey);
+ var instance = new Example05_Dalle_And_GPT4V(openAIClient);
+ var imagePath = Path.Combine(Environment.CurrentDirectory, "image.jpg");
+ if (File.Exists(imagePath))
+ {
+ File.Delete(imagePath);
+ }
+
+ var dalleAgent = new AssistantAgent(
+ name: "dalle",
+ systemMessage: "You are a DALL-E agent that generate image from prompt, when conversation is terminated, return the most recent image url",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = gpt35Config,
+ FunctionContracts = new[]
+ {
+ instance.GenerateImageFunctionContract,
+ },
+ },
+ functionMap: new Dictionary>>
+ {
+ { nameof(GenerateImage), instance.GenerateImageWrapper },
+ })
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ // if last message contains [TERMINATE], then find the last image url and terminate the conversation
+ if (msgs.Last().GetContent()?.Contains("TERMINATE") is true)
+ {
+ var lastMessageWithImage = msgs.Last(msg => msg is ImageMessage) as ImageMessage;
+ var lastImageUrl = lastMessageWithImage.Url;
+ Console.WriteLine($"download image from {lastImageUrl} to {imagePath}");
+ var httpClient = new HttpClient();
+ var imageBytes = await httpClient.GetByteArrayAsync(lastImageUrl);
+ File.WriteAllBytes(imagePath, imageBytes);
+
+ var messageContent = $@"{GroupChatExtension.TERMINATE}
+
+{lastImageUrl}";
+ return new TextMessage(Role.Assistant, messageContent)
+ {
+ From = "dalle",
+ };
+ }
+
+ var reply = await agent.GenerateReplyAsync(msgs, option, ct);
+
+ if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION"))
+ {
+ var imageUrl = content.Split("\n").Last();
+ var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From);
+
+ return imageMessage;
+ }
+ else
+ {
+ return reply;
+ }
+ })
+ .RegisterPrintMessage();
+
+ var gpt4VAgent = new AssistantAgent(
+ name: "gpt4v",
+ systemMessage: @"You are a critism that provide feedback to DALL-E agent.
+Carefully check the image generated by DALL-E agent and provide feedback.
+If the image satisfies the condition, then terminate the conversation by saying [TERMINATE].
+Otherwise, provide detailed feedback to DALL-E agent so it can generate better image.
+
+The image should satisfy the following conditions:
+- There should be a cat and a mouse in the image
+- The cat should be chasing after the mouse
+",
+ llmConfig: new ConversableAgentConfig
+ {
+ Temperature = 0,
+ ConfigList = gpt4vConfig,
+ })
+ .RegisterPrintMessage();
+
+ IEnumerable conversation = new List()
+ {
+ new TextMessage(Role.User, "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse")
+ };
+ var maxRound = 20;
+ await gpt4VAgent.InitiateChatAsync(
+ receiver: dalleAgent,
+ message: "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse",
+ maxRound: maxRound);
+
+ File.Exists(imagePath).Should().BeTrue();
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
new file mode 100644
index 00000000000..dd3b5a67192
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs
@@ -0,0 +1,32 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example06_UserProxyAgent.cs
+using AutoGen.Core;
+using AutoGen.OpenAI;
+
+namespace AutoGen.BasicSample;
+
+public static class Example06_UserProxyAgent
+{
+ public static async Task RunAsync()
+ {
+ var gpt35 = LLMConfiguration.GetOpenAIGPT3_5_Turbo();
+
+ var assistantAgent = new GPTAgent(
+ name: "assistant",
+ systemMessage: "You are an assistant that help user to do some tasks.",
+ config: gpt35)
+ .RegisterPrintMessage();
+
+ // set human input mode to ALWAYS so that user always provide input
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ // start the conversation
+ await userProxyAgent.InitiateChatAsync(
+ receiver: assistantAgent,
+ message: "Hey assistant, please help me to do some tasks.",
+ maxRound: 10);
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
new file mode 100644
index 00000000000..6584baa5fae
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
@@ -0,0 +1,368 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs
+
+using System.Text;
+using System.Text.Json;
+using AutoGen;
+using AutoGen.BasicSample;
+using AutoGen.Core;
+using AutoGen.DotnetInteractive;
+using AutoGen.OpenAI;
+using FluentAssertions;
+
+public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci
+{
+ #region reviewer_function
+ public struct CodeReviewResult
+ {
+ public bool HasMultipleCodeBlocks { get; set; }
+ public bool IsTopLevelStatement { get; set; }
+ public bool IsDotnetCodeBlock { get; set; }
+ public bool IsPrintResultToConsole { get; set; }
+ }
+
+ ///
+ /// review code block
+ ///
+ /// true if there're multipe csharp code blocks
+ /// true if the code is in top level statement
+ /// true if the code block is csharp code block
+ /// true if the code block print out result to console
+ [Function]
+ public async Task ReviewCodeBlock(
+ bool hasMultipleCodeBlocks,
+ bool isTopLevelStatement,
+ bool isDotnetCodeBlock,
+ bool isPrintResultToConsole)
+ {
+ var obj = new CodeReviewResult
+ {
+ HasMultipleCodeBlocks = hasMultipleCodeBlocks,
+ IsTopLevelStatement = isTopLevelStatement,
+ IsDotnetCodeBlock = isDotnetCodeBlock,
+ IsPrintResultToConsole = isPrintResultToConsole,
+ };
+
+ return JsonSerializer.Serialize(obj);
+ }
+ #endregion reviewer_function
+
+ #region create_coder
+ public static async Task CreateCoderAgentAsync()
+ {
+ var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var coder = new GPTAgent(
+ name: "coder",
+ systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you.
+
+ Here're some rules to follow on writing dotnet code:
+ - put code between ```csharp and ```
+ - Avoid adding `using` keyword when creating disposable object. e.g `var httpClient = new HttpClient()`
+ - Try to use `var` instead of explicit type.
+ - Try avoid using external library, use .NET Core library instead.
+ - Use top level statement to write code.
+ - Always print out the result to console. Don't write code that doesn't print out anything.
+
+ If you need to install nuget packages, put nuget packages in the following format:
+ ```nuget
+ nuget_package_name
+ ```
+
+ If your code is incorrect, runner will tell you the error message. Fix the error and send the code again.",
+ config: gpt3Config,
+ temperature: 0.4f)
+ .RegisterPrintMessage();
+
+ return coder;
+ }
+ #endregion create_coder
+
+ #region create_runner
+ public static async Task CreateRunnerAgentAsync(InteractiveService service)
+ {
+ var runner = new AssistantAgent(
+ name: "runner",
+ systemMessage: "You run dotnet code",
+ defaultReply: "No code available.")
+ .RegisterDotnetCodeBlockExectionHook(interactiveService: service)
+ .RegisterMiddleware(async (msgs, option, agent, _) =>
+ {
+ if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder"))
+ {
+ return new TextMessage(Role.Assistant, "No code available. Coder please write code");
+ }
+ else
+ {
+ var coderMsg = msgs.Last(msg => msg.From == "coder");
+ return await agent.GenerateReplyAsync([coderMsg], option);
+ }
+ })
+ .RegisterPrintMessage();
+
+ return runner;
+ }
+ #endregion create_runner
+
+ #region create_admin
+ public static async Task CreateAdminAsync()
+ {
+ var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var admin = new GPTAgent(
+ name: "admin",
+ systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer",
+ temperature: 0,
+ config: gpt3Config)
+ .RegisterMiddleware(async (msgs, option, agent, _) =>
+ {
+ var reply = await agent.GenerateReplyAsync(msgs, option);
+ if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true)
+ {
+ var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}";
+
+ return new TextMessage(Role.Assistant, content, from: reply.From);
+ }
+
+ return reply;
+ });
+
+ return admin;
+ }
+ #endregion create_admin
+
+ #region create_reviewer
+ public static async Task CreateReviewerAgentAsync()
+ {
+ var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var functions = new Example07_Dynamic_GroupChat_Calculate_Fibonacci();
+ var reviewer = new GPTAgent(
+ name: "code_reviewer",
+ systemMessage: @"You review code block from coder",
+ config: gpt3Config,
+ functions: [functions.ReviewCodeBlockFunction],
+ functionMap: new Dictionary>>()
+ {
+ { nameof(ReviewCodeBlock), functions.ReviewCodeBlockWrapper },
+ })
+ .RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
+ {
+ var maxRetry = 3;
+ var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct);
+ while (maxRetry-- > 0)
+ {
+ if (reply.GetToolCalls() is var toolCalls && toolCalls.Count() == 1 && toolCalls[0].FunctionName == nameof(ReviewCodeBlock))
+ {
+ var toolCallResult = reply.GetContent();
+ var reviewResultObj = JsonSerializer.Deserialize(toolCallResult);
+ var reviews = new List();
+ if (reviewResultObj.HasMultipleCodeBlocks)
+ {
+ var fixCodeBlockPrompt = @"There're multiple code blocks, please combine them into one code block";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviewResultObj.IsDotnetCodeBlock is false)
+ {
+ var fixCodeBlockPrompt = @"The code block is not csharp code block, please write dotnet code only";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviewResultObj.IsTopLevelStatement is false)
+ {
+ var fixCodeBlockPrompt = @"The code is not top level statement, please rewrite your dotnet code using top level statement";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviewResultObj.IsPrintResultToConsole is false)
+ {
+ var fixCodeBlockPrompt = @"The code doesn't print out result to console, please print out result to console";
+ reviews.Add(fixCodeBlockPrompt);
+ }
+
+ if (reviews.Count > 0)
+ {
+ var sb = new StringBuilder();
+ sb.AppendLine("There're some comments from code reviewer, please fix these comments");
+ foreach (var review in reviews)
+ {
+ sb.AppendLine($"- {review}");
+ }
+
+ return new TextMessage(Role.Assistant, sb.ToString(), from: "code_reviewer");
+ }
+ else
+ {
+ var msg = new TextMessage(Role.Assistant, "The code looks good, please ask runner to run the code for you.")
+ {
+ From = "code_reviewer",
+ };
+
+ return msg;
+ }
+ }
+ else
+ {
+ var originalContent = reply.GetContent();
+ var prompt = $@"Please convert the content to ReviewCodeBlock function arguments.
+
+ ## Original Content
+ {originalContent}";
+
+ reply = await innerAgent.SendAsync(prompt, msgs, ct);
+ }
+ }
+
+ throw new Exception("Failed to review code block");
+ })
+ .RegisterPrintMessage();
+
+ return reviewer;
+ }
+ #endregion create_reviewer
+
+ public static async Task RunWorkflowAsync()
+ {
+ long the39thFibonacciNumber = 63245986;
+ var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
+ if (!Directory.Exists(workDir))
+ Directory.CreateDirectory(workDir);
+
+ using var service = new InteractiveService(workDir);
+ var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
+
+ await service.StartAsync(workDir, default);
+
+ #region create_workflow
+ var reviewer = await CreateReviewerAgentAsync();
+ var coder = await CreateCoderAgentAsync();
+ var runner = await CreateRunnerAgentAsync(service);
+ var admin = await CreateAdminAsync();
+
+ var admin2CoderTransition = Transition.Create(admin, coder);
+ var coder2ReviewerTransition = Transition.Create(coder, reviewer);
+ var reviewer2RunnerTransition = Transition.Create(
+ from: reviewer,
+ to: runner,
+ canTransitionAsync: async (from, to, messages) =>
+ {
+ var lastMessage = messages.Last();
+ if (lastMessage is TextMessage textMessage && textMessage.Content.ToLower().Contains("the code looks good, please ask runner to run the code for you.") is true)
+ {
+ // ask runner to run the code
+ return true;
+ }
+
+ return false;
+ });
+ var reviewer2CoderTransition = Transition.Create(
+ from: reviewer,
+ to: coder,
+ canTransitionAsync: async (from, to, messages) =>
+ {
+ var lastMessage = messages.Last();
+ if (lastMessage is TextMessage textMessage && textMessage.Content.ToLower().Contains("there're some comments from code reviewer, please fix these comments") is true)
+ {
+ // ask coder to fix the code based on reviewer's comments
+ return true;
+ }
+
+ return false;
+ });
+
+ var runner2CoderTransition = Transition.Create(
+ from: runner,
+ to: coder,
+ canTransitionAsync: async (from, to, messages) =>
+ {
+ var lastMessage = messages.Last();
+ if (lastMessage is TextMessage textMessage && textMessage.Content.ToLower().Contains("error") is true)
+ {
+ // ask coder to fix the error
+ return true;
+ }
+
+ return false;
+ });
+ var runner2AdminTransition = Transition.Create(runner, admin);
+
+ var workflow = new Graph(
+ [
+ admin2CoderTransition,
+ coder2ReviewerTransition,
+ reviewer2RunnerTransition,
+ reviewer2CoderTransition,
+ runner2CoderTransition,
+ runner2AdminTransition,
+ ]);
+ #endregion create_workflow
+
+ #region create_group_chat_with_workflow
+ var groupChat = new GroupChat(
+ admin: admin,
+ workflow: workflow,
+ members:
+ [
+ admin,
+ coder,
+ runner,
+ reviewer,
+ ]);
+
+ admin.SendIntroduction("Welcome to my group, work together to resolve my task", groupChat);
+ coder.SendIntroduction("I will write dotnet code to resolve task", groupChat);
+ reviewer.SendIntroduction("I will review dotnet code", groupChat);
+ runner.SendIntroduction("I will run dotnet code once the review is done", groupChat);
+
+ var groupChatManager = new GroupChatManager(groupChat);
+ var conversationHistory = await admin.InitiateChatAsync(groupChatManager, "What's the 39th of fibonacci number?", maxRound: 10);
+ #endregion create_group_chat_with_workflow
+ // the last message is from admin, which is the termination message
+ var lastMessage = conversationHistory.Last();
+ lastMessage.From.Should().Be("admin");
+ lastMessage.IsGroupChatTerminateMessage().Should().BeTrue();
+ lastMessage.Should().BeOfType();
+ lastMessage.GetContent().Should().Contain(the39thFibonacciNumber.ToString());
+ }
+
+ public static async Task RunAsync()
+ {
+ long the39thFibonacciNumber = 63245986;
+ var workDir = Path.Combine(Path.GetTempPath(), "InteractiveService");
+ if (!Directory.Exists(workDir))
+ Directory.CreateDirectory(workDir);
+
+ using var service = new InteractiveService(workDir);
+ var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service);
+
+ await service.StartAsync(workDir, default);
+ #region create_group_chat
+ var reviewer = await CreateReviewerAgentAsync();
+ var coder = await CreateCoderAgentAsync();
+ var runner = await CreateRunnerAgentAsync(service);
+ var admin = await CreateAdminAsync();
+ var groupChat = new GroupChat(
+ admin: admin,
+ members:
+ [
+ admin,
+ coder,
+ runner,
+ reviewer,
+ ]);
+
+ admin.SendIntroduction("Welcome to my group, work together to resolve my task", groupChat);
+ coder.SendIntroduction("I will write dotnet code to resolve task", groupChat);
+ reviewer.SendIntroduction("I will review dotnet code", groupChat);
+ runner.SendIntroduction("I will run dotnet code once the review is done", groupChat);
+
+ var groupChatManager = new GroupChatManager(groupChat);
+ var conversationHistory = await admin.InitiateChatAsync(groupChatManager, "What's the 39th of fibonacci number?", maxRound: 10);
+
+ // the last message is from admin, which is the termination message
+ var lastMessage = conversationHistory.Last();
+ lastMessage.From.Should().Be("admin");
+ lastMessage.IsGroupChatTerminateMessage().Should().BeTrue();
+ lastMessage.Should().BeOfType();
+ lastMessage.GetContent().Should().Contain(the39thFibonacciNumber.ToString());
+ #endregion create_group_chat
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs b/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
new file mode 100644
index 00000000000..cce33011762
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example08_LMStudio.cs
@@ -0,0 +1,44 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example08_LMStudio.cs
+
+#region lmstudio_using_statements
+using AutoGen.Core;
+using AutoGen.LMStudio;
+#endregion lmstudio_using_statements
+
+namespace AutoGen.BasicSample;
+
+public class Example08_LMStudio
+{
+ public static async Task RunAsync()
+ {
+ #region lmstudio_example_1
+ var config = new LMStudioConfig("localhost", 1234);
+ var lmAgent = new LMStudioAgent("asssistant", config: config)
+ .RegisterPrintMessage();
+
+ await lmAgent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+
+ // output from assistant (the output below is generated using llama-2-chat-7b, the output may vary depending on the model used)
+ //
+ // Of course! To calculate the 100th number in the Fibonacci sequence using C#, you can use the following code:```
+ // using System;
+ // class FibonacciSequence {
+ // static int Fibonacci(int n) {
+ // if (n <= 1) {
+ // return 1;
+ // } else {
+ // return Fibonacci(n - 1) + Fibonacci(n - 2);
+ // }
+ // }
+ // static void Main() {
+ // Console.WriteLine("The 100th number in the Fibonacci sequence is: " + Fibonacci(100));
+ // }
+ // }
+ // ```
+ // In this code, we define a function `Fibonacci` that takes an integer `n` as input and returns the `n`-th number in the Fibonacci sequence. The function uses a recursive approach to calculate the value of the sequence.
+ // The `Main` method simply calls the `Fibonacci` function with the argument `100`, and prints the result to the console.
+ // Note that this code will only work for positive integers `n`. If you want to calculate the Fibonacci sequence for other types of numbers, such as real or complex numbers, you will need to modify the code accordingly.
+ #endregion lmstudio_example_1
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
new file mode 100644
index 00000000000..9a62144df2b
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs
@@ -0,0 +1,135 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example09_LMStudio_FunctionCall.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.LMStudio;
+using Azure.AI.OpenAI;
+
+namespace AutoGen.BasicSample;
+
+public class LLaMAFunctionCall
+{
+ [JsonPropertyName("name")]
+ public string Name { get; set; }
+
+ [JsonPropertyName("arguments")]
+ public JsonElement Arguments { get; set; }
+}
+
+public partial class Example09_LMStudio_FunctionCall
+{
+ ///
+ /// Get weather from location.
+ ///
+ /// location
+ /// date. type is string
+ [Function]
+ public async Task GetWeather(string location, string date)
+ {
+ return $"[Function] The weather on {date} in {location} is sunny.";
+ }
+
+
+ ///
+ /// Search query on Google and return the results.
+ ///
+ /// search query
+ [Function]
+ public async Task GoogleSearch(string query)
+ {
+ return $"[Function] Here are the search results for {query}.";
+ }
+
+ private static object SerializeFunctionDefinition(FunctionDefinition functionDefinition)
+ {
+ return new
+ {
+ type = "function",
+ function = new
+ {
+ name = functionDefinition.Name,
+ description = functionDefinition.Description,
+ parameters = functionDefinition.Parameters.ToObjectFromJson(),
+ }
+ };
+ }
+
+ public static async Task RunAsync()
+ {
+ #region lmstudio_function_call_example
+ // This example has been verified to work with Trelis-Llama-2-7b-chat-hf-function-calling-v3
+ var instance = new Example09_LMStudio_FunctionCall();
+ var config = new LMStudioConfig("localhost", 1234);
+ var systemMessage = @$"You are a helpful AI assistant.";
+
+ // Because the LM studio server doesn't support openai function call yet
+ // To simulate the function call, we can put the function call details in the system message
+ // And ask agent to response in function call object format using few-shot example
+ object[] functionList =
+ [
+ SerializeFunctionDefinition(instance.GetWeatherFunction),
+ SerializeFunctionDefinition(instance.GoogleSearchFunction)
+ ];
+ var functionListString = JsonSerializer.Serialize(functionList, new JsonSerializerOptions { WriteIndented = true });
+ var lmAgent = new LMStudioAgent(
+ name: "assistant",
+ systemMessage: @$"
+You are a helpful AI assistant
+You have access to the following functions. Use them if required:
+
+{functionListString}",
+ config: config)
+ .RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
+ {
+ // inject few-shot example to the message
+ var exampleGetWeather = new TextMessage(Role.User, "Get weather in London");
+ var exampleAnswer = new TextMessage(Role.Assistant, "{\n \"name\": \"GetWeather\",\n \"arguments\": {\n \"city\": \"London\"\n }\n}", from: innerAgent.Name);
+
+ msgs = new[] { exampleGetWeather, exampleAnswer }.Concat(msgs).ToArray();
+ var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct);
+
+ // if reply is a function call, invoke function
+ var content = reply.GetContent();
+ try
+ {
+ if (JsonSerializer.Deserialize(content) is { } functionCall)
+ {
+ var arguments = JsonSerializer.Serialize(functionCall.Arguments);
+ // invoke function wrapper
+ if (functionCall.Name == instance.GetWeatherFunction.Name)
+ {
+ var result = await instance.GetWeatherWrapper(arguments);
+ return new TextMessage(Role.Assistant, result);
+ }
+ else if (functionCall.Name == instance.GoogleSearchFunction.Name)
+ {
+ var result = await instance.GoogleSearchWrapper(arguments);
+ return new TextMessage(Role.Assistant, result);
+ }
+ else
+ {
+ throw new Exception($"Unknown function call: {functionCall.Name}");
+ }
+ }
+ }
+ catch (JsonException)
+ {
+ // ignore
+ }
+
+ return reply;
+ })
+ .RegisterPrintMessage();
+
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS);
+
+ await userProxyAgent.SendAsync(
+ receiver: lmAgent,
+ "Search the names of the five largest stocks in the US by market cap ");
+ #endregion lmstudio_function_call_example
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
new file mode 100644
index 00000000000..61c341204ec
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs
@@ -0,0 +1,80 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example10_SemanticKernel.cs
+
+using System.ComponentModel;
+using AutoGen.Core;
+using AutoGen.SemanticKernel.Extension;
+using FluentAssertions;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+namespace AutoGen.BasicSample;
+
+public class LightPlugin
+{
+ public bool IsOn { get; set; } = false;
+
+ [KernelFunction]
+ [Description("Gets the state of the light.")]
+ public string GetState() => this.IsOn ? "on" : "off";
+
+ [KernelFunction]
+ [Description("Changes the state of the light.'")]
+ public string ChangeState(bool newState)
+ {
+ this.IsOn = newState;
+ var state = this.GetState();
+
+ // Print the state to the console
+ Console.ForegroundColor = ConsoleColor.DarkBlue;
+ Console.WriteLine($"[Light is now {state}]");
+ Console.ResetColor();
+
+ return state;
+ }
+}
+
+public class Example10_SemanticKernel
+{
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ var builder = Kernel.CreateBuilder()
+ .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey);
+ var kernel = builder.Build();
+ var settings = new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions,
+ };
+
+ kernel.Plugins.AddFromObject(new LightPlugin());
+ var skAgent = kernel
+ .ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings);
+
+ // Send a message to the skAgent, the skAgent supports the following message types:
+ // - IMessage
+ // - (streaming) IMessage
+ // You can create an IMessage using MessageEnvelope.Create
+ var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.User, "Toggle the light"));
+ var reply = await skAgent.SendAsync(chatMessageContent);
+ reply.Should().BeOfType>();
+ Console.WriteLine((reply as IMessage).Content.Items[0].As().Text);
+
+ var skAgentWithMiddleware = skAgent
+ .RegisterMessageConnector() // Register the message connector to support more AutoGen built-in message types
+ .RegisterPrintMessage();
+
+ // Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage
+ // It also register a print format message hook to print the message in a human readable format to the console
+ await skAgent.SendAsync(chatMessageContent);
+ await skAgentWithMiddleware.SendAsync(new TextMessage(Role.User, "Toggle the light"));
+
+ // The more message type an agent support, the more flexible it is to be used in different scenarios
+ // For example, since the TextMessage is supported, the skAgentWithMiddleware can be used with user proxy.
+ var userProxy = new UserProxyAgent("user");
+
+ await skAgentWithMiddleware.InitiateChatAsync(userProxy, "how can I help you today");
+ }
+
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs b/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
new file mode 100644
index 00000000000..00ff321082a
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example11_Sequential_GroupChat_Example.cs
@@ -0,0 +1,94 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example11_Sequential_GroupChat_Example.cs
+
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using AutoGen.SemanticKernel;
+using AutoGen.SemanticKernel.Extension;
+using Azure.AI.OpenAI;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Plugins.Web;
+using Microsoft.SemanticKernel.Plugins.Web.Bing;
+#endregion using_statement
+
+namespace AutoGen.BasicSample;
+
+public partial class Sequential_GroupChat_Example
+{
+ public static async Task CreateBingSearchAgentAsync()
+ {
+ #region CreateBingSearchAgent
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var apiKey = config.ApiKey;
+ var kernelBuilder = Kernel.CreateBuilder()
+ .AddAzureOpenAIChatCompletion(config.DeploymentName, config.Endpoint, apiKey);
+ var bingApiKey = Environment.GetEnvironmentVariable("BING_API_KEY") ?? throw new Exception("BING_API_KEY environment variable is not set");
+ var bingSearch = new BingConnector(bingApiKey);
+ var webSearchPlugin = new WebSearchEnginePlugin(bingSearch);
+ kernelBuilder.Plugins.AddFromObject(webSearchPlugin);
+
+ var kernel = kernelBuilder.Build();
+ var kernelAgent = new SemanticKernelAgent(
+ kernel: kernel,
+ name: "bing-search",
+ systemMessage: """
+ You search results from Bing and return it as-is.
+ You put the original search result between ```bing and ```
+
+ e.g.
+ ```bing
+ xxx
+ ```
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage(); // pretty print the message
+
+ return kernelAgent;
+ #endregion CreateBingSearchAgent
+ }
+
+ public static async Task CreateSummarizerAgentAsync()
+ {
+ #region CreateSummarizerAgent
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var apiKey = config.ApiKey;
+ var endPoint = new Uri(config.Endpoint);
+
+ var openAIClient = new OpenAIClient(endPoint, new Azure.AzureKeyCredential(apiKey));
+ var openAIClientAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "summarizer",
+ modelName: config.DeploymentName,
+ systemMessage: "You summarize search result from bing in a short and concise manner");
+
+ return openAIClientAgent
+ .RegisterMessageConnector()
+ .RegisterPrintMessage(); // pretty print the message
+ #endregion CreateSummarizerAgent
+ }
+
+ public static async Task RunAsync()
+ {
+ #region Sequential_GroupChat_Example
+ var userProxyAgent = new UserProxyAgent(
+ name: "user",
+ humanInputMode: HumanInputMode.ALWAYS)
+ .RegisterPrintMessage();
+
+ var bingSearchAgent = await CreateBingSearchAgentAsync();
+ var summarizerAgent = await CreateSummarizerAgentAsync();
+
+ var groupChat = new RoundRobinGroupChat(
+ agents: [userProxyAgent, bingSearchAgent, summarizerAgent]);
+
+ var groupChatAgent = new GroupChatManager(groupChat);
+
+ var history = await userProxyAgent.InitiateChatAsync(
+ receiver: groupChatAgent,
+ message: "How to deploy an openai resource on azure",
+ maxRound: 10);
+ #endregion Sequential_GroupChat_Example
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs b/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
new file mode 100644
index 00000000000..b622a3e641e
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example12_TwoAgent_Fill_Application.cs
@@ -0,0 +1,199 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example12_TwoAgent_Fill_Application.cs
+
+using System.Text;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+
+namespace AutoGen.BasicSample;
+
+public partial class TwoAgent_Fill_Application
+{
+ private string? name = null;
+ private string? email = null;
+ private string? phone = null;
+ private string? address = null;
+ private bool? receiveUpdates = null;
+
+ [Function]
+ public async Task SaveProgress(
+ string name,
+ string email,
+ string phone,
+ string address,
+ bool? receiveUpdates)
+ {
+ this.name = !string.IsNullOrEmpty(name) ? name : this.name;
+ this.email = !string.IsNullOrEmpty(email) ? email : this.email;
+ this.phone = !string.IsNullOrEmpty(phone) ? phone : this.phone;
+ this.address = !string.IsNullOrEmpty(address) ? address : this.address;
+ this.receiveUpdates = receiveUpdates ?? this.receiveUpdates;
+
+ var missingInformationStringBuilder = new StringBuilder();
+ if (string.IsNullOrEmpty(this.name))
+ {
+ missingInformationStringBuilder.AppendLine("Name is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.email))
+ {
+ missingInformationStringBuilder.AppendLine("Email is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.phone))
+ {
+ missingInformationStringBuilder.AppendLine("Phone is missing.");
+ }
+
+ if (string.IsNullOrEmpty(this.address))
+ {
+ missingInformationStringBuilder.AppendLine("Address is missing.");
+ }
+
+ if (this.receiveUpdates == null)
+ {
+ missingInformationStringBuilder.AppendLine("ReceiveUpdates is missing.");
+ }
+
+ if (missingInformationStringBuilder.Length > 0)
+ {
+ return missingInformationStringBuilder.ToString();
+ }
+ else
+ {
+ return "Application information is saved to database.";
+ }
+ }
+
+ public static async Task CreateSaveProgressAgent()
+ {
+ var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endPoint = gpt3Config.Endpoint ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+ var apiKey = gpt3Config.ApiKey ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey));
+
+ var instance = new TwoAgent_Fill_Application();
+ var functionCallConnector = new FunctionCallMiddleware(
+ functions: [instance.SaveProgressFunctionContract],
+ functionMap: new Dictionary>>
+ {
+ { instance.SaveProgressFunctionContract.Name, instance.SaveProgressWrapper },
+ });
+
+ var chatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "application",
+ modelName: gpt3Config.DeploymentName,
+ systemMessage: """You are a helpful application form assistant who saves progress while user fills application.""")
+ .RegisterMessageConnector()
+ .RegisterMiddleware(functionCallConnector)
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ var lastUserMessage = msgs.Last() ?? throw new Exception("No user message found.");
+ var prompt = $"""
+ Save progress according to the most recent information provided by user.
+
+ ```user
+ {lastUserMessage.GetContent()}
+ ```
+ """;
+
+ return await agent.GenerateReplyAsync([lastUserMessage], option, ct);
+
+ });
+
+ return chatAgent;
+ }
+
+ public static async Task CreateAssistantAgent()
+ {
+ var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endPoint = gpt3Config.Endpoint ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+ var apiKey = gpt3Config.ApiKey ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey));
+
+ var chatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "assistant",
+ modelName: gpt3Config.DeploymentName,
+ systemMessage: """You create polite prompt to ask user provide missing information""")
+ .RegisterMessageConnector()
+ .RegisterPrintMessage()
+ .RegisterMiddleware(async (msgs, option, agent, ct) =>
+ {
+ var lastReply = msgs.Last() ?? throw new Exception("No reply found.");
+ var reply = await agent.GenerateReplyAsync(msgs, option, ct);
+
+ // if application is complete, exit conversation by sending termination message
+ if (lastReply.GetContent().Contains("Application information is saved to database."))
+ {
+ return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: agent.Name);
+ }
+ else
+ {
+ return reply;
+ }
+ });
+
+ return chatAgent;
+ }
+
+ public static async Task CreateUserAgent()
+ {
+ var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo();
+ var endPoint = gpt3Config.Endpoint ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+ var apiKey = gpt3Config.ApiKey ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var openaiClient = new OpenAIClient(new Uri(endPoint), new Azure.AzureKeyCredential(apiKey));
+
+ var chatAgent = new OpenAIChatAgent(
+ openAIClient: openaiClient,
+ name: "user",
+ modelName: gpt3Config.DeploymentName,
+ systemMessage: """
+ You are a user who is filling an application form. Simply provide the information as requested and answer the questions, don't do anything else.
+
+ here's some personal information about you:
+ - name: John Doe
+ - email: 1234567@gmail.com
+ - phone: 123-456-7890
+ - address: 1234 Main St, Redmond, WA 98052
+ - want to receive update? true
+ """)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+
+ return chatAgent;
+ }
+
+ public static async Task RunAsync()
+ {
+ var applicationAgent = await CreateSaveProgressAgent();
+ var assistantAgent = await CreateAssistantAgent();
+ var userAgent = await CreateUserAgent();
+
+ var userToApplicationTransition = Transition.Create(userAgent, applicationAgent);
+ var applicationToAssistantTransition = Transition.Create(applicationAgent, assistantAgent);
+ var assistantToUserTransition = Transition.Create(assistantAgent, userAgent);
+
+ var workflow = new Graph(
+ [
+ userToApplicationTransition,
+ applicationToAssistantTransition,
+ assistantToUserTransition,
+ ]);
+
+ var groupChat = new GroupChat(
+ members: [userAgent, applicationAgent, assistantAgent],
+ workflow: workflow);
+
+ var groupChatManager = new GroupChatManager(groupChat);
+ var initialMessage = await assistantAgent.SendAsync("Generate a greeting meesage for user and start the conversation by asking what's their name.");
+
+ var chatHistory = await userAgent.SendAsync(groupChatManager, [initialMessage], maxRound: 30);
+
+ var lastMessage = chatHistory.Last();
+ Console.WriteLine(lastMessage.GetContent());
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
new file mode 100644
index 00000000000..35b7b7d1d2f
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs
@@ -0,0 +1,68 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example13_OpenAIAgent_JsonMode.cs
+
+using System.Text.Json;
+using System.Text.Json.Serialization;
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using FluentAssertions;
+
+namespace AutoGen.BasicSample;
+
+public class Example13_OpenAIAgent_JsonMode
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ var config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(deployName: "gpt-35-turbo-0125"); // json mode only works with 0125 and later model.
+ var apiKey = config.ApiKey;
+ var endPoint = new Uri(config.Endpoint);
+
+ var openAIClient = new OpenAIClient(endPoint, new Azure.AzureKeyCredential(apiKey));
+ var openAIClientAgent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: config.DeploymentName,
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0, // explicitly set a seed to enable deterministic output
+ responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region chat_with_agent
+ var reply = await openAIClientAgent.SendAsync("My name is John, I am 25 years old, and I live in Seattle.");
+
+ var person = JsonSerializer.Deserialize(reply.GetContent());
+ Console.WriteLine($"Name: {person.Name}");
+ Console.WriteLine($"Age: {person.Age}");
+
+ if (!string.IsNullOrEmpty(person.Address))
+ {
+ Console.WriteLine($"Address: {person.Address}");
+ }
+
+ Console.WriteLine("Done.");
+ #endregion chat_with_agent
+
+ person.Name.Should().Be("John");
+ person.Age.Should().Be(25);
+ person.Address.Should().BeNullOrEmpty();
+ }
+}
+
+#region person_class
+public class Person
+{
+ [JsonPropertyName("name")]
+ public string Name { get; set; }
+
+ [JsonPropertyName("age")]
+ public int Age { get; set; }
+
+ [JsonPropertyName("address")]
+ public string Address { get; set; }
+}
+#endregion person_class
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example14_MistralClientAgent_TokenCount.cs b/dotnet/sample/AutoGen.BasicSamples/Example14_MistralClientAgent_TokenCount.cs
new file mode 100644
index 00000000000..4c8794de961
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example14_MistralClientAgent_TokenCount.cs
@@ -0,0 +1,65 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example14_MistralClientAgent_TokenCount.cs
+
+#region using_statements
+using AutoGen.Core;
+using AutoGen.Mistral;
+#endregion using_statements
+using FluentAssertions;
+
+namespace AutoGen.BasicSample;
+
+public class Example14_MistralClientAgent_TokenCount
+{
+ #region token_counter_middleware
+ public class MistralAITokenCounterMiddleware : IMiddleware
+ {
+ private readonly List responses = new List();
+ public string? Name => nameof(MistralAITokenCounterMiddleware);
+
+ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
+ {
+ var reply = await agent.GenerateReplyAsync(context.Messages, context.Options, cancellationToken);
+
+ if (reply is IMessage message)
+ {
+ responses.Add(message.Content);
+ }
+
+ return reply;
+ }
+
+ public int GetCompletionTokenCount()
+ {
+ return responses.Sum(r => r.Usage.CompletionTokens);
+ }
+ }
+ #endregion token_counter_middleware
+
+ public static async Task RunAsync()
+ {
+ #region create_mistral_client_agent
+ var apiKey = Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ?? throw new Exception("Missing MISTRAL_API_KEY environment variable.");
+ var mistralClient = new MistralClient(apiKey);
+ var agent = new MistralClientAgent(
+ client: mistralClient,
+ name: "assistant",
+ model: MistralAIModelID.OPEN_MISTRAL_7B);
+ #endregion create_mistral_client_agent
+
+ #region register_middleware
+ var tokenCounterMiddleware = new MistralAITokenCounterMiddleware();
+ var mistralMessageConnector = new MistralChatMessageConnector();
+ var agentWithTokenCounter = agent
+ .RegisterMiddleware(tokenCounterMiddleware)
+ .RegisterMiddleware(mistralMessageConnector)
+ .RegisterPrintMessage();
+ #endregion register_middleware
+
+ #region chat_with_agent
+ await agentWithTokenCounter.SendAsync("write a long, tedious story");
+ Console.WriteLine($"Completion token count: {tokenCounterMiddleware.GetCompletionTokenCount()}");
+ tokenCounterMiddleware.GetCompletionTokenCount().Should().BeGreaterThan(0);
+ #endregion chat_with_agent
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
new file mode 100644
index 00000000000..f376342ed85
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example15_GPT4V_BinaryDataImageMessage.cs
@@ -0,0 +1,62 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example15_GPT4V_BinaryDataImageMessage.cs
+
+using AutoGen.Core;
+using AutoGen.OpenAI;
+
+namespace AutoGen.BasicSample;
+
+///
+/// This example shows usage of ImageMessage. The image is loaded as BinaryData and sent to GPT-4V
+///
+///
+/// Add additional images to the ImageResources to load and send more images to GPT-4V
+///
+public static class Example15_GPT4V_BinaryDataImageMessage
+{
+ private static readonly string ImageResourcePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "ImageResources");
+
+ private static Dictionary _mediaTypeMappings = new()
+ {
+ { ".png", "image/png" },
+ { ".jpeg", "image/jpeg" },
+ { ".jpg", "image/jpeg" },
+ { ".gif", "image/gif" },
+ { ".webp", "image/webp" }
+ };
+
+ public static async Task RunAsync()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var openAiConfig = new OpenAIConfig(openAIKey, "gpt-4-vision-preview");
+
+ var visionAgent = new GPTAgent(
+ name: "gpt",
+ systemMessage: "You are a helpful AI assistant",
+ config: openAiConfig,
+ temperature: 0);
+
+ List messages =
+ [new TextMessage(Role.User, "What is this image?", from: "user")];
+ AddMessagesFromResource(ImageResourcePath, messages);
+
+ var multiModalMessage = new MultiModalMessage(Role.User, messages, from: "user");
+ var response = await visionAgent.SendAsync(multiModalMessage);
+ }
+
+ private static void AddMessagesFromResource(string imageResourcePath, List messages)
+ {
+ foreach (string file in Directory.GetFiles(imageResourcePath))
+ {
+ if (!_mediaTypeMappings.TryGetValue(Path.GetExtension(file).ToLowerInvariant(), out var mediaType))
+ continue;
+
+ using var fs = new FileStream(file, FileMode.Open, FileAccess.Read);
+ var ms = new MemoryStream();
+ fs.CopyTo(ms);
+ ms.Seek(0, SeekOrigin.Begin);
+ var imageData = BinaryData.FromStream(ms, mediaType);
+ messages.Add(new ImageMessage(Role.Assistant, imageData, from: "user"));
+ }
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs b/dotnet/sample/AutoGen.BasicSamples/Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
new file mode 100644
index 00000000000..eb8bcb179be
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
@@ -0,0 +1,62 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.cs
+#region using_statement
+using AutoGen.Core;
+using AutoGen.OpenAI;
+using AutoGen.OpenAI.Extension;
+using Azure.AI.OpenAI;
+using Azure.Core.Pipeline;
+#endregion using_statement
+
+namespace AutoGen.BasicSample;
+
+#region CustomHttpClientHandler
+public sealed class CustomHttpClientHandler : HttpClientHandler
+{
+ private string _modelServiceUrl;
+
+ public CustomHttpClientHandler(string modelServiceUrl)
+ {
+ _modelServiceUrl = modelServiceUrl;
+ }
+
+ protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ request.RequestUri = new Uri($"{_modelServiceUrl}{request.RequestUri.PathAndQuery}");
+
+ return base.SendAsync(request, cancellationToken);
+ }
+}
+#endregion CustomHttpClientHandler
+
+public class Example16_OpenAIChatAgent_ConnectToThirdPartyBackend
+{
+ public static async Task RunAsync()
+ {
+ #region create_agent
+ using var client = new HttpClient(new CustomHttpClientHandler("http://localhost:11434"));
+ var option = new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2024_04_01_Preview)
+ {
+ Transport = new HttpClientTransport(client),
+ };
+
+ // api-key is not required for local server
+ // so you can use any string here
+ var openAIClient = new OpenAIClient("api-key", option);
+ var model = "llama3";
+
+ var agent = new OpenAIChatAgent(
+ openAIClient: openAIClient,
+ name: "assistant",
+ modelName: model,
+ systemMessage: "You are a helpful assistant designed to output JSON.",
+ seed: 0)
+ .RegisterMessageConnector()
+ .RegisterPrintMessage();
+ #endregion create_agent
+
+ #region send_message
+ await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
+ #endregion send_message
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs b/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs
new file mode 100644
index 00000000000..87b4ee0ab4c
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs
@@ -0,0 +1,3 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// GlobalUsing.cs
+
diff --git a/dotnet/sample/AutoGen.BasicSamples/ImageResources/square.png b/dotnet/sample/AutoGen.BasicSamples/ImageResources/square.png
new file mode 100644
index 00000000000..afb4f4cd4df
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/ImageResources/square.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8323d0b8eceb752e14c29543b2e28bb2fc648ed9719095c31b7708867a4dc918
+size 491
diff --git a/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
new file mode 100644
index 00000000000..37c9b0d7ade
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs
@@ -0,0 +1,40 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// LLMConfiguration.cs
+
+using AutoGen.OpenAI;
+
+namespace AutoGen.BasicSample;
+
+internal static class LLMConfiguration
+{
+ public static OpenAIConfig GetOpenAIGPT3_5_Turbo()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-3.5-turbo";
+ return new OpenAIConfig(openAIKey, modelId);
+ }
+
+ public static OpenAIConfig GetOpenAIGPT4()
+ {
+ var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable.");
+ var modelId = "gpt-4";
+
+ return new OpenAIConfig(openAIKey, modelId);
+ }
+
+ public static AzureOpenAIConfig GetAzureOpenAIGPT3_5_Turbo(string deployName = "gpt-35-turbo-16k")
+ {
+ var azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+
+ return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey);
+ }
+
+ public static AzureOpenAIConfig GetAzureOpenAIGPT4(string deployName = "gpt-4")
+ {
+ var azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
+ var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
+
+ return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey);
+ }
+}
diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs
new file mode 100644
index 00000000000..11b5127ade0
--- /dev/null
+++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs
@@ -0,0 +1,6 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Program.cs
+
+using AutoGen.BasicSample;
+Console.ReadLine();
+await Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.RunAsync();
diff --git a/dotnet/src/AutoGen.Core/Agent/DefaultReplyAgent.cs b/dotnet/src/AutoGen.Core/Agent/DefaultReplyAgent.cs
new file mode 100644
index 00000000000..647a2ece79d
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/DefaultReplyAgent.cs
@@ -0,0 +1,31 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// DefaultReplyAgent.cs
+
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class DefaultReplyAgent : IAgent
+{
+ public DefaultReplyAgent(
+ string name,
+ string? defaultReply)
+ {
+ Name = name;
+ DefaultReply = defaultReply ?? string.Empty;
+ }
+
+ public string Name { get; }
+
+ public string DefaultReply { get; } = string.Empty;
+
+ public async Task GenerateReplyAsync(
+ IEnumerable _,
+ GenerateReplyOptions? __ = null,
+ CancellationToken ___ = default)
+ {
+ return new TextMessage(Role.Assistant, DefaultReply, from: this.Name);
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/GroupChatManager.cs b/dotnet/src/AutoGen.Core/Agent/GroupChatManager.cs
new file mode 100644
index 00000000000..db40f801dea
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/GroupChatManager.cs
@@ -0,0 +1,34 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// GroupChatManager.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class GroupChatManager : IAgent
+{
+ public GroupChatManager(IGroupChat groupChat)
+ {
+ GroupChat = groupChat;
+ }
+ public string Name => throw new ArgumentException("GroupChatManager does not have a name");
+
+ public IEnumerable? Messages { get; private set; }
+
+ public IGroupChat GroupChat { get; }
+
+ public async Task GenerateReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options,
+ CancellationToken cancellationToken = default)
+ {
+ var response = await GroupChat.CallAsync(messages, ct: cancellationToken);
+ Messages = response;
+
+ return response.Last();
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/IAgent.cs b/dotnet/src/AutoGen.Core/Agent/IAgent.cs
new file mode 100644
index 00000000000..b9149008480
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/IAgent.cs
@@ -0,0 +1,50 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// IAgent.cs
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+public interface IAgent
+{
+ public string Name { get; }
+
+ ///
+ /// Generate reply
+ ///
+ /// conversation history
+ /// completion option. If provided, it should override existing option if there's any
+ public Task GenerateReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default);
+}
+
+public class GenerateReplyOptions
+{
+ public GenerateReplyOptions()
+ {
+ }
+
+ ///
+ /// Copy constructor
+ ///
+ /// other option to copy from
+ public GenerateReplyOptions(GenerateReplyOptions other)
+ {
+ this.Temperature = other.Temperature;
+ this.MaxToken = other.MaxToken;
+ this.StopSequence = other.StopSequence?.Select(s => s)?.ToArray();
+ this.Functions = other.Functions?.Select(f => f)?.ToArray();
+ }
+
+ public float? Temperature { get; set; }
+
+ public int? MaxToken { get; set; }
+
+ public string[]? StopSequence { get; set; }
+
+ public FunctionContract[]? Functions { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs
new file mode 100644
index 00000000000..a0b01e7c3e2
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs
@@ -0,0 +1,54 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// IMiddlewareAgent.cs
+
+using System.Collections.Generic;
+
+namespace AutoGen.Core;
+
+public interface IMiddlewareAgent : IAgent
+{
+ ///
+ /// Get the inner agent.
+ ///
+ IAgent Agent { get; }
+
+ ///
+ /// Get the middlewares.
+ ///
+ IEnumerable Middlewares { get; }
+
+ ///
+ /// Use middleware.
+ ///
+ void Use(IMiddleware middleware);
+}
+
+public interface IMiddlewareStreamAgent : IStreamingAgent
+{
+ ///
+ /// Get the inner agent.
+ ///
+ IStreamingAgent StreamingAgent { get; }
+
+ IEnumerable StreamingMiddlewares { get; }
+
+ void UseStreaming(IStreamingMiddleware middleware);
+}
+
+public interface IMiddlewareAgent : IMiddlewareAgent
+ where T : IAgent
+{
+ ///
+ /// Get the typed inner agent.
+ ///
+ T TAgent { get; }
+}
+
+public interface IMiddlewareStreamAgent : IMiddlewareStreamAgent
+ where T : IStreamingAgent
+{
+ ///
+ /// Get the typed inner agent.
+ ///
+ T TStreamingAgent { get; }
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
new file mode 100644
index 00000000000..665f18bac12
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
@@ -0,0 +1,18 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// IStreamingAgent.cs
+
+using System.Collections.Generic;
+using System.Threading;
+
+namespace AutoGen.Core;
+
+///
+/// agent that supports streaming reply
+///
+public interface IStreamingAgent : IAgent
+{
+ public IAsyncEnumerable GenerateStreamingReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default);
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs
new file mode 100644
index 00000000000..84d0d4b59e6
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs
@@ -0,0 +1,140 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// MiddlewareAgent.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+///
+/// An agent that allows you to add middleware and modify the behavior of an existing agent.
+///
+public class MiddlewareAgent : IMiddlewareAgent
+{
+ private IAgent _agent;
+ private readonly List middlewares = new();
+
+ ///
+ /// Create a new instance of
+ ///
+ /// the inner agent where middleware will be added.
+ /// the name of the agent if provided. Otherwise, the name of will be used.
+ public MiddlewareAgent(IAgent innerAgent, string? name = null, IEnumerable? middlewares = null)
+ {
+ this.Name = name ?? innerAgent.Name;
+ this._agent = innerAgent;
+ if (middlewares != null && middlewares.Any())
+ {
+ foreach (var middleware in middlewares)
+ {
+ this.Use(middleware);
+ }
+ }
+ }
+
+ ///
+ /// Create a new instance of by copying the middlewares from another .
+ ///
+ public MiddlewareAgent(MiddlewareAgent other)
+ {
+ this.Name = other.Name;
+ this._agent = other._agent;
+ this.middlewares.AddRange(other.middlewares);
+ }
+
+ public string Name { get; }
+
+ ///
+ /// Get the inner agent.
+ ///
+ public IAgent Agent => this._agent;
+
+ ///
+ /// Get the middlewares.
+ ///
+ public IEnumerable Middlewares => this.middlewares;
+
+ public Task GenerateReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ return _agent.GenerateReplyAsync(messages, options, cancellationToken);
+ }
+
+ ///
+ /// Add a middleware to the agent. If multiple middlewares are added, they will be executed in the LIFO order.
+ /// Call into the next function to continue the execution of the next middleware.
+ /// Short cut middleware execution by not calling into the next function.
+ ///
+ public void Use(Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null)
+ {
+ var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
+ {
+ return await func(context.Messages, context.Options, agent, cancellationToken);
+ });
+
+ this.Use(middleware);
+ }
+
+ public void Use(IMiddleware middleware)
+ {
+ this.middlewares.Add(middleware);
+ _agent = new DelegateAgent(middleware, _agent);
+ }
+
+ public override string ToString()
+ {
+ var names = this.Middlewares.Select(m => m.Name ?? "[Unknown middleware]");
+ var namesPlusAgentName = names.Append(this.Name);
+
+ return namesPlusAgentName.Aggregate((a, b) => $"{a} -> {b}");
+ }
+
+ private class DelegateAgent : IAgent
+ {
+ private readonly IAgent innerAgent;
+ private readonly IMiddleware middleware;
+
+ public DelegateAgent(IMiddleware middleware, IAgent innerAgent)
+ {
+ this.middleware = middleware;
+ this.innerAgent = innerAgent;
+ }
+
+ public string Name { get => this.innerAgent.Name; }
+
+ public Task GenerateReplyAsync(
+ IEnumerable messages,
+ GenerateReplyOptions? options = null,
+ CancellationToken cancellationToken = default)
+ {
+ var context = new MiddlewareContext(messages, options);
+ return this.middleware.InvokeAsync(context, this.innerAgent, cancellationToken);
+ }
+ }
+}
+
+public sealed class MiddlewareAgent : MiddlewareAgent, IMiddlewareAgent
+ where T : IAgent
+{
+ public MiddlewareAgent(T innerAgent, string? name = null)
+ : base(innerAgent, name)
+ {
+ this.TAgent = innerAgent;
+ }
+
+ public MiddlewareAgent(MiddlewareAgent other)
+ : base(other)
+ {
+ this.TAgent = other.TAgent;
+ }
+
+ ///
+ /// Get the inner agent of type .
+ ///
+ public T TAgent { get; }
+}
diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
new file mode 100644
index 00000000000..251d3c110f9
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs
@@ -0,0 +1,119 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// MiddlewareStreamingAgent.cs
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class MiddlewareStreamingAgent : IMiddlewareStreamAgent
+{
+ private IStreamingAgent _agent;
+ private readonly List _streamingMiddlewares = new();
+
+ public MiddlewareStreamingAgent(
+ IStreamingAgent agent,
+ string? name = null,
+ IEnumerable? streamingMiddlewares = null)
+ {
+ this.Name = name ?? agent.Name;
+ _agent = agent;
+
+ if (streamingMiddlewares != null && streamingMiddlewares.Any())
+ {
+ foreach (var middleware in streamingMiddlewares)
+ {
+ this.UseStreaming(middleware);
+ }
+ }
+ }
+
+ ///
+ /// Get the inner agent.
+ ///
+ public IStreamingAgent StreamingAgent => _agent;
+
+ ///
+ /// Get the streaming middlewares.
+ ///
+ public IEnumerable StreamingMiddlewares => _streamingMiddlewares;
+
+ public string Name { get; }
+
+ public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ return _agent.GenerateReplyAsync(messages, options, cancellationToken);
+ }
+
+ public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+
+ return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
+ }
+
+ public void UseStreaming(IStreamingMiddleware middleware)
+ {
+ _streamingMiddlewares.Add(middleware);
+ _agent = new DelegateStreamingAgent(middleware, _agent);
+ }
+
+ private class DelegateStreamingAgent : IStreamingAgent
+ {
+ private IStreamingMiddleware? streamingMiddleware;
+ private IStreamingAgent innerAgent;
+
+ public string Name => innerAgent.Name;
+
+ public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent next)
+ {
+ this.streamingMiddleware = middleware;
+ this.innerAgent = next;
+ }
+
+
+ public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ if (this.streamingMiddleware is null)
+ {
+ return innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
+ }
+
+ var context = new MiddlewareContext(messages, options);
+ return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
+ }
+
+ public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ if (streamingMiddleware is null)
+ {
+ return innerAgent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
+ }
+
+ var context = new MiddlewareContext(messages, options);
+ return streamingMiddleware.InvokeAsync(context, innerAgent, cancellationToken);
+ }
+ }
+}
+
+public sealed class MiddlewareStreamingAgent : MiddlewareStreamingAgent, IMiddlewareStreamAgent
+ where T : IStreamingAgent
+{
+ public MiddlewareStreamingAgent(T innerAgent, string? name = null, IEnumerable? streamingMiddlewares = null)
+ : base(innerAgent, name, streamingMiddlewares)
+ {
+ TStreamingAgent = innerAgent;
+ }
+
+ public MiddlewareStreamingAgent(MiddlewareStreamingAgent other)
+ : base(other)
+ {
+ TStreamingAgent = other.TStreamingAgent;
+ }
+
+ ///
+ /// Get the inner agent.
+ ///
+ public T TStreamingAgent { get; }
+}
diff --git a/dotnet/src/AutoGen.Core/AutoGen.Core.csproj b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj
new file mode 100644
index 00000000000..ebbec3f0a46
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj
@@ -0,0 +1,22 @@
+ο»Ώ
+
+ netstandard2.0
+ AutoGen.Core
+
+
+
+
+
+
+ AutoGen.Core
+
+ Core library for AutoGen. This package provides contracts and core functionalities for AutoGen.
+
+
+
+
+
+
+
+
+
diff --git a/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs
new file mode 100644
index 00000000000..44ce8838b73
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs
@@ -0,0 +1,174 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// AgentExtension.cs
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public static class AgentExtension
+{
+ ///
+ /// Send message to an agent.
+ ///
+ /// message to send. will be added to the end of if provided
+ /// sender agent.
+ /// chat history.
+ /// conversation history
+ public static async Task SendAsync(
+ this IAgent agent,
+ IMessage? message = null,
+ IEnumerable? chatHistory = null,
+ CancellationToken ct = default)
+ {
+ var messages = new List();
+
+ if (chatHistory != null)
+ {
+ messages.AddRange(chatHistory);
+ }
+
+ if (message != null)
+ {
+ messages.Add(message);
+ }
+
+
+ var result = await agent.GenerateReplyAsync(messages, cancellationToken: ct);
+
+ return result;
+ }
+
+ ///
+ /// Send message to an agent.
+ ///
+ /// sender agent.
+ /// message to send. will be added to the end of if provided
+ /// chat history.
+ /// conversation history
+ public static async Task SendAsync(
+ this IAgent agent,
+ string message,
+ IEnumerable? chatHistory = null,
+ CancellationToken ct = default)
+ {
+ var msg = new TextMessage(Role.User, message);
+
+ return await agent.SendAsync(msg, chatHistory, ct);
+ }
+
+ ///
+ /// Send message to another agent.
+ ///
+ /// sender agent.
+ /// receiver agent.
+ /// chat history.
+ /// max conversation round.
+ /// conversation history
+ public static async Task> SendAsync(
+ this IAgent agent,
+ IAgent receiver,
+ IEnumerable chatHistory,
+ int maxRound = 10,
+ CancellationToken ct = default)
+ {
+ if (receiver is GroupChatManager manager)
+ {
+ var gc = manager.GroupChat;
+
+ return await agent.SendMessageToGroupAsync(gc, chatHistory, maxRound, ct);
+ }
+
+ var groupChat = new RoundRobinGroupChat(
+ agents: new[]
+ {
+ agent,
+ receiver,
+ });
+
+ return await groupChat.CallAsync(chatHistory, maxRound, ct: ct);
+ }
+
+ ///
+ /// Send message to another agent.
+ ///
+ /// sender agent.
+ /// message to send. will be added to the end of if provided
+ /// receiver agent.
+ /// chat history.
+ /// max conversation round.
+ /// conversation history
+ public static async Task> SendAsync(
+ this IAgent agent,
+ IAgent receiver,
+ string message,
+ IEnumerable? chatHistory = null,
+ int maxRound = 10,
+ CancellationToken ct = default)
+ {
+ var msg = new TextMessage(Role.User, message)
+ {
+ From = agent.Name,
+ };
+
+ chatHistory = chatHistory ?? new List();
+ chatHistory = chatHistory.Append(msg);
+
+ return await agent.SendAsync(receiver, chatHistory, maxRound, ct);
+ }
+
+ ///
+ /// Shortcut API to send message to another agent.
+ ///
+ /// sender agent
+ /// receiver agent
+ /// message to send
+ /// max round
+ public static async Task> InitiateChatAsync(
+ this IAgent agent,
+ IAgent receiver,
+ string? message = null,
+ int maxRound = 10,
+ CancellationToken ct = default)
+ {
+ var chatHistory = new List();
+ if (message != null)
+ {
+ var msg = new TextMessage(Role.User, message)
+ {
+ From = agent.Name,
+ };
+
+ chatHistory.Add(msg);
+ }
+
+ return await agent.SendAsync(receiver, chatHistory, maxRound, ct);
+ }
+
+ public static async Task> SendMessageToGroupAsync(
+ this IAgent agent,
+ IGroupChat groupChat,
+ string msg,
+ IEnumerable? chatHistory = null,
+ int maxRound = 10,
+ CancellationToken ct = default)
+ {
+ var chatMessage = new TextMessage(Role.Assistant, msg, from: agent.Name);
+ chatHistory = chatHistory ?? Enumerable.Empty();
+ chatHistory = chatHistory.Append(chatMessage);
+
+ return await agent.SendMessageToGroupAsync(groupChat, chatHistory, maxRound, ct);
+ }
+
+ public static async Task> SendMessageToGroupAsync(
+ this IAgent _,
+ IGroupChat groupChat,
+ IEnumerable? chatHistory = null,
+ int maxRound = 10,
+ CancellationToken ct = default)
+ {
+ return await groupChat.CallAsync(chatHistory, maxRound, ct);
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
new file mode 100644
index 00000000000..e3e44622c81
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs
@@ -0,0 +1,109 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// GroupChatExtension.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace AutoGen.Core;
+
+public static class GroupChatExtension
+{
+ public const string TERMINATE = "[GROUPCHAT_TERMINATE]";
+ public const string CLEAR_MESSAGES = "[GROUPCHAT_CLEAR_MESSAGES]";
+
+ [Obsolete("please use SendIntroduction")]
+ public static void AddInitializeMessage(this IAgent agent, string message, IGroupChat groupChat)
+ {
+ var msg = new TextMessage(Role.User, message)
+ {
+ From = agent.Name
+ };
+
+ groupChat.SendIntroduction(msg);
+ }
+
+ ///
+ /// Send an instruction message to the group chat.
+ ///
+ public static void SendIntroduction(this IAgent agent, string message, IGroupChat groupChat)
+ {
+ var msg = new TextMessage(Role.User, message)
+ {
+ From = agent.Name
+ };
+
+ groupChat.SendIntroduction(msg);
+ }
+
+ public static IEnumerable MessageToKeep(
+ this IGroupChat _,
+ IEnumerable messages)
+ {
+ var lastCLRMessageIndex = messages.ToList()
+ .FindLastIndex(x => x.IsGroupChatClearMessage());
+
+ // if multiple clr messages, e.g [msg, clr, msg, clr, msg, clr, msg]
+ // only keep the the messages after the second last clr message.
+ if (messages.Count(m => m.IsGroupChatClearMessage()) > 1)
+ {
+ lastCLRMessageIndex = messages.ToList()
+ .FindLastIndex(lastCLRMessageIndex - 1, lastCLRMessageIndex - 1, x => x.IsGroupChatClearMessage());
+ messages = messages.Skip(lastCLRMessageIndex);
+ }
+
+ lastCLRMessageIndex = messages.ToList()
+ .FindLastIndex(x => x.IsGroupChatClearMessage());
+
+ if (lastCLRMessageIndex != -1 && messages.Count() - lastCLRMessageIndex >= 2)
+ {
+ messages = messages.Skip(lastCLRMessageIndex);
+ }
+
+ return messages;
+ }
+
+ ///
+ /// Return true if contains , otherwise false.
+ ///
+ ///
+ ///
+ public static bool IsGroupChatTerminateMessage(this IMessage message)
+ {
+ return message.GetContent()?.Contains(TERMINATE) ?? false;
+ }
+
+ public static bool IsGroupChatClearMessage(this IMessage message)
+ {
+ return message.GetContent()?.Contains(CLEAR_MESSAGES) ?? false;
+ }
+
+ public static IEnumerable ProcessConversationForAgent(
+ this IGroupChat groupChat,
+ IEnumerable initialMessages,
+ IEnumerable messages)
+ {
+ messages = groupChat.MessageToKeep(messages);
+ return initialMessages.Concat(messages);
+ }
+
+ internal static IEnumerable ProcessConversationsForRolePlay(
+ this IGroupChat groupChat,
+ IEnumerable initialMessages,
+ IEnumerable messages)
+ {
+ messages = groupChat.MessageToKeep(messages);
+ var messagesToKeep = initialMessages.Concat(messages);
+
+ return messagesToKeep.Select((x, i) =>
+ {
+ var msg = @$"From {x.From}:
+{x.GetContent()}
+
+round #
+ {i}";
+
+ return new TextMessage(Role.User, content: msg);
+ });
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
new file mode 100644
index 00000000000..47dbad55e30
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs
@@ -0,0 +1,213 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// MessageExtension.cs
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace AutoGen.Core;
+
+public static class MessageExtension
+{
+ private static string separator = new string('-', 20);
+
+ public static string FormatMessage(this IMessage message)
+ {
+ return message switch
+ {
+ Message msg => msg.FormatMessage(),
+ TextMessage textMessage => textMessage.FormatMessage(),
+ ImageMessage imageMessage => imageMessage.FormatMessage(),
+ ToolCallMessage toolCallMessage => toolCallMessage.FormatMessage(),
+ ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.FormatMessage(),
+ AggregateMessage aggregateMessage => aggregateMessage.FormatMessage(),
+ _ => message.ToString(),
+ };
+ }
+
+ public static string FormatMessage(this TextMessage message)
+ {
+ var sb = new StringBuilder();
+ // write from
+ sb.AppendLine($"TextMessage from {message.From}");
+ // write a seperator
+ sb.AppendLine(separator);
+ sb.AppendLine(message.Content);
+ // write a seperator
+ sb.AppendLine(separator);
+
+ return sb.ToString();
+ }
+
+ public static string FormatMessage(this ImageMessage message)
+ {
+ var sb = new StringBuilder();
+ // write from
+ sb.AppendLine($"ImageMessage from {message.From}");
+ // write a seperator
+ sb.AppendLine(separator);
+ sb.AppendLine($"Image: {message.Url}");
+ // write a seperator
+ sb.AppendLine(separator);
+
+ return sb.ToString();
+ }
+
+ public static string FormatMessage(this ToolCallMessage message)
+ {
+ var sb = new StringBuilder();
+ // write from
+ sb.AppendLine($"ToolCallMessage from {message.From}");
+
+ // write a seperator
+ sb.AppendLine(separator);
+
+ foreach (var toolCall in message.ToolCalls)
+ {
+ sb.AppendLine($"- {toolCall.FunctionName}: {toolCall.FunctionArguments}");
+ }
+
+ sb.AppendLine(separator);
+
+ return sb.ToString();
+ }
+
+ public static string FormatMessage(this ToolCallResultMessage message)
+ {
+ var sb = new StringBuilder();
+ // write from
+ sb.AppendLine($"ToolCallResultMessage from {message.From}");
+
+ // write a seperator
+ sb.AppendLine(separator);
+
+ foreach (var toolCall in message.ToolCalls)
+ {
+ sb.AppendLine($"- {toolCall.FunctionName}: {toolCall.Result}");
+ }
+
+ sb.AppendLine(separator);
+
+ return sb.ToString();
+ }
+
+ public static string FormatMessage(this AggregateMessage message)
+ {
+ var sb = new StringBuilder();
+ // write from
+ sb.AppendLine($"AggregateMessage from {message.From}");
+
+ // write a seperator
+ sb.AppendLine(separator);
+
+ sb.AppendLine("ToolCallMessage:");
+ sb.AppendLine(message.Message1.FormatMessage());
+
+ sb.AppendLine("ToolCallResultMessage:");
+ sb.AppendLine(message.Message2.FormatMessage());
+
+ sb.AppendLine(separator);
+
+ return sb.ToString();
+ }
+ public static string FormatMessage(this Message message)
+ {
+ var sb = new StringBuilder();
+ // write from
+ sb.AppendLine($"Message from {message.From}");
+ // write a seperator
+ sb.AppendLine(separator);
+
+ // write content
+ sb.AppendLine($"content: {message.Content}");
+
+ // write function name if exists
+ if (!string.IsNullOrEmpty(message.FunctionName))
+ {
+ sb.AppendLine($"function name: {message.FunctionName}");
+ sb.AppendLine($"function arguments: {message.FunctionArguments}");
+ }
+
+ // write metadata
+ if (message.Metadata is { Count: > 0 })
+ {
+ sb.AppendLine($"metadata:");
+ foreach (var item in message.Metadata)
+ {
+ sb.AppendLine($"{item.Key}: {item.Value}");
+ }
+ }
+
+ // write a seperator
+ sb.AppendLine(separator);
+
+ return sb.ToString();
+ }
+
+ public static bool IsSystemMessage(this IMessage message)
+ {
+ return message switch
+ {
+ TextMessage textMessage => textMessage.Role == Role.System,
+ Message msg => msg.Role == Role.System,
+ _ => false,
+ };
+ }
+
+ ///
+ /// Get the content from the message
+ /// if the message is a or , return the content
+ /// if the message is a and only contains one function call, return the result of that function call
+ /// if the message is a where TMessage1 is and TMessage2 is and the second message only contains one function call, return the result of that function call
+ /// for all other situation, return null.
+ ///
+ ///
+ public static string? GetContent(this IMessage message)
+ {
+ return message switch
+ {
+ TextMessage textMessage => textMessage.Content,
+ Message msg => msg.Content,
+ ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
+ AggregateMessage aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
+ _ => null,
+ };
+ }
+
+ ///
+ /// Get the role from the message if it's available.
+ ///
+ public static Role? GetRole(this IMessage message)
+ {
+ return message switch
+ {
+ TextMessage textMessage => textMessage.Role,
+ Message msg => msg.Role,
+ ImageMessage img => img.Role,
+ MultiModalMessage multiModal => multiModal.Role,
+ _ => null,
+ };
+ }
+
+ ///
+ /// Return the tool calls from the message if it's available.
+ /// if the message is a , return its tool calls
+ /// if the message is a and the function name and function arguments are available, return a list of tool call with one item
+ /// if the message is a where TMessage1 is and TMessage2 is , return the tool calls from the first message
+ ///
+ ///
+ ///
+ public static IList? GetToolCalls(this IMessage message)
+ {
+ return message switch
+ {
+ ToolCallMessage toolCallMessage => toolCallMessage.ToolCalls,
+ Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null
+ ? msg.Content is not null ? new List { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) }
+ : new List { new ToolCall(msg.FunctionName, msg.FunctionArguments) }
+ : null,
+ AggregateMessage aggregateMessage => aggregateMessage.Message1.ToolCalls,
+ _ => null,
+ };
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs
new file mode 100644
index 00000000000..5beed7fd815
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs
@@ -0,0 +1,145 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// MiddlewareExtension.cs
+
+using System;
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public static class MiddlewareExtension
+{
+ ///
+ /// Register a auto reply hook to an agent. The hook will be called before the agent generate the reply.
+ /// If the hook return a non-null reply, then that non-null reply will be returned directly without calling the agent.
+ /// Otherwise, the agent will generate the reply.
+ /// This is useful when you want to override the agent reply in some cases.
+ ///
+ ///
+ ///
+ ///
+ /// throw when agent name is null.
+ [Obsolete("Use RegisterMiddleware instead.")]
+ public static MiddlewareAgent RegisterReply(
+ this TAgent agent,
+ Func, CancellationToken, Task> replyFunc)
+ where TAgent : IAgent
+ {
+ return agent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var reply = await replyFunc(messages, ct);
+
+ if (reply != null)
+ {
+ return reply;
+ }
+
+ return await agent.GenerateReplyAsync(messages, options, ct);
+ });
+ }
+
+ ///
+ /// Register a post process hook to an agent. The hook will be called before the agent return the reply and after the agent generate the reply.
+ /// This is useful when you want to customize arbitrary behavior before the agent return the reply.
+ ///
+ /// One example is , which print the formatted message to console before the agent return the reply.
+ ///
+ /// throw when agent name is null.
+ [Obsolete("Use RegisterMiddleware instead.")]
+ public static MiddlewareAgent RegisterPostProcess(
+ this TAgent agent,
+ Func, IMessage, CancellationToken, Task> postprocessFunc)
+ where TAgent : IAgent
+ {
+ return agent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var reply = await agent.GenerateReplyAsync(messages, options, ct);
+
+ return await postprocessFunc(messages, reply, ct);
+ });
+ }
+
+ ///
+ /// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply.
+ ///
+ /// throw when agent name is null.
+ [Obsolete("Use RegisterMiddleware instead.")]
+ public static MiddlewareAgent RegisterPreProcess(
+ this TAgent agent,
+ Func, CancellationToken, Task>> preprocessFunc)
+ where TAgent : IAgent
+ {
+ return agent.RegisterMiddleware(async (messages, options, agent, ct) =>
+ {
+ var newMessages = await preprocessFunc(messages, ct);
+
+ return await agent.GenerateReplyAsync(newMessages, options, ct);
+ });
+ }
+
+ ///
+ /// Register a middleware to an existing agent and return a new agent with the middleware.
+ /// To register a streaming middleware, use .
+ ///
+ public static MiddlewareAgent RegisterMiddleware(
+ this TAgent agent,
+ Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func,
+ string? middlewareName = null)
+ where TAgent : IAgent
+ {
+ var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
+ {
+ return await func(context.Messages, context.Options, agent, cancellationToken);
+ });
+
+ return agent.RegisterMiddleware(middleware);
+ }
+
+ ///
+ /// Register a middleware to an existing agent and return a new agent with the middleware.
+ /// To register a streaming middleware, use .
+ ///
+ public static MiddlewareAgent RegisterMiddleware(
+ this TAgent agent,
+ IMiddleware middleware)
+ where TAgent : IAgent
+ {
+ var middlewareAgent = new MiddlewareAgent(agent);
+
+ return middlewareAgent.RegisterMiddleware(middleware);
+ }
+
+ ///
+ /// Register a middleware to an existing agent and return a new agent with the middleware.
+ /// To register a streaming middleware, use .
+ ///
+ public static MiddlewareAgent RegisterMiddleware(
+ this MiddlewareAgent agent,
+ Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func,
+ string? middlewareName = null)
+ where TAgent : IAgent
+ {
+ var delegateMiddleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
+ {
+ return await func(context.Messages, context.Options, agent, cancellationToken);
+ });
+
+ return agent.RegisterMiddleware(delegateMiddleware);
+ }
+
+ ///
+ /// Register a middleware to an existing agent and return a new agent with the middleware.
+ /// To register a streaming middleware, use .
+ ///
+ public static MiddlewareAgent RegisterMiddleware(
+ this MiddlewareAgent agent,
+ IMiddleware middleware)
+ where TAgent : IAgent
+ {
+ var copyAgent = new MiddlewareAgent(agent);
+ copyAgent.Use(middleware);
+
+ return copyAgent;
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs
new file mode 100644
index 00000000000..262b50d125d
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs
@@ -0,0 +1,69 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// PrintMessageMiddlewareExtension.cs
+
+using System;
+
+namespace AutoGen.Core;
+
+public static class PrintMessageMiddlewareExtension
+{
+ [Obsolete("This API will be removed in v0.1.0, Use RegisterPrintMessage instead.")]
+ public static MiddlewareAgent RegisterPrintFormatMessageHook(this TAgent agent)
+ where TAgent : IAgent
+ {
+ return RegisterPrintMessage(agent);
+ }
+
+ [Obsolete("This API will be removed in v0.1.0, Use RegisterPrintMessage instead.")]
+ public static MiddlewareAgent RegisterPrintFormatMessageHook(this MiddlewareAgent agent)
+ where TAgent : IAgent
+ {
+ return RegisterPrintMessage(agent);
+ }
+
+ [Obsolete("This API will be removed in v0.1.0, Use RegisterPrintMessage instead.")]
+ public static MiddlewareStreamingAgent RegisterPrintFormatMessageHook(this MiddlewareStreamingAgent agent)
+ where TAgent : IStreamingAgent
+ {
+ return RegisterPrintMessage(agent);
+ }
+
+ ///
+ /// Register a to which print formatted message to console.
+ ///
+ public static MiddlewareAgent RegisterPrintMessage(this TAgent agent)
+ where TAgent : IAgent
+ {
+ var middleware = new PrintMessageMiddleware();
+ var middlewareAgent = new MiddlewareAgent(agent);
+ middlewareAgent.Use(middleware);
+
+ return middlewareAgent;
+ }
+
+ ///
+ /// Register a to which print formatted message to console.
+ ///
+ public static MiddlewareAgent RegisterPrintMessage(this MiddlewareAgent agent)
+ where TAgent : IAgent
+ {
+ var middleware = new PrintMessageMiddleware();
+ var middlewareAgent = new MiddlewareAgent(agent);
+ middlewareAgent.Use(middleware);
+
+ return middlewareAgent;
+ }
+
+ ///
+ /// Register a to which print formatted message to console.
+ ///
+ public static MiddlewareStreamingAgent RegisterPrintMessage(this MiddlewareStreamingAgent agent)
+ where TAgent : IStreamingAgent
+ {
+ var middleware = new PrintMessageMiddleware();
+ var middlewareAgent = new MiddlewareStreamingAgent(agent);
+ middlewareAgent.UseStreaming(middleware);
+
+ return middlewareAgent;
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs
new file mode 100644
index 00000000000..2ec7b3f9f3b
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs
@@ -0,0 +1,37 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// StreamingMiddlewareExtension.cs
+
+namespace AutoGen.Core;
+
+public static class StreamingMiddlewareExtension
+{
+ ///
+ /// Register an to an existing and return a new agent with the registered middleware.
+ /// For registering an , please refer to
+ ///
+ public static MiddlewareStreamingAgent RegisterStreamingMiddleware(
+ this TStreamingAgent agent,
+ IStreamingMiddleware middleware)
+ where TStreamingAgent : IStreamingAgent
+ {
+ var middlewareAgent = new MiddlewareStreamingAgent(agent);
+ middlewareAgent.UseStreaming(middleware);
+
+ return middlewareAgent;
+ }
+
+ ///
+ /// Register an to an existing and return a new agent with the registered middleware.
+ /// For registering an , please refer to
+ ///
+ public static MiddlewareStreamingAgent RegisterStreamingMiddleware(
+ this MiddlewareStreamingAgent agent,
+ IStreamingMiddleware middleware)
+ where TAgent : IStreamingAgent
+ {
+ var copyAgent = new MiddlewareStreamingAgent(agent);
+ copyAgent.UseStreaming(middleware);
+
+ return copyAgent;
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
new file mode 100644
index 00000000000..2c828c26d89
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
@@ -0,0 +1,93 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// FunctionAttribute.cs
+
+using System;
+using System.Collections.Generic;
+
+namespace AutoGen.Core;
+
+[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
+public class FunctionAttribute : Attribute
+{
+ public string? FunctionName { get; }
+
+ public string? Description { get; }
+
+ public FunctionAttribute(string? functionName = null, string? description = null)
+ {
+ FunctionName = functionName;
+ Description = description;
+ }
+}
+
+public class FunctionContract
+{
+ ///
+ /// The namespace of the function.
+ ///
+ public string? Namespace { get; set; }
+
+ ///
+ /// The class name of the function.
+ ///
+ public string? ClassName { get; set; }
+
+ ///
+ /// The name of the function.
+ ///
+ public string? Name { get; set; }
+
+ ///
+ /// The description of the function.
+ /// If a structured comment is available, the description will be extracted from the summary section.
+ /// Otherwise, the description will be null.
+ ///
+ public string? Description { get; set; }
+
+ ///
+ /// The parameters of the function.
+ ///
+ public IEnumerable? Parameters { get; set; }
+
+ ///
+ /// The return type of the function.
+ ///
+ public Type? ReturnType { get; set; }
+
+ ///
+ /// The description of the return section.
+ /// If a structured comment is available, the description will be extracted from the return section.
+ /// Otherwise, the description will be null.
+ ///
+ public string? ReturnDescription { get; set; }
+}
+
+public class FunctionParameterContract
+{
+ ///
+ /// The name of the parameter.
+ ///
+ public string? Name { get; set; }
+
+ ///
+ /// The description of the parameter.
+ /// This will be extracted from the param section of the structured comment if available.
+ /// Otherwise, the description will be null.
+ ///
+ public string? Description { get; set; }
+
+ ///
+ /// The type of the parameter.
+ ///
+ public Type? ParameterType { get; set; }
+
+ ///
+ /// If the parameter is a required parameter.
+ ///
+ public bool IsRequired { get; set; }
+
+ ///
+ /// The default value of the parameter.
+ ///
+ public object? DefaultValue { get; set; }
+}
diff --git a/dotnet/src/AutoGen.Core/GroupChat/Graph.cs b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
new file mode 100644
index 00000000000..02f4da50bae
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/GroupChat/Graph.cs
@@ -0,0 +1,104 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// Graph.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class Graph
+{
+ private readonly List transitions = new List();
+
+ public Graph(IEnumerable transitions)
+ {
+ this.transitions.AddRange(transitions);
+ }
+
+ public void AddTransition(Transition transition)
+ {
+ transitions.Add(transition);
+ }
+
+ ///
+ /// Get the transitions of the workflow.
+ ///
+ public IEnumerable Transitions => transitions;
+
+ ///
+ /// Get the next available agents that the messages can be transit to.
+ ///
+ /// the from agent
+ /// messages
+ /// A list of agents that the messages can be transit to
+ public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages)
+ {
+ var nextAgents = new List();
+ var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty();
+ foreach (var transition in availableTransitions)
+ {
+ if (await transition.CanTransitionAsync(messages))
+ {
+ nextAgents.Add(transition.To);
+ }
+ }
+
+ return nextAgents;
+ }
+}
+
+///
+/// Represents a transition between two agents.
+///
+public class Transition
+{
+ private readonly IAgent _from;
+ private readonly IAgent _to;
+ private readonly Func, Task>? _canTransition;
+
+ ///
+ /// Create a new instance of .
+ /// This constructor is used for testing purpose only.
+ /// To create a new instance of , use .
+ ///
+ /// from agent
+ /// to agent
+ /// detect if the transition is allowed, default to be always true
+ internal Transition(IAgent from, IAgent to, Func, Task>? canTransitionAsync = null)
+ {
+ _from = from;
+ _to = to;
+ _canTransition = canTransitionAsync;
+ }
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// "
+ public static Transition Create(TFromAgent from, TToAgent to, Func, Task>? canTransitionAsync = null)
+ where TFromAgent : IAgent
+ where TToAgent : IAgent
+ {
+ return new Transition(from, to, (fromAgent, toAgent, messages) => canTransitionAsync?.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages) ?? Task.FromResult(true));
+ }
+
+ public IAgent From => _from;
+
+ public IAgent To => _to;
+
+ ///
+ /// Check if the transition is allowed.
+ ///
+ /// messages
+ public Task CanTransitionAsync(IEnumerable messages)
+ {
+ if (_canTransition == null)
+ {
+ return Task.FromResult(true);
+ }
+
+ return _canTransition(this.From, this.To, messages);
+ }
+}
diff --git a/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
new file mode 100644
index 00000000000..3b6288ca0a7
--- /dev/null
+++ b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
@@ -0,0 +1,183 @@
+ο»Ώ// Copyright (c) Microsoft Corporation. All rights reserved.
+// GroupChat.cs
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace AutoGen.Core;
+
+public class GroupChat : IGroupChat
+{
+ private IAgent? admin;
+ private List agents = new List();
+ private IEnumerable initializeMessages = new List();
+ private Graph? workflow = null;
+
+ public IEnumerable? Messages { get; private set; }
+
+ ///
+ /// Create a group chat. The next speaker will be decided by a combination effort of the admin and the workflow.
+ ///
+ /// admin agent. If provided, the admin will be invoked to decide the next speaker.
+ /// workflow of the group chat. If provided, the next speaker will be decided by the workflow.
+ /// group members.
+ ///
+ public GroupChat(
+ IEnumerable members,
+ IAgent? admin = null,
+ IEnumerable? initializeMessages = null,
+ Graph? workflow = null)
+ {
+ this.admin = admin;
+ this.agents = members.ToList();
+ this.initializeMessages = initializeMessages ?? new List();
+ this.workflow = workflow;
+
+ this.Validation();
+ }
+
+ private void Validation()
+ {
+ // check if all agents has a name
+ if (this.agents.Any(x => string.IsNullOrEmpty(x.Name)))
+ {
+ throw new Exception("All agents must have a name.");
+ }
+
+ // check if any agents has the same name
+ var names = this.agents.Select(x => x.Name).ToList();
+ if (names.Distinct().Count() != names.Count)
+ {
+ throw new Exception("All agents must have a unique name.");
+ }
+
+ // if there's a workflow
+ // check if the agents in that workflow are in the group chat
+ if (this.workflow != null)
+ {
+ var agentNamesInWorkflow = this.workflow.Transitions.Select(x => x.From.Name!).Concat(this.workflow.Transitions.Select(x => x.To.Name!)).Distinct();
+ if (agentNamesInWorkflow.Any(x => !this.agents.Select(a => a.Name).Contains(x)))
+ {
+ throw new Exception("All agents in the workflow must be in the group chat.");
+ }
+ }
+
+ // must provide one of admin or workflow
+ if (this.admin == null && this.workflow == null)
+ {
+ throw new Exception("Must provide one of admin or workflow.");
+ }
+ }
+
+ ///
+ /// Select the next speaker based on the conversation history.
+ /// The next speaker will be decided by a combination effort of the admin and the workflow.
+ /// Firstly, a group of candidates will be selected by the workflow. If there's only one candidate, then that candidate will be the next speaker.
+ /// Otherwise, the admin will be invoked to decide the next speaker using role-play prompt.
+ ///
+ /// current speaker
+ /// conversation history
+ /// next speaker.
+ public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable conversationHistory)
+ {
+ var agentNames = this.agents.Select(x => x.Name).ToList();
+ if (this.workflow != null)
+ {
+ var nextAvailableAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, conversationHistory);
+ agentNames = nextAvailableAgents.Select(x => x.Name).ToList();
+ if (agentNames.Count() == 0)
+ {
+ throw new Exception("No next available agents found in the current workflow");
+ }
+
+ if (agentNames.Count() == 1)
+ {
+ return this.agents.FirstOrDefault(x => x.Name == agentNames.First());
+ }
+ }
+
+ if (this.admin == null)
+ {
+ throw new Exception("No admin is provided.");
+ }
+
+ var systemMessage = new TextMessage(Role.System,
+ content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation.
+The available roles are:
+{string.Join(",", agentNames)}
+
+Each message will start with 'From name:', e.g:
+From admin:
+//your message//.");
+
+ var conv = this.ProcessConversationsForRolePlay(this.initializeMessages, conversationHistory);
+
+ var messages = new IMessage[] { systemMessage }.Concat(conv);
+ var response = await this.admin.GenerateReplyAsync(
+ messages: messages,
+ options: new GenerateReplyOptions
+ {
+ Temperature = 0,
+ MaxToken = 128,
+ StopSequence = [":"],
+ Functions = [],
+ });
+
+ var name = response?.GetContent() ?? throw new Exception("No name is returned.");
+
+ // remove From
+ name = name!.Substring(5);
+ return this.agents.First(x => x.Name!.ToLower() == name.ToLower());
+ }
+
+ ///
+ public void AddInitializeMessage(IMessage message)
+ {
+ this.SendIntroduction(message);
+ }
+
+ public async Task> CallAsync(
+ IEnumerable? conversationWithName = null,
+ int maxRound = 10,
+ CancellationToken ct = default)
+ {
+ var conversationHistory = new List