Skip to content

Commit

Permalink
few shot examples
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisclark committed Aug 25, 2024
1 parent c8d2e16 commit 0c8c6ec
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 7 deletions.
2 changes: 1 addition & 1 deletion explorer/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@admin.register(Query)
class QueryAdmin(admin.ModelAdmin):
list_display = ("title", "description", "created_by_user",)
list_display = ("title", "description", "created_by_user", "few_shot")
list_filter = ("title",)
raw_id_fields = ("created_by_user",)
actions = [generate_report_action()]
Expand Down
23 changes: 22 additions & 1 deletion explorer/assistant/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from explorer import app_settings
from explorer.schema import schema_info
from explorer.models import ExplorerValue
from explorer.models import ExplorerValue, Query
from django.db.utils import OperationalError
from django.db.models.functions import Lower
from django.db.models import Q
from explorer.assistant.models import TableDescription


Expand Down Expand Up @@ -130,6 +131,21 @@ def get_relevant_annotations(db_connection, included_tables):
)


def get_relevant_few_shots(db_connection, included_tables):
included_tables_lower = [t.lower() for t in included_tables]

query_conditions = Q()
for table in included_tables_lower:
query_conditions |= Q(sql__icontains=table)

return Query.objects.annotate(
sql_lower=Lower("sql")
).filter(
database_connection=db_connection,
few_shot=True
).filter(query_conditions)


def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None):
djc = db_connection.as_django_connection()
sp = build_system_prompt(djc.vendor)
Expand All @@ -154,6 +170,11 @@ def build_prompt(db_connection, assistant_request, included_tables, query_error=
for td in get_relevant_annotations(db_connection, included_tables):
user_prompt += f"## Usage Notes about Table {td.table_name} ##\n{td.description}\n\n"

for fs in get_relevant_few_shots(db_connection, included_tables):
user_prompt += f"""## Example queries using these tables, written by expert analysts ##\n
Description of query:\n{fs.title}: {fs.description}\n\n
SQL:\n{fs.sql}\n\n"""

user_prompt += f"## User's Request to Assistant ##\n{assistant_request}\n\n"

prompt = {
Expand Down
3 changes: 2 additions & 1 deletion explorer/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class QueryForm(ModelForm):

sql = SqlField()
snapshot = BooleanField(widget=CheckboxInput, required=False)
few_shot = BooleanField(widget=CheckboxInput, required=False)
database_connection = CharField(widget=Select, required=False)

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -77,4 +78,4 @@ def connections(self):

class Meta:
model = Query
fields = ["title", "sql", "description", "snapshot", "database_connection"]
fields = ["title", "sql", "description", "snapshot", "database_connection", "few_shot"]
18 changes: 18 additions & 0 deletions explorer/migrations/0027_query_few_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.0.4 on 2024-08-25 21:26

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('explorer', '0026_tabledescription'),
]

operations = [
migrations.AddField(
model_name='query',
name='few_shot',
field=models.BooleanField(default=False, help_text='Will be included as a good example of SQL in assistant queries that use relevant tables'),
),
]
2 changes: 2 additions & 0 deletions explorer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class Query(models.Model):
)
)
database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True)
few_shot = models.BooleanField(default=False, help_text=_(
"Will be included as a good example of SQL in assistant queries that use relevant tables"))

def __init__(self, *args, **kwargs):
self.params = kwargs.get("params")
Expand Down
1 change: 1 addition & 0 deletions explorer/templates/explorer/query.html
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,6 @@ <h2>
</div>
<div class="container mt-1 text-end small">
{% if query and can_change and tasks_enabled %}{{ form.snapshot }} {% translate "Snapshot" %}{% endif %}
{% if query and can_change and assistant_enabled %}{{ form.few_shot }} {% translate "Include as few-shot example" %}{% endif %}
</div>
{% endblock %}
42 changes: 38 additions & 4 deletions explorer/tests/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from django.contrib.auth.models import User
from django.db import OperationalError
from explorer.ee.db_connections.utils import default_db_connection
from explorer.assistant.utils import sample_rows_from_table, ROW_SAMPLE_SIZE, build_prompt
from explorer.assistant.utils import (
sample_rows_from_table,
ROW_SAMPLE_SIZE,
build_prompt,
get_relevant_annotations,
get_relevant_few_shots
)

from explorer.assistant.models import TableDescription


Expand Down Expand Up @@ -87,6 +94,20 @@ def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_fits_in_
self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"])
self.assertIn("## Usage Notes about Table foo ##\nannotated\n\n", result["user"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
@patch("explorer.models.ExplorerValue.objects.get_item")
def test_build_prompt_with_few_shot(self, mock_get_item, mock_fits_in_window, mock_sample_rows):
mock_get_item.return_value.value = "system prompt"

included_tables = ["magic"]
SimpleQueryFactory(title="Few shot", description="the quick brown fox", sql="select 'magic value';",
few_shot=True)

result = build_prompt(default_db_connection(),
"Help me with SQL", included_tables, sql="SELECT * FROM table;")
self.assertIn("Example queries using these tables", result["user"])
self.assertIn("magic value", result["user"])

@patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data")
@patch("explorer.assistant.utils.fits_in_window", return_value=True)
Expand Down Expand Up @@ -256,10 +277,23 @@ def test_sample_rows_from_tables_no_tables(self):
ret = sample_rows_from_tables(conn(), [])
self.assertEqual(ret, "")

def test_relevant_few_shots(self):
relevant_q1 = SimpleQueryFactory(sql="select * from relevant_table", few_shot=True)
relevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=True)
irrelevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=False)
relevant_q3 = SimpleQueryFactory(sql="select * from conn.another_good_table limit 10", few_shot=True)
irrelevant_q1 = SimpleQueryFactory(sql="select * from irrelevant_table")
included_tables = ["relevant_table", "ANOTHER_GOOD_TABLE"]
res = get_relevant_few_shots(relevant_q1.database_connection, included_tables)
res_ids = [td.id for td in res]
self.assertIn(relevant_q1.id, res_ids)
self.assertIn(relevant_q2.id, res_ids)
self.assertIn(relevant_q3.id, res_ids)
self.assertNotIn(irrelevant_q1.id, res_ids)
self.assertNotIn(irrelevant_q2.id, res_ids)

def test_get_relevant_annotations(self):
from explorer.assistant.models import TableDescription
from explorer.assistant.utils import get_relevant_annotations
from explorer.ee.db_connections.utils import default_db_connection

included_tables = ["Fruit", "vegetables"]
relevant1 = TableDescription(
database_connection=default_db_connection(),
Expand Down

0 comments on commit 0c8c6ec

Please sign in to comment.