Skip to content

Commit

Permalink
Fix deprecated construct method, fix unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus committed Dec 7, 2024
1 parent 6a624ac commit 20ad080
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
2 changes: 1 addition & 1 deletion etc/unittest/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_get_models(self):

def test_get_providers(self):
response = self.api.get_providers()
self.assertIsInstance(response, dict)
self.assertIsInstance(response, list)
self.assertTrue(len(response) > 0)

def test_search(self):
Expand Down
20 changes: 10 additions & 10 deletions g4f/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def iter_response(
finish_reason = "stop"

if stream:
yield ChatCompletionChunk.construct(chunk, None, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))

if finish_reason is not None:
break
Expand All @@ -84,12 +84,12 @@ def iter_response(
finish_reason = "stop" if finish_reason is None else finish_reason

if stream:
yield ChatCompletionChunk.construct(None, finish_reason, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(None, finish_reason, completion_id, int(time.time()))
else:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
content = filter_json(content)
yield ChatCompletion.construct(content, finish_reason, completion_id, int(time.time()))
yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()))

# Synchronous iter_append_model_and_provider function
def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
Expand Down Expand Up @@ -138,20 +138,20 @@ async def async_iter_response(
finish_reason = "stop"

if stream:
yield ChatCompletionChunk.construct(chunk, None, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))

if finish_reason is not None:
break

finish_reason = "stop" if finish_reason is None else finish_reason

if stream:
yield ChatCompletionChunk.construct(None, finish_reason, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(None, finish_reason, completion_id, int(time.time()))
else:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
content = filter_json(content)
yield ChatCompletion.construct(content, finish_reason, completion_id, int(time.time()))
yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()))
finally:
await safe_aclose(response)

Expand Down Expand Up @@ -422,19 +422,19 @@ async def _process_image_response(
last_provider = get_last_provider(True)
if response_format == "url":
# Return original URLs without saving locally
images = [Image.construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
else:
# Save locally for None (default) case
images = await copy_images(response.get_list(), response.get("cookies"), proxy)
if response_format == "b64_json":
async def process_image_item(image_file: str) -> Image:
with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
image_data = base64.b64encode(file.read()).decode()
return Image.construct(b64_json=image_data, revised_prompt=response.alt)
return Image.model_construct(b64_json=image_data, revised_prompt=response.alt)
images = await asyncio.gather(*[process_image_item(image) for image in images])
else:
images = [Image.construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
return ImagesResponse.construct(
images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
return ImagesResponse.model_construct(
created=int(time.time()),
data=images,
model=last_provider.get("model") if model is None else model,
Expand Down
49 changes: 28 additions & 21 deletions g4f/client/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
class BaseModel():
@classmethod
def construct(cls, **data):
def model_construct(cls, **data):
new = cls()
for key, value in data.items():
setattr(new, key, value)
Expand All @@ -19,6 +19,13 @@ class Field():
def __init__(self, **config):
pass

class BaseModel(BaseModel):
@classmethod
def model_construct(cls, **data):
if hasattr(super(), "model_construct"):
return super().model_construct(**data)
return cls.construct(**data)

class ChatCompletionChunk(BaseModel):
id: str
object: str
Expand All @@ -28,21 +35,21 @@ class ChatCompletionChunk(BaseModel):
choices: List[ChatCompletionDeltaChoice]

@classmethod
def construct(
def model_construct(
cls,
content: str,
finish_reason: str,
completion_id: str = None,
created: int = None
):
return super().construct(
return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None,
object="chat.completion.cunk",
created=created,
model=None,
provider=None,
choices=[ChatCompletionDeltaChoice.construct(
ChatCompletionDelta.construct(content),
choices=[ChatCompletionDeltaChoice.model_construct(
ChatCompletionDelta.model_construct(content),
finish_reason
)]
)
Expand All @@ -52,17 +59,17 @@ class ChatCompletionMessage(BaseModel):
content: str

@classmethod
def construct(cls, content: str):
return super().construct(role="assistant", content=content)
def model_construct(cls, content: str):
return super().model_construct(role="assistant", content=content)

class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: str

@classmethod
def construct(cls, message: ChatCompletionMessage, finish_reason: str):
return super().construct(index=0, message=message, finish_reason=finish_reason)
def model_construct(cls, message: ChatCompletionMessage, finish_reason: str):
return super().model_construct(index=0, message=message, finish_reason=finish_reason)

class ChatCompletion(BaseModel):
id: str
Expand All @@ -78,21 +85,21 @@ class ChatCompletion(BaseModel):
}])

@classmethod
def construct(
def model_construct(
cls,
content: str,
finish_reason: str,
completion_id: str = None,
created: int = None
):
return super().construct(
return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None,
object="chat.completion",
created=created,
model=None,
provider=None,
choices=[ChatCompletionChoice.construct(
ChatCompletionMessage.construct(content),
choices=[ChatCompletionChoice.model_construct(
ChatCompletionMessage.model_construct(content),
finish_reason
)],
usage={
Expand All @@ -107,26 +114,26 @@ class ChatCompletionDelta(BaseModel):
content: str

@classmethod
def construct(cls, content: Optional[str]):
return super().construct(role="assistant", content=content)
def model_construct(cls, content: Optional[str]):
return super().model_construct(role="assistant", content=content)

class ChatCompletionDeltaChoice(BaseModel):
index: int
delta: ChatCompletionDelta
finish_reason: Optional[str]

@classmethod
def construct(cls, delta: ChatCompletionDelta, finish_reason: Optional[str]):
return super().construct(index=0, delta=delta, finish_reason=finish_reason)
def model_construct(cls, delta: ChatCompletionDelta, finish_reason: Optional[str]):
return super().model_construct(index=0, delta=delta, finish_reason=finish_reason)

class Image(BaseModel):
url: Optional[str]
b64_json: Optional[str]
revised_prompt: Optional[str]

@classmethod
def construct(cls, url: str = None, b64_json: str = None, revised_prompt: str = None):
return super().construct(**filter_none(
def model_construct(cls, url: str = None, b64_json: str = None, revised_prompt: str = None):
return super().model_construct(**filter_none(
url=url,
b64_json=b64_json,
revised_prompt=revised_prompt
Expand All @@ -139,10 +146,10 @@ class ImagesResponse(BaseModel):
created: int

@classmethod
def construct(cls, data: List[Image], created: int = None, model: str = None, provider: str = None):
def model_construct(cls, data: List[Image], created: int = None, model: str = None, provider: str = None):
if created is None:
created = int(time())
return super().construct(
return super().model_construct(
data=data,
model=model,
provider=provider,
Expand Down

0 comments on commit 20ad080

Please sign in to comment.