Skip to content

Commit

Permalink
IterListProvider support for generating images (#2441)
Browse files Browse the repository at this point in the history
* IterListProvider support for generating images
* Add missing get_har_files import in Copilot
* Fix typo in dall-e-3 model name
* Add image client unittests
* Add MicrosoftDesigner provider
* Import MicrosoftDesigner and add it to the model list
  • Loading branch information
hlohaus authored Nov 29, 2024
1 parent 8d5d522 commit 79c407b
Show file tree
Hide file tree
Showing 16 changed files with 392 additions and 136 deletions.
1 change: 1 addition & 0 deletions etc/unittest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .main import *
from .model import *
from .client import *
from .image_client import *
from .include import *
from .retry_provider import *

Expand Down
44 changes: 44 additions & 0 deletions etc/unittest/image_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import asyncio
import unittest

from g4f.client import AsyncClient, ImagesResponse
from g4f.providers.retry_provider import IterListProvider
from .mocks import (
YieldImageResponseProviderMock,
MissingAuthProviderMock,
AsyncRaiseExceptionProviderMock,
YieldNoneProviderMock
)

DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]

class TestIterListProvider(unittest.IsolatedAsyncioTestCase):

async def test_skip_provider(self):
client = AsyncClient(image_provider=IterListProvider([MissingAuthProviderMock, YieldImageResponseProviderMock], False))
response = await client.images.generate("Hello", "", response_format="orginal")
self.assertIsInstance(response, ImagesResponse)
self.assertEqual("Hello", response.data[0].url)

async def test_only_one_result(self):
client = AsyncClient(image_provider=IterListProvider([YieldImageResponseProviderMock, YieldImageResponseProviderMock], False))
response = await client.images.generate("Hello", "", response_format="orginal")
self.assertIsInstance(response, ImagesResponse)
self.assertEqual("Hello", response.data[0].url)

async def test_skip_none(self):
client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, YieldImageResponseProviderMock], False))
response = await client.images.generate("Hello", "", response_format="orginal")
self.assertIsInstance(response, ImagesResponse)
self.assertEqual("Hello", response.data[0].url)

def test_raise_exception(self):
async def run_exception():
client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, AsyncRaiseExceptionProviderMock], False))
await client.images.generate("Hello", "")
self.assertRaises(RuntimeError, asyncio.run, run_exception())

if __name__ == '__main__':
unittest.main()
21 changes: 21 additions & 0 deletions etc/unittest/mocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from g4f.providers.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider
from g4f.image import ImageResponse
from g4f.errors import MissingAuthError

class ProviderMock(AbstractProvider):
working = True
Expand Down Expand Up @@ -41,6 +43,25 @@ async def create_async_generator(
for message in messages:
yield message["content"]

class YieldImageResponseProviderMock(AsyncGeneratorProvider):
working = True

@classmethod
async def create_async_generator(
cls, model, messages, stream, prompt: str, **kwargs
):
yield ImageResponse(prompt, "")

class MissingAuthProviderMock(AbstractProvider):
working = True

@classmethod
def create_completion(
cls, model, messages, stream, **kwargs
):
raise MissingAuthError(cls.__name__)
yield cls.__name__

class RaiseExceptionProviderMock(AbstractProvider):
working = True

Expand Down
10 changes: 5 additions & 5 deletions g4f/Provider/AmigoChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
'flux-pro/v1.1-ultra': {'persona_id': "flux-pro-v1.1-ultra"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan
'flux-pro/v1.1-ultra-raw': {'persona_id': "flux-pro-v1.1-ultra-raw"}, # Amigo, your balance is not enough to make the request, wait until 12 UTC or upgrade your plan
'flux/dev': {'persona_id': "flux-dev"},
'dalle-e-3': {'persona_id': "dalle-three"},

'dall-e-3': {'persona_id': "dalle-three"},

'recraft-v3': {'persona_id': "recraft"}
}
}
Expand Down Expand Up @@ -129,8 +129,8 @@ class AmigoChat(AsyncGeneratorProvider, ProviderModelMixin):
### image ###
"flux-realism": "flux-realism",
"flux-dev": "flux/dev",
"dalle-3": "dalle-e-3",

