Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better extract to table support, image support #11

Merged
merged 10 commits into from
Mar 7, 2024
193 changes: 156 additions & 37 deletions datasette_extract/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import base64
from datasette import hookimpl, Response, NotFound
from datetime import datetime, timezone
from openai import AsyncOpenAI, OpenAIError
from sqlite_utils import Database
import click
from starlette.requests import Request as StarletteRequest
import ijson
import json
import ulid
Expand All @@ -24,15 +26,16 @@
"""


async def extract_create_table(datasette, request):
async def extract_create_table(datasette, request, scope, receive):
database = request.url_vars["database"]
try:
db = datasette.get_database(database)
except KeyError:
raise NotFound("Database '{}' does not exist".format(database))

if request.method == "POST":
post_vars = await request.post_vars()
starlette_request = StarletteRequest(scope, receive)
post_vars = await starlette_request.form()
content = (post_vars.get("content") or "").strip()
if not content:
return Response.text("No content provided", status=400)
Expand All @@ -53,8 +56,10 @@ async def extract_create_table(datasette, request):
if hint:
properties[value]["description"] = hint

image = post_vars.get("image") or ""

return await extract_to_table_post(
datasette, request, content, database, table, properties
datasette, request, content, image, database, table, properties
)

return Response.html(
Expand All @@ -69,7 +74,7 @@ async def extract_create_table(datasette, request):
)


async def extract_to_table(datasette, request):
async def extract_to_table(datasette, request, scope, receive):
database = request.url_vars["database"]
table = request.url_vars["table"]
# Do they exist?
Expand All @@ -84,34 +89,89 @@ async def extract_to_table(datasette, request):
schema = await db.execute_fn(lambda conn: Database(conn)[table].columns_dict)

if request.method == "POST":
# Turn schema into a properties dict
properties = {
name: {
"type": get_type(type_),
# "description": "..."
}
for name, type_ in schema.items()
starlette_request = StarletteRequest(scope, receive)
post_vars = await starlette_request.form()

# We only use columns that have their use_{colname} set
use_columns = [
key[len("use_") :]
for key, value in post_vars.items()
if key.startswith("use_") and value
]

# Grab all of the hints
column_hints = {
key[len("hint_") :]: value.strip()
for key, value in post_vars.items()
if key.startswith("hint_") and value.strip()
}
post_vars = await request.post_vars()
# Turn schema into a properties dict
properties = {}
for name, type_ in schema.items():
if name in use_columns:
properties[name] = {"type": get_type(type_)}
description = column_hints.get(name) or ""
if description:
properties[name]["description"] = description

image = post_vars.get("image") or ""
content = (post_vars.get("content") or "").strip()
return await extract_to_table_post(
datasette, request, content, database, table, properties
datasette, request, content, image, database, table, properties
)

# Restore properties from previous run, if possible
previous_runs = []
if await db.table_exists("_datasette_extract"):
previous_runs = [
dict(row)
for row in (
await db.execute(
"""
select id, database_name, table_name, created, properties, completed, error, num_items
from _datasette_extract
where database_name = :database_name and table_name = :table_name
order by id desc limit 20
""",
{"database_name": database, "table_name": table},
)
).rows
]

columns = [
{"name": name, "type": value, "hint": "", "checked": True}
for name, value in schema.items()
]

# If there are previous runs, use the properties from the last one to update columns
if previous_runs:
properties = json.loads(previous_runs[0]["properties"])
print(properties)
for column in columns:
column_name = column["name"]
column["checked"] = column_name in properties
column["hint"] = (properties.get(column_name) or {}).get(
"description"
) or ""

return Response.html(
await datasette.render_template(
"extract_to_table.html",
{
"database": database,
"table": table,
"schema": schema,
"columns": columns,
"previous_runs": previous_runs,
},
request=request,
)
)


async def extract_table_task(datasette, database, table, properties, content, task_id):
async def extract_table_task(
datasette, database, table, properties, content, image, task_id
):
# This task runs in the background
events = ijson.sendable_list()
coro = ijson.items_coro(events, "items.item")
Expand All @@ -129,9 +189,30 @@ async def extract_table_task(datasette, database, table, properties, content, ta
}
datasette._extract_tasks[task_id] = task_info

# We record tasks to the _datasette_extract table, mainly so we can reuse
# property definitions later on
def start_write(conn):
with conn:
db = Database(conn)
db["_datasette_extract"].insert(
{
"id": task_id,
"database_name": database,
"table_name": table,
"created": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"properties": json.dumps(properties),
"completed": None,
"error": None,
"num_items": 0,
},
pk="id",
)

async_client = AsyncOpenAI()
db = datasette.get_database(database)

await db.execute_write_fn(start_write)

def make_row_writer(row):
def _write(conn):
with conn:
Expand All @@ -140,11 +221,46 @@ def _write(conn):

return _write

error = None

async def ocr_image(image_bytes):
base64_image = base64.b64encode(image_bytes).decode("utf-8")
messages = [
{
"role": "system",
"content": "Run OCR and return all of the text in this image, with newlines where appropriate",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
}
],
},
]
response = await async_client.chat.completions.create(
model="gpt-4-vision-preview", messages=messages, max_tokens=400
)
return response.choices[0].message.content

try:
messages = []
if content:
messages.append({"role": "user", "content": content})
if image:
# Run a separate thing to OCR the image first, because gpt-4-vision can't handle tools yet
image_content = await ocr_image(await image.read())
if image_content:
messages.append({"role": "user", "content": image_content})
else:
raise ValueError("Could not extract text from image")

async for chunk in await async_client.chat.completions.create(
stream=True,
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": content}],
messages=messages,
tools=[
{
"type": "function",
Expand All @@ -169,6 +285,7 @@ def _write(conn):
},
],
tool_choice={"type": "function", "function": {"name": "extract_data"}},
max_tokens=4096,
):
try:
content = chunk.choices[0].delta.tool_calls[0].function.arguments
Expand All @@ -187,23 +304,41 @@ def _write(conn):
items.append(event)
await db.execute_write_fn(make_row_writer(event))

except OpenAIError as ex:
except Exception as ex:
task_info["error"] = str(ex)
return
error = str(ex)
finally:
task_info["done"] = True

def end_write(conn):
with conn:
db = Database(conn)
db["_datasette_extract"].update(
task_id,
{
"completed": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
"num_items": len(items),
"error": error,
},
)

await db.execute_write_fn(end_write)


async def extract_to_table_post(
datasette, request, content, database, table, properties
datasette, request, content, image, database, table, properties
):
# Here we go!
if not content:
if not content and not image:
return Response.text("No content provided")

task_id = str(ulid.ULID())
asyncio.create_task(
extract_table_task(datasette, database, table, properties, content, task_id)
extract_table_task(
datasette, database, table, properties, content, image, task_id
)
)
return Response.redirect(
datasette.urls.path("/-/extract/progress/{}".format(task_id))
Expand Down Expand Up @@ -240,22 +375,6 @@ async def extract_progress_json(datasette, request):
return Response.json(task_info)


@click.command()
@click.argument(
"database",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table", required=True)
def extract(database, table):
click.echo("Will extract to {} in {}".format(table, database))


@hookimpl
def register_commands(cli):
cli.add_command(extract, name="extract")


@hookimpl
def register_routes():
return [
Expand Down
11 changes: 11 additions & 0 deletions datasette_extract/static/heic2any-0.0.4.min.js

Large diffs are not rendered by default.

Loading
Loading