Skip to content

Commit

Permalink
Remove webview js api, Add unittest for provider has model, Use cooki… (
Browse files Browse the repository at this point in the history
#2470)

* Remove webview js api, Add unittest for provider has model, Use cookies dir for cache
  • Loading branch information
hlohaus authored Dec 8, 2024
1 parent a358b28 commit 76c3683
Show file tree
Hide file tree
Showing 18 changed files with 63 additions and 293 deletions.
2 changes: 1 addition & 1 deletion etc/tool/create_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def create_async_generator(
print("Create code...")
response = []
for chunk in g4f.ChatCompletion.create(
model=g4f.models.default,
model=g4f.models.gpt_4o,
messages=[{"role": "user", "content": prompt}],
timeout=300,
stream=True,
Expand Down
1 change: 1 addition & 0 deletions etc/unittest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
from .image_client import *
from .include import *
from .retry_provider import *
from .models import *

unittest.main()
8 changes: 4 additions & 4 deletions etc/unittest/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@
from .mocks import ModelProviderMock

DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]

test_model = g4f.models.Model(
name = "test/test_model",
base_provider = "",
best_provider = ModelProviderMock
)
g4f.models.ModelUtils.convert["test_model"] = test_model

class TestPassModel(unittest.TestCase):

def test_model_instance(self):
response = ChatCompletion.create(test_model, DEFAULT_MESSAGES)
self.assertEqual(test_model.name, response)

def test_model_name(self):
response = ChatCompletion.create("test_model", DEFAULT_MESSAGES)
self.assertEqual(test_model.name, response)

def test_model_pass(self):
response = ChatCompletion.create("test/test_model", DEFAULT_MESSAGES, ModelProviderMock)
self.assertEqual(test_model.name, response)
23 changes: 23 additions & 0 deletions etc/unittest/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
from typing import Type
import asyncio

from g4f.models import __models__
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
from g4f.models import Model

class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
cache: dict = {}

async def test_provider_has_model(self):
for model, providers in __models__.values():
for provider in providers:
if issubclass(provider, ProviderModelMixin):
if model.name not in provider.model_aliases:
await asyncio.wait_for(self.provider_has_model(provider, model), 10)

async def provider_has_model(self, provider: Type[BaseProvider], model: Model):
if provider.__name__ not in self.cache:
self.cache[provider.__name__] = provider.get_models()
if self.cache[provider.__name__]:
self.assertIn(model.name, self.cache[provider.__name__], provider.__name__)
2 changes: 1 addition & 1 deletion g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
"evil": "any-uncensored",
"sdxl": "stable-diffusion-xl-base",
"flux-pro": "flux-1.1-pro",
"llama-3.1-8b": "llama-3.1-8b-chat"
}

@classmethod
Expand All @@ -85,7 +86,6 @@ def get_models(cls):
cls.models = [model for model in cls.models if model not in cls.hidden_models]
except Exception as e:
debug.log(f"Error fetching text models: {e}")
cls.models = [cls.default_model]

return cls.models

Expand Down
21 changes: 6 additions & 15 deletions g4f/Provider/Blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
import re
import aiohttp

import os
import json
from pathlib import Path

from ..typing import AsyncResult, Messages, ImageType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse, to_data_uri

from ..cookies import get_cookies_dir
from .helper import format_prompt

class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
label = "Blackbox AI"
url = "https://www.blackbox.ai"
api_endpoint = "https://www.blackbox.ai/api/chat"

working = True
supports_stream = True
supports_system_message = True
Expand All @@ -38,7 +37,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
agentMode = {
'ImageGeneration': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"}
}

trendingAgentMode = {
"gemini-1.5-flash": {'mode': True, 'id': 'Gemini'},
"llama-3.1-8b": {'mode': True, 'id': "llama-3.1-8b"},
Expand Down Expand Up @@ -108,19 +107,11 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
"flux": "ImageGeneration",
}

@classmethod
def _get_cache_dir(cls) -> Path:
# Get the path to the current file
current_file = Path(__file__)
# Create the path to the .cache directory
cache_dir = current_file.parent / '.cache'
# Create a directory if it does not exist
cache_dir.mkdir(exist_ok=True)
return cache_dir

@classmethod
def _get_cache_file(cls) -> Path:
return cls._get_cache_dir() / 'blackbox.json'
dir = Path(get_cookies_dir())
dir.mkdir(exist_ok=True)
return dir / 'blackbox.json'

@classmethod
def _load_cached_value(cls) -> str | None:
Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/Flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Flux(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://black-forest-labs-flux-1-dev.hf.space"
api_endpoint = "/gradio_api/call/infer"
working = True
default_model = 'flux-1-dev'
default_model = 'flux-dev'
models = [default_model]
image_models = [default_model]

Expand Down
3 changes: 2 additions & 1 deletion g4f/Provider/RobocodersAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aiohttp import ClientTimeout
from ..errors import MissingRequirementsError
from ..typing import AsyncResult, Messages
from ..cookies import get_cookies_dir
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt

Expand All @@ -30,7 +31,7 @@ class RobocodersAPI(AsyncGeneratorProvider, ProviderModelMixin):
agent = [default_model, "RepoAgent", "FrontEndAgent"]
models = [*agent]

CACHE_DIR = Path(__file__).parent / ".cache"
CACHE_DIR = Path(get_cookies_dir())
CACHE_FILE = CACHE_DIR / "robocoders.json"

@classmethod
Expand Down
5 changes: 4 additions & 1 deletion g4f/Provider/needs_auth/CopilotAccount.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ class CopilotAccount(Copilot, ProviderModelMixin):
default_model = "Copilot"
default_vision_model = default_model
models = [default_model]
image_models = models
image_models = models
model_aliases = {
"dall-e-3": default_model
}
2 changes: 1 addition & 1 deletion g4f/Provider/needs_auth/MetaAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MetaAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Meta AI"
url = "https://www.meta.ai"
working = True
default_model = ''
default_model = 'meta-ai'

def __init__(self, proxy: str = None, connector: BaseConnector = None):
self.session = ClientSession(connector=get_connector(connector, proxy), headers=DEFAULT_HEADERS)
Expand Down
26 changes: 8 additions & 18 deletions g4f/gui/client/index.html
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
<!DOCTYPE html>
<html lang="en" data-framework="javascript">

<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
Expand Down Expand Up @@ -62,9 +61,8 @@
onInit: (el, pswp) => {
lightbox.pswp.on('change', () => {
const currSlideElement = lightbox.pswp.currSlide.data.element;
let captionHTML = '';
if (currSlideElement) {
el.innerHTML = currSlideElement.querySelector('img').getAttribute('alt');
el.innerText = currSlideElement.querySelector('img').getAttribute('alt');
}
});
}
Expand All @@ -80,7 +78,6 @@
<script>window.conversation_id = "{{chat_id}}"</script>
<title>g4f - gui</title>
</head>

<body>
<div class="gradient"></div>
<div class="row">
Expand All @@ -92,12 +89,6 @@
</button>
</div>
<div class="bottom_buttons">
<!--
<button onclick="open_album();">
<i class="fa-solid fa-toolbox"></i>
<span>Images Album</span>
</button>
-->
<button onclick="open_settings();">
<i class="fa-solid fa-toolbox"></i>
<span>Open Settings</span>
Expand All @@ -118,8 +109,6 @@
</div>
</div>
</div>
<div class="images hidden">
</div>
<div class="settings hidden">
<div class="paper">
<h3>Settings</h3>
Expand Down Expand Up @@ -151,14 +140,14 @@ <h3>Settings</h3>
<div class="field">
<span class="label">Auto continue in ChatGPT</span>
<input id="auto_continue" type="checkbox" name="auto_continue" checked/>
<label for="auto_continue" class="toogle" title="Continue large responses in OpenaiChat"></label>
<label for="auto_continue" class="toogle" title="Continue large responses in OpenAI ChatGPT"></label>
</div>
<div class="field box">
<label for="message-input-height" class="label" title="">Input max. height</label>
<input type="number" id="message-input-height" value="200"/>
</div>
<div class="field box">
<label for="recognition-language" class="label" title="">Speech recognition lang</label>
<label for="recognition-language" class="label" title="">Speech recognition language</label>
<input type="text" id="recognition-language" value="" placeholder="navigator.language"/>
</div>
<div class="field box">
Expand Down Expand Up @@ -250,7 +239,7 @@ <h3>Settings</h3>
<div class="box input-box">
<textarea id="message-input" placeholder="Ask a question" cols="30" rows="10"
style="white-space: pre-wrap;resize: none;"></textarea>
<label class="file-label image-label" for="image" title="Works with Bing, Gemini, OpenaiChat and You">
<label class="file-label image-label" for="image" title="">
<input type="file" id="image" name="image" accept="image/*" required/>
<i class="fa-regular fa-image"></i>
</label>
Expand Down Expand Up @@ -278,12 +267,13 @@ <h3>Settings</h3>
<option value="gpt-4o">gpt-4o</option>
<option value="gpt-4o-mini">gpt-4o-mini</option>
<option value="llama-3.1-70b">llama-3.1-70b</option>
<option value="llama-3.1-70b">llama-3.1-405b</option>
<option value="llama-3.1-70b">mixtral-8x7b</option>
<option value="llama-3.1-405b">llama-3.1-405b</option>
<option value="mixtral-8x7b">mixtral-8x7b</option>
<option value="gemini-pro">gemini-pro</option>
<option value="gemini-flash">gemini-flash</option>
<option value="claude-3-haiku">claude-3-haiku</option>
<option value="claude-3.5-sonnet">claude-3.5-sonnet</option>
<option value="flux">flux (Image Generation)</option>
<option value="dall-e-3">dall-e-3 (Image Generation)</option>
<option disabled="disabled">----</option>
</select>
<select name="model2" id="model2" class="hidden"></select>
Expand Down
49 changes: 5 additions & 44 deletions g4f/gui/client/static/js/chat.v1.js
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ const handle_ask = async () => {
}
messageInput.value = "";
await count_input()
await add_conversation(window.conversation_id, message);
await add_conversation(window.conversation_id);

if ("text" in fileInput.dataset) {
message += '\n```' + fileInput.dataset.type + '\n';
Expand Down Expand Up @@ -544,20 +544,6 @@ async function add_message_chunk(message, message_id) {
}
}

cameraInput?.addEventListener("click", (e) => {
if (window?.pywebview) {
e.preventDefault();
pywebview.api.take_picture();
}
});

imageInput?.addEventListener("click", (e) => {
if (window?.pywebview) {
e.preventDefault();
pywebview.api.choose_image();
}
});

const ask_gpt = async (message_id, message_index = -1, regenerate = false, provider = null, model = null) => {
if (!model && !provider) {
model = get_selected_model()?.value || null;
Expand Down Expand Up @@ -861,7 +847,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
if (window.GPTTokenizer_cl100k_base) {
const filtered = prepare_messages(messages, null);
if (filtered.length > 0) {
last_model = last_model?.startsWith("gpt-4") ? "gpt-4" : "gpt-3.5-turbo"
last_model = last_model?.startsWith("gpt-3") ? "gpt-3.5-turbo" : "gpt-4"
let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length
if (count_total > 0) {
elements += `<div class="count_total">(${count_total} tokens used)</div>`;
Expand Down Expand Up @@ -916,7 +902,7 @@ async function get_messages(conversation_id) {
return conversation?.items || [];
}

async function add_conversation(conversation_id, content) {
async function add_conversation(conversation_id) {
if (appStorage.getItem(`conversation:${conversation_id}`) == null) {
await save_conversation(conversation_id, {
id: conversation_id,
Expand Down Expand Up @@ -1134,17 +1120,6 @@ function open_settings() {
log_storage.classList.add("hidden");
}

function open_album() {
if (album.classList.contains("hidden")) {
sidebar.classList.remove("shown");
settings.classList.add("hidden");
album.classList.remove("hidden");
history.pushState({}, null, "/images/");
} else {
album.classList.add("hidden");
}
}

const register_settings_storage = async () => {
const optionElements = document.querySelectorAll(optionElementsSelector);
optionElements.forEach((element) => {
Expand Down Expand Up @@ -1277,18 +1252,12 @@ window.addEventListener('load', async function() {
await on_load();
if (window.conversation_id == "{{chat_id}}") {
window.conversation_id = uuid();
} else {
await on_api();
}
});

window.addEventListener('pywebviewready', async function() {
await on_api();
});

async function on_load() {
count_input();

if (/\/chat\/.+/.test(window.location.href)) {
load_conversation(window.conversation_id);
} else {
Expand Down Expand Up @@ -1334,7 +1303,7 @@ async function on_api() {
messageInput.addEventListener("keydown", async (evt) => {
if (prompt_lock) return;

// If not mobile
// If not mobile and not shift enter
if (!window.matchMedia("(pointer:coarse)").matches && evt.keyCode === 13 && !evt.shiftKey) {
evt.preventDefault();
console.log("pressed enter");
Expand Down Expand Up @@ -1396,6 +1365,7 @@ async function on_api() {
await load_provider_models(appStorage.getItem("provider"));
} catch (e) {
console.error(e)
// Redirect to show basic authenfication
if (document.location.pathname == "/chat/") {
document.location.href = `/chat/error`;
}
Expand Down Expand Up @@ -1552,15 +1522,6 @@ function get_selected_model() {
}

async function api(ressource, args=null, file=null, message_id=null) {
if (window?.pywebview) {
if (args !== null) {
if (ressource == "models") {
ressource = "provider_models";
}
return pywebview.api[`get_${ressource}`](args);
}
return pywebview.api[`get_${ressource}`]();
}
let api_key;
if (ressource == "models" && args) {
api_key = get_api_key_by_provider(args);
Expand Down
Loading

0 comments on commit 76c3683

Please sign in to comment.