Skip to content

Commit

Permalink
first pass, introduce new default /$DB/-/query endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
asg017 committed Jul 3, 2024
1 parent c2e8e50 commit fb17380
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 76 deletions.
6 changes: 5 additions & 1 deletion datasette/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .events import Event
from .views import Context
from .views.base import ureg
from .views.database import database_download, DatabaseView, TableCreateView
from .views.database import database_download, DatabaseView, TableCreateView, QueryView
from .views.index import IndexView
from .views.special import (
JsonDataView,
Expand Down Expand Up @@ -1578,6 +1578,10 @@ def add_route(view, regex):
r"/(?P<database>[^\/\.]+)(\.(?P<format>\w+))?$",
)
add_route(TableCreateView.as_view(self), r"/(?P<database>[^\/\.]+)/-/create$")
add_route(
wrap_view(QueryView, self),
r"/(?P<database>[^\/\.]+)/-/query(\.(?P<format>\w+))?$",
)
add_route(
wrap_view(table_view, self),
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)(\.(?P<format>\w+))?$",
Expand Down
8 changes: 8 additions & 0 deletions datasette/views/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ async def get(self, request, datasette):

sql = (request.args.get("sql") or "").strip()
if sql:
redirect_url = "/" + request.url_vars.get("database") + "/-/query"
if request.url_vars.get("format"):
redirect_url += "." + request.url_vars.get("format")
redirect_url += "?" + request.query_string
return Response.redirect(redirect_url)
return await QueryView()(request, datasette)

if format_ not in ("html", "json"):
Expand Down Expand Up @@ -433,6 +438,8 @@ async def post(self, request, datasette):
async def get(self, request, datasette):
from datasette.app import TableNotFound

await datasette.refresh_schemas()

db = await datasette.resolve_database(request)
database = db.name

