From de32d790369943f53f3a913b5dc49d330688f205 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 13 Mar 2024 08:30:05 -0700 Subject: [PATCH] Obey custom instructions, closes #17 --- datasette_extract/__init__.py | 44 ++++++++++++++++--- .../templates/extract_create_table.html | 4 ++ .../templates/extract_to_table.html | 8 ++++ tests/test_web.py | 2 + 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/datasette_extract/__init__.py b/datasette_extract/__init__.py index 64962ac..46d7ff5 100644 --- a/datasette_extract/__init__.py +++ b/datasette_extract/__init__.py @@ -87,6 +87,7 @@ async def extract_create_table(datasette, request, scope, receive): post_vars = await starlette_request.form() content = (post_vars.get("content") or "").strip() image = post_vars.get("image") or "" + instructions = post_vars.get("instructions") or "" if not content and not image_is_provided(image): return Response.text("No content provided", status=400) table = post_vars.get("table") @@ -107,7 +108,14 @@ async def extract_create_table(datasette, request, scope, receive): properties[value]["description"] = hint return await extract_to_table_post( - datasette, request, content, image, database, table, properties + datasette, + request, + instructions, + content, + image, + database, + table, + properties, ) return Response.html( @@ -167,9 +175,17 @@ async def extract_to_table(datasette, request, scope, receive): properties[name]["description"] = description image = post_vars.get("image") or "" + instructions = post_vars.get("instructions") or "" content = (post_vars.get("content") or "").strip() return await extract_to_table_post( - datasette, request, content, image, database, table, properties + datasette, + request, + instructions, + content, + image, + database, + table, + properties, ) # Restore properties from previous run, if possible @@ -180,7 +196,7 @@ async def extract_to_table(datasette, request, scope, receive): for row in ( await db.execute( """ - select id, database_name, table_name, created, properties, completed, error, num_items + select id, database_name, table_name, created, properties, instructions, completed, error, num_items from _datasette_extract where database_name = :database_name and table_name = :table_name order by id desc limit 20 @@ -195,6 +211,8 @@ async def extract_to_table(datasette, request, scope, receive): for name, value in schema.items() ] + instructions = "" + # If there are previous runs, use the properties from the last one to update columns if previous_runs: properties = json.loads(previous_runs[0]["properties"]) @@ -204,6 +222,7 @@ async def extract_to_table(datasette, request, scope, receive): column["hint"] = (properties.get(column_name) or {}).get( "description" ) or "" + instructions = previous_runs[0]["instructions"] or "" return Response.html( await datasette.render_template( @@ -213,6 +232,7 @@ async def extract_to_table(datasette, request, scope, receive): "table": table, "schema": schema, "columns": columns, + "instructions": instructions, "previous_runs": previous_runs, }, request=request, @@ -221,7 +241,7 @@ async def extract_to_table(datasette, request, scope, receive): async def extract_table_task( - datasette, database, table, properties, content, image, task_id + datasette, database, table, properties, instructions, content, image, task_id ): # This task runs in the background events = ijson.sendable_list() @@ -234,6 +254,7 @@ async def extract_table_task( "items": items, "database": database, "table": table, + "instructions": instructions, "properties": properties, "error": None, "done": False, @@ -251,12 +272,14 @@ def start_write(conn): "database_name": database, "table_name": table, "created": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + "instructions": instructions.strip() or None, "properties": json.dumps(properties), "completed": None, "error": None, "num_items": 0, }, pk="id", + alter=True, ) async_client = AsyncOpenAI() @@ -298,6 +321,8 @@ async def ocr_image(image_bytes): try: messages = [] + if instructions: + messages.append({"role": "system", "content": instructions}) if content: messages.append({"role": "user", "content": content}) if image_is_provided(image): @@ -379,7 +404,7 @@ def end_write(conn): async def extract_to_table_post( - datasette, request, content, image, database, table, properties + datasette, request, instructions, content, image, database, table, properties ): # Here we go! if not content and not image_is_provided(image): @@ -389,7 +414,14 @@ async def extract_to_table_post( asyncio.create_task( extract_table_task( - datasette, database, table, properties, content, image, task_id + datasette, + database, + table, + properties, + instructions, + content, + image, + task_id, ) ) return Response.redirect( diff --git a/datasette_extract/templates/extract_create_table.html b/datasette_extract/templates/extract_create_table.html index 501b82e..039adff 100644 --- a/datasette_extract/templates/extract_create_table.html +++ b/datasette_extract/templates/extract_create_table.html @@ -37,6 +37,10 @@

Extract data and create a new table in {{ database }}

+

+

+ +

diff --git a/datasette_extract/templates/extract_to_table.html b/datasette_extract/templates/extract_to_table.html index 220e082..f8e27dc 100644 --- a/datasette_extract/templates/extract_to_table.html +++ b/datasette_extract/templates/extract_to_table.html @@ -44,6 +44,10 @@

Extract data into {{ database }} / {{ table }}

+

+

+ +

@@ -51,12 +55,14 @@

Extract data into {{ database }} / {{ table }}

{% if previous_runs %}

Previous extraction tasks

+
+ @@ -66,11 +72,13 @@

Previous extraction tasks

+ {% endfor %}
ID created completed propertiesinstructions error num_items
{{ run.created }} {{ run.completed or "" }} {{ run.properties }}{{ run.instructions }} {{ run.error or "" }} {{ run.num_items }}
+
{% endif %} {% endblock %} diff --git a/tests/test_web.py b/tests/test_web.py index 3add255..801b9be 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -25,6 +25,7 @@ async def test_extract_flow(): "type_0": "string", "name_1": "age", "type_1": "integer", + "instructions": "Be nice", }, files={ # Send an empty image too @@ -51,6 +52,7 @@ async def test_extract_flow(): assert data == { "items": [{"name": "Sergei", "age": 4}, {"name": "Cynthia", "age": 7}], + "instructions": "Be nice", "database": "data", "table": "ages", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},