diff --git a/explorer/admin.py b/explorer/admin.py index 0cad25ac..219c87ce 100644 --- a/explorer/admin.py +++ b/explorer/admin.py @@ -7,7 +7,7 @@ @admin.register(Query) class QueryAdmin(admin.ModelAdmin): - list_display = ("title", "description", "created_by_user",) + list_display = ("title", "description", "created_by_user", "few_shot") list_filter = ("title",) raw_id_fields = ("created_by_user",) actions = [generate_report_action()] diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py index cee02d8e..9f44e25f 100644 --- a/explorer/assistant/utils.py +++ b/explorer/assistant/utils.py @@ -1,8 +1,9 @@ from explorer import app_settings from explorer.schema import schema_info -from explorer.models import ExplorerValue +from explorer.models import ExplorerValue, Query from django.db.utils import OperationalError from django.db.models.functions import Lower +from django.db.models import Q from explorer.assistant.models import TableDescription @@ -130,6 +131,21 @@ def get_relevant_annotations(db_connection, included_tables): ) +def get_relevant_few_shots(db_connection, included_tables): + included_tables_lower = [t.lower() for t in included_tables] + + query_conditions = Q() + for table in included_tables_lower: + query_conditions |= Q(sql__icontains=table) + + return Query.objects.annotate( + sql_lower=Lower("sql") + ).filter( + database_connection=db_connection, + few_shot=True + ).filter(query_conditions) + + def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None): djc = db_connection.as_django_connection() sp = build_system_prompt(djc.vendor) @@ -154,6 +170,11 @@ def build_prompt(db_connection, assistant_request, included_tables, query_error= for td in get_relevant_annotations(db_connection, included_tables): user_prompt += f"## Usage Notes about Table {td.table_name} ##\n{td.description}\n\n" + for fs in get_relevant_few_shots(db_connection, included_tables): + user_prompt += f"""## Example queries using these tables, written by expert analysts ##\n + Description of query:\n{fs.title}: {fs.description}\n\n + SQL:\n{fs.sql}\n\n""" + user_prompt += f"## User's Request to Assistant ##\n{assistant_request}\n\n" prompt = { diff --git a/explorer/forms.py b/explorer/forms.py index fb6511c2..e025b558 100644 --- a/explorer/forms.py +++ b/explorer/forms.py @@ -34,6 +34,7 @@ class QueryForm(ModelForm): sql = SqlField() snapshot = BooleanField(widget=CheckboxInput, required=False) + few_shot = BooleanField(widget=CheckboxInput, required=False) database_connection = CharField(widget=Select, required=False) def __init__(self, *args, **kwargs): @@ -77,4 +78,4 @@ def connections(self): class Meta: model = Query - fields = ["title", "sql", "description", "snapshot", "database_connection"] + fields = ["title", "sql", "description", "snapshot", "database_connection", "few_shot"] diff --git a/explorer/migrations/0027_query_few_shot.py b/explorer/migrations/0027_query_few_shot.py new file mode 100644 index 00000000..18e901fb --- /dev/null +++ b/explorer/migrations/0027_query_few_shot.py @@ -0,0 +1,18 @@ +# Generated by Django 5.0.4 on 2024-08-25 21:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0026_tabledescription'), + ] + + operations = [ + migrations.AddField( + model_name='query', + name='few_shot', + field=models.BooleanField(default=False, help_text='Will be included as a good example of SQL in assistant queries that use relevant tables'), + ), + ] diff --git a/explorer/models.py b/explorer/models.py index be4576d8..67939df8 100644 --- a/explorer/models.py +++ b/explorer/models.py @@ -56,6 +56,8 @@ class Query(models.Model): ) ) database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True) + few_shot = models.BooleanField(default=False, help_text=_( + "Will be included as a good example of SQL in assistant queries that use relevant tables")) def __init__(self, *args, **kwargs): self.params = kwargs.get("params") diff --git a/explorer/templates/explorer/query.html b/explorer/templates/explorer/query.html index ed60550f..40e79167 100644 --- a/explorer/templates/explorer/query.html +++ b/explorer/templates/explorer/query.html @@ -150,5 +150,6 @@

{% if query and can_change and tasks_enabled %}{{ form.snapshot }} {% translate "Snapshot" %}{% endif %} + {% if query and can_change and assistant_enabled %}{{ form.few_shot }} {% translate "Include as few-shot example" %}{% endif %}
{% endblock %} diff --git a/explorer/tests/test_assistant.py b/explorer/tests/test_assistant.py index 4fa6dc46..f5b51a75 100644 --- a/explorer/tests/test_assistant.py +++ b/explorer/tests/test_assistant.py @@ -9,7 +9,14 @@ from django.contrib.auth.models import User from django.db import OperationalError from explorer.ee.db_connections.utils import default_db_connection -from explorer.assistant.utils import sample_rows_from_table, ROW_SAMPLE_SIZE, build_prompt +from explorer.assistant.utils import ( + sample_rows_from_table, + ROW_SAMPLE_SIZE, + build_prompt, + get_relevant_annotations, + get_relevant_few_shots +) + from explorer.assistant.models import TableDescription @@ -87,6 +94,20 @@ def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_fits_in_ self.assertIn("## User's Request to Assistant ##\nHelp me with SQL\n\n", result["user"]) self.assertIn("## Usage Notes about Table foo ##\nannotated\n\n", result["user"]) + @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") + @patch("explorer.assistant.utils.fits_in_window", return_value=True) + @patch("explorer.models.ExplorerValue.objects.get_item") + def test_build_prompt_with_few_shot(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + mock_get_item.return_value.value = "system prompt" + + included_tables = ["magic"] + SimpleQueryFactory(title="Few shot", description="the quick brown fox", sql="select 'magic value';", + few_shot=True) + + result = build_prompt(default_db_connection(), + "Help me with SQL", included_tables, sql="SELECT * FROM table;") + self.assertIn("Example queries using these tables", result["user"]) + self.assertIn("magic value", result["user"]) @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") @patch("explorer.assistant.utils.fits_in_window", return_value=True) @@ -256,10 +277,23 @@ def test_sample_rows_from_tables_no_tables(self): ret = sample_rows_from_tables(conn(), []) self.assertEqual(ret, "") + def test_relevant_few_shots(self): + relevant_q1 = SimpleQueryFactory(sql="select * from relevant_table", few_shot=True) + relevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=True) + irrelevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=False) + relevant_q3 = SimpleQueryFactory(sql="select * from conn.another_good_table limit 10", few_shot=True) + irrelevant_q1 = SimpleQueryFactory(sql="select * from irrelevant_table") + included_tables = ["relevant_table", "ANOTHER_GOOD_TABLE"] + res = get_relevant_few_shots(relevant_q1.database_connection, included_tables) + res_ids = [td.id for td in res] + self.assertIn(relevant_q1.id, res_ids) + self.assertIn(relevant_q2.id, res_ids) + self.assertIn(relevant_q3.id, res_ids) + self.assertNotIn(irrelevant_q1.id, res_ids) + self.assertNotIn(irrelevant_q2.id, res_ids) + def test_get_relevant_annotations(self): - from explorer.assistant.models import TableDescription - from explorer.assistant.utils import get_relevant_annotations - from explorer.ee.db_connections.utils import default_db_connection + included_tables = ["Fruit", "vegetables"] relevant1 = TableDescription( database_connection=default_db_connection(),