Expand Down Expand Up @@ -686,6 +693,7 @@ async def fetch_data_for_csv(request, _next=None):
if allow_execute_sql and is_validated_sql and ":_" not in sql:
edit_sql_url = (
datasette.urls.database(database)
+ "/-/query"
+ "?"
+ urlencode(
{
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/my_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def query_actions(datasette, database, query_name, sql):
return [
{
"href": datasette.urls.database(database)
+ "/-/query"
+ "?"
+ urllib.parse.urlencode(
{
Expand Down
38 changes: 27 additions & 11 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def test_no_files_uses_memory_database(app_client_no_files):
} == response.json
# Try that SQL query
response = app_client_no_files.get(
"/_memory.json?sql=select+sqlite_version()&_shape=array"
"/_memory/-/query.json?sql=select+sqlite_version()&_shape=array"
)
assert 1 == len(response.json)
assert ["sqlite_version()"] == list(response.json[0].keys())
Expand Down Expand Up @@ -653,7 +653,7 @@ def test_database_page_for_database_with_dot_in_name(app_client_with_dot):
@pytest.mark.asyncio
async def test_custom_sql(ds_client):
response = await ds_client.get(
"/fixtures.json?sql=select+content+from+simple_primary_key"
"/fixtures/-/query.json?sql=select+content+from+simple_primary_key",
)
data = response.json()
assert data == {
Expand All @@ -670,7 +670,9 @@ async def test_custom_sql(ds_client):


def test_sql_time_limit(app_client_shorter_time_limit):
response = app_client_shorter_time_limit.get("/fixtures.json?sql=select+sleep(0.5)")
response = app_client_shorter_time_limit.get(
"/fixtures/-/query.json?sql=select+sleep(0.5)",
)
assert 400 == response.status
assert response.json == {
"ok": False,
Expand All @@ -691,16 +693,22 @@ def test_sql_time_limit(app_client_shorter_time_limit):

@pytest.mark.asyncio
async def test_custom_sql_time_limit(ds_client):
response = await ds_client.get("/fixtures.json?sql=select+sleep(0.01)")
response = await ds_client.get(
"/fixtures/-/query.json?sql=select+sleep(0.01)",
)
assert response.status_code == 200
response = await ds_client.get("/fixtures.json?sql=select+sleep(0.01)&_timelimit=5")
response = await ds_client.get(
"/fixtures/-/query.json?sql=select+sleep(0.01)&_timelimit=5",
)
assert response.status_code == 400
assert response.json()["title"] == "SQL Interrupted"


@pytest.mark.asyncio
async def test_invalid_custom_sql(ds_client):
response = await ds_client.get("/fixtures.json?sql=.schema")
response = await ds_client.get(
"/fixtures/-/query.json?sql=.schema",
)
assert response.status_code == 400
assert response.json()["ok"] is False
assert "Statement must be a SELECT" == response.json()["error"]
Expand Down Expand Up @@ -883,9 +891,13 @@ async def test_json_columns(ds_client, extra_args, expected):
select 1 as intval, "s" as strval, 0.5 as floatval,
'{"foo": "bar"}' as jsonval
"""
path = "/fixtures.json?" + urllib.parse.urlencode({"sql": sql, "_shape": "array"})
path = "/fixtures/-/query.json?" + urllib.parse.urlencode(
{"sql": sql, "_shape": "array"}
)
path += extra_args
response = await ds_client.get(path)
response = await ds_client.get(
path,
)
assert response.json() == expected


Expand Down Expand Up @@ -917,7 +929,7 @@ def test_config_force_https_urls():
("/fixtures.json", 200),
("/fixtures/no_primary_key.json", 200),
# A 400 invalid SQL query should still have the header:
("/fixtures.json?sql=select+blah", 400),
("/fixtures/-/query.json?sql=select+blah", 400),
# Write APIs
("/fixtures/-/create", 405),
("/fixtures/facetable/-/insert", 405),
Expand All @@ -930,7 +942,9 @@ def test_cors(
path,
status_code,
):
response = app_client_with_cors.get(path)
response = app_client_with_cors.get(
path,
)
assert response.status == status_code
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert (
Expand All @@ -946,7 +960,9 @@ def test_cors(
# should not have those headers - I'm using that fixture because
# regular app_client doesn't have immutable fixtures.db which means
# the test for /fixtures.db returns a 403 error
response = app_client_two_attached_databases_one_immutable.get(path)
response = app_client_two_attached_databases_one_immutable.get(
path,
)
assert response.status == status_code
assert "Access-Control-Allow-Origin" not in response.headers
assert "Access-Control-Allow-Headers" not in response.headers
Expand Down
12 changes: 9 additions & 3 deletions tests/test_api_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,15 +637,19 @@ async def test_delete_row(ds_write, table, row_for_create, pks, delete_path):
# Should be a single row
assert (
await ds_write.client.get(
"/data.json?_shape=arrayfirst&sql=select+count(*)+from+{}".format(table)
"/data/-/query.json?_shape=arrayfirst&sql=select+count(*)+from+{}".format(
table
)
)
).json() == [1]
# Now delete the row
if delete_path is None:
# Special case for that rowid table
delete_path = (
await ds_write.client.get(
"/data.json?_shape=arrayfirst&sql=select+rowid+from+{}".format(table)
"/data/-/query.json?_shape=arrayfirst&sql=select+rowid+from+{}".format(
table
)
)
).json()[0]

Expand All @@ -663,7 +667,9 @@ async def test_delete_row(ds_write, table, row_for_create, pks, delete_path):
assert event.pks == str(delete_path).split(",")
assert (
await ds_write.client.get(
"/data.json?_shape=arrayfirst&sql=select+count(*)+from+{}".format(table)
"/data/-/query.json?_shape=arrayfirst&sql=select+count(*)+from+{}".format(
table
)
)
).json() == [0]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_canned_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_magic_parameters_csrf_json(magic_parameters_client, use_csrf, return_js

def test_magic_parameters_cannot_be_used_in_arbitrary_queries(magic_parameters_client):
response = magic_parameters_client.get(
"/data.json?sql=select+:_header_host&_shape=array"
"/data/-/query.json?sql=select+:_header_host&_shape=array"
)
assert 400 == response.status
assert response.json["error"].startswith("You did not supply a value for binding")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_plugin_s_overwrite():
"--plugins-dir",
plugins_dir,
"--get",
"/_memory.json?sql=select+prepare_connection_args()",
"/_memory/-/query.json?sql=select+prepare_connection_args()",
],
)
assert result.exit_code == 0, result.output
Expand All @@ -265,7 +265,7 @@ def test_plugin_s_overwrite():
"--plugins-dir",
plugins_dir,
"--get",
"/_memory.json?sql=select+prepare_connection_args()",
"/_memory/-/query.json?sql=select+prepare_connection_args()",
"-s",
"plugins.name-of-plugin",
"OVERRIDE",
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_setting_default_allow_sql(default_allow_sql):
"default_allow_sql",
"on" if default_allow_sql else "off",
"--get",
"/_memory.json?sql=select+21&_shape=objects",
"/_memory/-/query.json?sql=select+21&_shape=objects",
],
)
if default_allow_sql:
Expand All @@ -309,7 +309,7 @@ def test_setting_default_allow_sql(default_allow_sql):

def test_sql_errors_logged_to_stderr():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(cli, ["--get", "/_memory.json?sql=select+blah"])
result = runner.invoke(cli, ["--get", "/_memory/-/query.json?sql=select+blah"])
assert result.exit_code == 1
assert "sql = 'select blah', params = {}: no such column: blah\n" in result.stderr

Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_serve_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def startup(datasette):
"--plugins-dir",
str(plugins_dir),
"--get",
"/_memory.json?sql=select+sqlite_version()",
"/_memory/-/query.json?sql=select+sqlite_version()",
],
)
assert result.exit_code == 0, result.output
Expand Down
3 changes: 2 additions & 1 deletion tests/test_crossdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_crossdb_join(app_client_two_attached_databases_crossdb_enabled):
fixtures.searchable
"""
response = app_client.get(
"/_memory.json?" + urllib.parse.urlencode({"sql": sql, "_shape": "array"})
"/_memory/-/query.json?"
+ urllib.parse.urlencode({"sql": sql, "_shape": "array"})
)
assert response.status == 200
assert response.json == [
Expand Down
10 changes: 5 additions & 5 deletions tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,22 @@ async def test_table_csv_blob_columns(ds_client):
@pytest.mark.asyncio
async def test_custom_sql_csv_blob_columns(ds_client):
response = await ds_client.get(
"/fixtures.csv?sql=select+rowid,+data+from+binary_data"
"/fixtures/-/query.csv?sql=select+rowid,+data+from+binary_data"
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
assert response.text == (
"rowid,data\r\n"
'1,"http://localhost/fixtures.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=f3088978da8f9aea479ffc7f631370b968d2e855eeb172bea7f6c7a04262bb6d"\r\n'
'2,"http://localhost/fixtures.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=b835b0483cedb86130b9a2c280880bf5fadc5318ddf8c18d0df5204d40df1724"\r\n'
'1,"http://localhost/fixtures/-/query.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=f3088978da8f9aea479ffc7f631370b968d2e855eeb172bea7f6c7a04262bb6d"\r\n'
'2,"http://localhost/fixtures/-/query.blob?sql=select+rowid,+data+from+binary_data&_blob_column=data&_blob_hash=b835b0483cedb86130b9a2c280880bf5fadc5318ddf8c18d0df5204d40df1724"\r\n'
"3,\r\n"
)


@pytest.mark.asyncio
async def test_custom_sql_csv(ds_client):
response = await ds_client.get(
"/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2"
"/fixtures/-/query.csv?sql=select+content+from+simple_primary_key+limit+2"
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
Expand All @@ -182,7 +182,7 @@ async def test_table_csv_download(ds_client):
@pytest.mark.asyncio
async def test_csv_with_non_ascii_characters(ds_client):
response = await ds_client.get(
"/fixtures.csv?sql=select%0D%0A++%27%F0%9D%90%9C%F0%9D%90%A2%F0%9D%90%AD%F0%9D%90%A2%F0%9D%90%9E%F0%9D%90%AC%27+as+text%2C%0D%0A++1+as+number%0D%0Aunion%0D%0Aselect%0D%0A++%27bob%27+as+text%2C%0D%0A++2+as+number%0D%0Aorder+by%0D%0A++number"
"/fixtures/-/query.csv?sql=select%0D%0A++%27%F0%9D%90%9C%F0%9D%90%A2%F0%9D%90%AD%F0%9D%90%A2%F0%9D%90%9E%F0%9D%90%AC%27+as+text%2C%0D%0A++1+as+number%0D%0Aunion%0D%0Aselect%0D%0A++%27bob%27+as+text%2C%0D%0A++2+as+number%0D%0Aorder+by%0D%0A++number"
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
Expand Down
Loading

0 comments on commit fb17380

Please sign in to comment.