diff --git a/explorer/admin.py b/explorer/admin.py index 0cad25ac..219c87ce 100644 --- a/explorer/admin.py +++ b/explorer/admin.py @@ -7,7 +7,7 @@ @admin.register(Query) class QueryAdmin(admin.ModelAdmin): - list_display = ("title", "description", "created_by_user",) + list_display = ("title", "description", "created_by_user", "few_shot") list_filter = ("title",) raw_id_fields = ("created_by_user",) actions = [generate_report_action()] diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py index cee02d8e..9f44e25f 100644 --- a/explorer/assistant/utils.py +++ b/explorer/assistant/utils.py @@ -1,8 +1,9 @@ from explorer import app_settings from explorer.schema import schema_info -from explorer.models import ExplorerValue +from explorer.models import ExplorerValue, Query from django.db.utils import OperationalError from django.db.models.functions import Lower +from django.db.models import Q from explorer.assistant.models import TableDescription @@ -130,6 +131,21 @@ def get_relevant_annotations(db_connection, included_tables): ) +def get_relevant_few_shots(db_connection, included_tables): + included_tables_lower = [t.lower() for t in included_tables] + + query_conditions = Q() + for table in included_tables_lower: + query_conditions |= Q(sql__icontains=table) + + return Query.objects.annotate( + sql_lower=Lower("sql") + ).filter( + database_connection=db_connection, + few_shot=True + ).filter(query_conditions) + + def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None): djc = db_connection.as_django_connection() sp = build_system_prompt(djc.vendor) @@ -154,6 +170,11 @@ def build_prompt(db_connection, assistant_request, included_tables, query_error= for td in get_relevant_annotations(db_connection, included_tables): user_prompt += f"## Usage Notes about Table {td.table_name} ##\n{td.description}\n\n" + for fs in get_relevant_few_shots(db_connection, included_tables): + user_prompt += f"""## Example queries using these tables, written by expert analysts ##\n + Description of query:\n{fs.title}: {fs.description}\n\n + SQL:\n{fs.sql}\n\n""" + user_prompt += f"## User's Request to Assistant ##\n{assistant_request}\n\n" prompt = { diff --git a/explorer/forms.py b/explorer/forms.py index fb6511c2..e025b558 100644 --- a/explorer/forms.py +++ b/explorer/forms.py @@ -34,6 +34,7 @@ class QueryForm(ModelForm): sql = SqlField() snapshot = BooleanField(widget=CheckboxInput, required=False) + few_shot = BooleanField(widget=CheckboxInput, required=False) database_connection = CharField(widget=Select, required=False) def __init__(self, *args, **kwargs): @@ -77,4 +78,4 @@ def connections(self): class Meta: model = Query - fields = ["title", "sql", "description", "snapshot", "database_connection"] + fields = ["title", "sql", "description", "snapshot", "database_connection", "few_shot"] diff --git a/explorer/migrations/0027_query_few_shot.py b/explorer/migrations/0027_query_few_shot.py new file mode 100644 index 00000000..18e901fb --- /dev/null +++ b/explorer/migrations/0027_query_few_shot.py @@ -0,0 +1,18 @@ +# Generated by Django 5.0.4 on 2024-08-25 21:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0026_tabledescription'), + ] + + operations = [ + migrations.AddField( + model_name='query', + name='few_shot', + field=models.BooleanField(default=False, help_text='Will be included as a good example of SQL in assistant queries that use relevant tables'), + ), + ] diff --git a/explorer/models.py b/explorer/models.py index be4576d8..67939df8 100644 --- a/explorer/models.py +++ b/explorer/models.py @@ -56,6 +56,8 @@ class Query(models.Model): ) ) database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True) + few_shot = models.BooleanField(default=False, help_text=_( + "Will be included as a good example of SQL in assistant queries that use relevant tables")) def __init__(self, *args, **kwargs): self.params = kwargs.get("params") diff --git a/explorer/templates/explorer/query.html b/explorer/templates/explorer/query.html index ed60550f..40e79167 100644 --- a/explorer/templates/explorer/query.html +++ b/explorer/templates/explorer/query.html @@ -150,5 +150,6 @@