Skip to content

Commit

Permalink
Obey custom instructions, closes #17
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Mar 13, 2024
1 parent 68ce739 commit de32d79
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
44 changes: 38 additions & 6 deletions datasette_extract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def extract_create_table(datasette, request, scope, receive):
post_vars = await starlette_request.form()
content = (post_vars.get("content") or "").strip()
image = post_vars.get("image") or ""
instructions = post_vars.get("instructions") or ""
if not content and not image_is_provided(image):
return Response.text("No content provided", status=400)
table = post_vars.get("table")
Expand All @@ -107,7 +108,14 @@ async def extract_create_table(datasette, request, scope, receive):
properties[value]["description"] = hint

return await extract_to_table_post(
datasette, request, content, image, database, table, properties
datasette,
request,
instructions,
content,
image,
database,
table,
properties,
)

return Response.html(
Expand Down Expand Up @@ -167,9 +175,17 @@ async def extract_to_table(datasette, request, scope, receive):
properties[name]["description"] = description

image = post_vars.get("image") or ""
instructions = post_vars.get("instructions") or ""
content = (post_vars.get("content") or "").strip()
return await extract_to_table_post(
datasette, request, content, image, database, table, properties
datasette,
request,
instructions,
content,
image,
database,
table,
properties,
)

# Restore properties from previous run, if possible
Expand All @@ -180,7 +196,7 @@ async def extract_to_table(datasette, request, scope, receive):
for row in (
await db.execute(
"""
select id, database_name, table_name, created, properties, completed, error, num_items
select id, database_name, table_name, created, properties, instructions, completed, error, num_items
from _datasette_extract
where database_name = :database_name and table_name = :table_name
order by id desc limit 20
Expand All @@ -195,6 +211,8 @@ async def extract_to_table(datasette, request, scope, receive):
for name, value in schema.items()
]

instructions = ""

# If there are previous runs, use the properties from the last one to update columns
if previous_runs:
properties = json.loads(previous_runs[0]["properties"])
Expand All @@ -204,6 +222,7 @@ async def extract_to_table(datasette, request, scope, receive):
column["hint"] = (properties.get(column_name) or {}).get(
"description"
) or ""
instructions = previous_runs[0]["instructions"] or ""

return Response.html(
await datasette.render_template(
Expand All @@ -213,6 +232,7 @@ async def extract_to_table(datasette, request, scope, receive):
"table": table,
"schema": schema,
"columns": columns,
"instructions": instructions,
"previous_runs": previous_runs,
},
request=request,
Expand All @@ -221,7 +241,7 @@ async def extract_to_table(datasette, request, scope, receive):


async def extract_table_task(
datasette, database, table, properties, content, image, task_id
datasette, database, table, properties, instructions, content, image, task_id
):
# This task runs in the background
events = ijson.sendable_list()
Expand All @@ -234,6 +254,7 @@ async def extract_table_task(
"items": items,
"database": database,
"table": table,
"instructions": instructions,
"properties": properties,
"error": None,
"done": False,
Expand All @@ -251,12 +272,14 @@ def start_write(conn):
"database_name": database,
"table_name": table,
"created": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"instructions": instructions.strip() or None,
"properties": json.dumps(properties),
"completed": None,
"error": None,
"num_items": 0,
},
pk="id",
alter=True,
)

async_client = AsyncOpenAI()
Expand Down Expand Up @@ -298,6 +321,8 @@ async def ocr_image(image_bytes):

try:
messages = []
if instructions:
messages.append({"role": "system", "content": instructions})
if content:
messages.append({"role": "user", "content": content})
if image_is_provided(image):
Expand Down Expand Up @@ -379,7 +404,7 @@ def end_write(conn):


async def extract_to_table_post(
datasette, request, content, image, database, table, properties
datasette, request, instructions, content, image, database, table, properties
):
# Here we go!
if not content and not image_is_provided(image):
Expand All @@ -389,7 +414,14 @@ async def extract_to_table_post(

asyncio.create_task(
extract_table_task(
datasette, database, table, properties, content, image, task_id
datasette,
database,
table,
properties,
instructions,
content,
image,
task_id,
)
)
return Response.redirect(
Expand Down
4 changes: 4 additions & 0 deletions datasette_extract/templates/extract_create_table.html
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ <h1>Extract data and create a new table in {{ database }}</h1>
<p>
<label>Or upload an image: <input type="file" id="id_image" name="image"></label>
</p>
<p><label for="id_instructions">Additional instructions:</label></p>
<p>
<textarea name="instructions" id="id_instructions" style="width: 100%; height: 5em;" placeholder="Optional additional instructions"></textarea>
</p>
<p>
<input type="submit" value="Extract">
</p>
Expand Down
8 changes: 8 additions & 0 deletions datasette_extract/templates/extract_to_table.html
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,25 @@ <h1>Extract data into {{ database }} / {{ table }}</h1>
<p>
<label>Or upload an image: <input type="file" id="id_image" name="image"></label>
</p>
<p><label for="id_instructions">Additional instructions:</label></p>
<p>
<textarea name="instructions" id="id_instructions" style="width: 100%; height: 5em;" placeholder="Optional additional instructions">{{ instructions }}</textarea>
</p>
<p><input type="submit" value="Extract"></p>
</form>

{% include "_extract_drop_handler.html" %}

{% if previous_runs %}
<h2>Previous extraction tasks</h2>
<div style="overflow: auto">
<table>
<tr>
<th>ID</th>
<th>created</th>
<th>completed</th>
<th>properties</th>
<th>instructions</th>
<th>error</th>
<th>num_items</th>
</tr>
Expand All @@ -66,11 +72,13 @@ <h2>Previous extraction tasks</h2>
<td>{{ run.created }}</td>
<td>{{ run.completed or "" }}</td>
<td>{{ run.properties }}</td>
<td>{{ run.instructions }}</td>
<td>{{ run.error or "" }}</td>
<td>{{ run.num_items }}</td>
</tr>
{% endfor %}
</table>
</div>
{% endif %}

{% endblock %}
2 changes: 2 additions & 0 deletions tests/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def test_extract_flow():
"type_0": "string",
"name_1": "age",
"type_1": "integer",
"instructions": "Be nice",
},
files={
# Send an empty image too
Expand All @@ -51,6 +52,7 @@ async def test_extract_flow():

assert data == {
"items": [{"name": "Sergei", "age": 4}, {"name": "Cynthia", "age": 7}],
"instructions": "Be nice",
"database": "data",
"table": "ages",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
Expand Down

0 comments on commit de32d79

Please sign in to comment.