"dalle-3": "dall-e-3",
}

@classmethod
Expand Down
79 changes: 37 additions & 42 deletions g4f/Provider/Copilot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
import json
import asyncio
from http.cookiejar import CookieJar
Expand All @@ -20,10 +19,10 @@
from .base_provider import AbstractProvider, ProviderModelMixin, BaseConversation
from .helper import format_prompt
from ..typing import CreateResult, Messages, ImageType
from ..errors import MissingRequirementsError
from ..errors import MissingRequirementsError, NoValidHarFileError
from ..requests.raise_for_status import raise_for_status
from ..providers.asyncio import get_running_loop
from .openai.har_file import NoValidHarFileError, get_headers
from .openai.har_file import get_headers, get_har_files
from ..requests import get_nodriver
from ..image import ImageResponse, to_bytes, is_accepted_format
from .. import debug
Expand Down Expand Up @@ -76,12 +75,12 @@ def create_completion(
if cls.needs_auth or image is not None:
if conversation is None or conversation.access_token is None:
try:
access_token, cookies = readHAR()
access_token, cookies = readHAR(cls.url)
except NoValidHarFileError as h:
debug.log(f"Copilot: {h}")
try:
get_running_loop(check_nested=True)
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
access_token, cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
except MissingRequirementsError:
raise h
else:
Expand Down Expand Up @@ -162,35 +161,34 @@ def create_completion(
if not is_started:
raise RuntimeError(f"Invalid response: {last_msg}")

@classmethod
async def get_access_token_and_cookies(cls, proxy: str = None):
browser = await get_nodriver(proxy=proxy)
page = await browser.get(cls.url)
access_token = None
while access_token is None:
access_token = await page.evaluate("""
(() => {
for (var i = 0; i < localStorage.length; i++) {
try {
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
if (item.credentialType == "AccessToken"
&& item.expiresOn > Math.floor(Date.now() / 1000)
&& item.target.includes("ChatAI")) {
return item.secret;
}
} catch(e) {}
}
})()
""")
if access_token is None:
await asyncio.sleep(1)
cookies = {}
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
cookies[c.name] = c.value
await page.close()
return access_token, cookies

def readHAR():
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
browser = await get_nodriver(proxy=proxy)
page = await browser.get(url)
access_token = None
while access_token is None:
access_token = await page.evaluate("""
(() => {
for (var i = 0; i < localStorage.length; i++) {
try {
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
if (item.credentialType == "AccessToken"
&& item.expiresOn > Math.floor(Date.now() / 1000)
&& item.target.includes("target")) {
return item.secret;
}
} catch(e) {}
}
})()
""".replace('"target"', json.dumps(target)))
if access_token is None:
await asyncio.sleep(1)
cookies = {}
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
cookies[c.name] = c.value
await page.close()
return access_token, cookies

def readHAR(url: str):
api_key = None
cookies = None
for path in get_har_files():
Expand All @@ -201,16 +199,13 @@ def readHAR():
# Error: not a HAR file!
continue
for v in harFile['log']['entries']:
v_headers = get_headers(v)
if v['request']['url'].startswith(Copilot.url):
try:
if "authorization" in v_headers:
api_key = v_headers["authorization"].split(maxsplit=1).pop()
except Exception as e:
debug.log(f"Error on read headers: {e}")
if v['request']['url'].startswith(url):
v_headers = get_headers(v)
if "authorization" in v_headers:
api_key = v_headers["authorization"].split(maxsplit=1).pop()
if v['request']['cookies']:
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
if api_key is None:
raise NoValidHarFileError("No access token found in .har files")

return api_key, cookies
return api_key, cookies
Loading

0 comments on commit 79c407b

Please sign in to comment.