Skip to content

Commit

Permalink
Structured outputs support with examples (#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
ParthSareen authored Dec 5, 2024
1 parent e956a33 commit 4b10dee
Show file tree
Hide file tree
Showing 8 changed files with 355 additions and 18 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ See [_types.py](ollama/_types.py) for more information on the response types.

Response streaming can be enabled by setting `stream=True`.

> [!NOTE]
> Streaming Tool/Function calling is not yet supported.
```python
from ollama import chat

Expand Down
6 changes: 6 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ python3 examples/<example>.py
- [multimodal_generate.py](multimodal_generate.py)


### Structured Outputs - Generate structured outputs with a model
- [structured-outputs.py](structured-outputs.py)
- [async-structured-outputs.py](async-structured-outputs.py)
- [structured-outputs-image.py](structured-outputs-image.py)


### Ollama List - List all downloaded models and their properties
- [list.py](list.py)

Expand Down
32 changes: 32 additions & 0 deletions examples/async-structured-outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pydantic import BaseModel
from ollama import AsyncClient
import asyncio


# Define the schema for the response
class FriendInfo(BaseModel):
name: str
age: int
is_available: bool


class FriendList(BaseModel):
friends: list[FriendInfo]


async def main():
client = AsyncClient()
response = await client.chat(
model='llama3.1:8b',
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema
options={'temperature': 0}, # Make responses more deterministic
)

# Use Pydantic to validate the response
friends_response = FriendList.model_validate_json(response.message.content)
print(friends_response)


if __name__ == '__main__':
asyncio.run(main())
50 changes: 50 additions & 0 deletions examples/structured-outputs-image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path
from pydantic import BaseModel
from typing import List, Optional, Literal
from ollama import chat
from rich import print


# Define the schema for image objects
class Object(BaseModel):
name: str
confidence: float
attributes: Optional[dict] = None


class ImageDescription(BaseModel):
summary: str
objects: List[Object]
scene: str
colors: List[str]
time_of_day: Literal['Morning', 'Afternoon', 'Evening', 'Night']
setting: Literal['Indoor', 'Outdoor', 'Unknown']
text_content: Optional[str] = None


# Get path from user input
path = input('Enter the path to your image: ')
path = Path(path)

# Verify the file exists
if not path.exists():
raise FileNotFoundError(f'Image not found at: {path}')

# Set up chat as usual
response = chat(
model='llama3.2-vision',
format=ImageDescription.model_json_schema(), # Pass in the schema for the response
messages=[
{
'role': 'user',
'content': 'Analyze this image and return a detailed JSON description including objects, scene, colors and any text detected. If you cannot determine certain details, leave those fields empty.',
'images': [path],
},
],
options={'temperature': 0}, # Set temperature to 0 for more deterministic output
)


# Convert received content to the schema
image_analysis = ImageDescription.model_validate_json(response.message.content)
print(image_analysis)
26 changes: 26 additions & 0 deletions examples/structured-outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ollama import chat
from pydantic import BaseModel


# Define the schema for the response
class FriendInfo(BaseModel):
name: str
age: int
is_available: bool


class FriendList(BaseModel):
friends: list[FriendInfo]


# schema = {'type': 'object', 'properties': {'friends': {'type': 'array', 'items': {'type': 'object', 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}, 'is_available': {'type': 'boolean'}}, 'required': ['name', 'age', 'is_available']}}}, 'required': ['friends']}
response = chat(
model='llama3.1:8b',
messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}],
format=FriendList.model_json_schema(), # Use Pydantic to generate the schema or format=schema
options={'temperature': 0}, # Make responses more deterministic
)

# Use Pydantic to validate the response
friends_response = FriendList.model_validate_json(response.message.content)
print(friends_response)
27 changes: 14 additions & 13 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import sys

from pydantic.json_schema import JsonSchemaValue


from ollama._utils import convert_function_to_tool

Expand Down Expand Up @@ -186,7 +188,7 @@ def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -204,7 +206,7 @@ def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -221,7 +223,7 @@ def generate(
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand Down Expand Up @@ -265,7 +267,7 @@ def chat(
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> ChatResponse: ...
Expand All @@ -278,7 +280,7 @@ def chat(
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Iterator[ChatResponse]: ...
Expand All @@ -290,7 +292,7 @@ def chat(
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[ChatResponse, Iterator[ChatResponse]]:
Expand Down Expand Up @@ -327,7 +329,6 @@ def add_two_numbers(a: int, b: int) -> int:
Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator.
"""

return self._request(
ChatResponse,
'POST',
Expand Down Expand Up @@ -689,7 +690,7 @@ async def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -707,7 +708,7 @@ async def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
raw: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -724,7 +725,7 @@ async def generate(
context: Optional[Sequence[int]] = None,
stream: bool = False,
raw: Optional[bool] = None,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes]]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand Down Expand Up @@ -767,7 +768,7 @@ async def chat(
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> ChatResponse: ...
Expand All @@ -780,7 +781,7 @@ async def chat(
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> AsyncIterator[ChatResponse]: ...
Expand All @@ -792,7 +793,7 @@ async def chat(
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Literal['', 'json']] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
Expand Down
3 changes: 2 additions & 1 deletion ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import Any, Mapping, Optional, Union, Sequence

from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Annotated, Literal

from pydantic import (
Expand Down Expand Up @@ -150,7 +151,7 @@ class BaseGenerateRequest(BaseStreamableRequest):
options: Optional[Union[Mapping[str, Any], Options]] = None
'Options to use for the request.'

format: Optional[Literal['', 'json']] = None
format: Optional[Union[Literal['json'], JsonSchemaValue]] = None
'Format of the response.'

keep_alive: Optional[Union[float, str]] = None
Expand Down
Loading

0 comments on commit 4b10dee

Please sign in to comment.