Skip to content

Commit

Permalink
add optional support for query parameter declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabien Coelho committed Nov 9, 2024
1 parent d41dfb6 commit 18a09e0
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 22 deletions.
23 changes: 17 additions & 6 deletions aiosql/query_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# extract a valid query name followed by an optional operation spec
# FIXME this accepts "1st" but seems to reject "é"
_NAME_OP = re.compile(r"^(?P<name>\w+)(?P<op>(|\^|\$|!|<!|\*!|#))$")
_NAME_OP = re.compile(r"^(?P<name>\w+)(|\((?P<params>(\s*|\s*\w+\s*(,\s*\w+\s*)*))\))(?P<op>(|\^|\$|!|<!|\*!|#))$")

# forbid numbers as first character
_BAD_PREFIX = re.compile(r"^\d")
Expand Down Expand Up @@ -118,12 +118,12 @@ def _make_query_datum(
# - ns_parts: name space parts, i.e. subdirectories of loaded files
# - floc: file name and lineno the query was extracted from
lines = [line.strip() for line in query.strip().splitlines()]
qname, qop = self._get_name_op(lines[0])
qname, qop, qsig = self._get_name_op(lines[0])
if re.search(r"[^A-Za-z0-9_]", qname):
log.warning(f"non ASCII character in query name: {qname}")
record_class = self._get_record_class(lines[1])
sql, doc = self._get_sql_doc(lines[2 if record_class else 1 :])
signature = self._build_signature(sql)
signature = self._build_signature(sql, qname, qsig)
query_fqn = ".".join(ns_parts + [qname])
if self.attribute: # :u.a -> :u__a, **after** signature generation
sql, attributes = _preprocess_object_attributes(self.attribute, sql)
Expand All @@ -132,13 +132,18 @@ def _make_query_datum(
sql = self.driver_adapter.process_sql(query_fqn, qop, sql)
return QueryDatum(query_fqn, doc, qop, sql, record_class, signature, floc, attributes)

def _get_name_op(self, text: str) -> Tuple[str, SQLOperationType]:
def _get_name_op(self, text: str) -> Tuple[str, SQLOperationType, List[str]|None]:
qname_spec = text.replace("-", "_")
matched = _NAME_OP.match(qname_spec)
if not matched or _BAD_PREFIX.match(qname_spec):
raise SQLParseException(f'invalid query name and operation spec: "{qname_spec}"')
nameop = matched.groupdict()
return nameop["name"], _OP_TYPES[nameop["op"]]
params, rawparams = None, nameop["params"]
if rawparams is not None:
params = [p.strip() for p in rawparams.split(",")]
if params == ['']: # handle "( )"
params = []
return nameop["name"], _OP_TYPES[nameop["op"]], params

def _get_record_class(self, text: str) -> Optional[Type]:
rc_match = _RECORD_DEF.match(text)
Expand All @@ -157,7 +162,7 @@ def _get_sql_doc(self, lines: Sequence[str]) -> Tuple[str, str]:

return sql.strip(), doc.rstrip()

def _build_signature(self, sql: str) -> inspect.Signature:
def _build_signature(self, sql: str, qname: str, sig: List[str]|None) -> inspect.Signature:
params = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
names = set()
for match in VAR_REF.finditer(sql):
Expand All @@ -167,13 +172,19 @@ def _build_signature(self, sql: str) -> inspect.Signature:
name = gd["var_name"]
if name.isdigit() or name in names:
continue
if sig is not None: # optional parameter declarations
if name not in sig:
raise SQLParseException(f"undeclared parameter name in query {qname}: {name}")
names.add(name)
params.append(
inspect.Parameter(
name=name,
kind=inspect.Parameter.KEYWORD_ONLY,
)
)
if sig is not None and len(sig) != len(names):
unused = sorted(n for n in sig if n not in names)
raise SQLParseException(f"unused declared parameter in query {qname}: {unused}")
return inspect.Signature(parameters=params)

def load_query_data_from_sql(
Expand Down
11 changes: 11 additions & 0 deletions docs/source/defining-sql-queries.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ into underlines (``_``).

This query will be available in aiosql under the python method name ``.get_all_blogs(conn)``

Query Parameters
----------------

Query parameters may be declared in parentheses just after the method name.

.. literalinclude:: ../../tests/blogdb/sql/blogs/blogs.sql
:language: sql
:lines: 55,56

When declared they are checked, raising errors when parameters are unused or undeclared.

Query Comments
--------------

Expand Down
10 changes: 6 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Badges
..
NOTE all tests
# MIST
loading: 15
loading: 16
patterns: 5
# SYNC
sqlite3: 17
Expand All @@ -92,7 +92,7 @@ Badges
# ASYNC
aiosqlite: 13
asyncpg: 18
.. image:: https://img.shields.io/badge/tests-245%20✓-success
.. image:: https://img.shields.io/badge/tests-246%20✓-success
:alt: Tests
:target: https://github.com/nackjicholson/aiosql/actions/
.. image:: https://img.shields.io/github/issues/nackjicholson/aiosql?style=flat
Expand Down Expand Up @@ -148,13 +148,13 @@ eg this *greetings.sql* file:

.. code:: sql
-- name: get_all_greetings
-- name: get_all_greetings()
-- Get all the greetings in the database
select greeting_id, greeting
from greetings
order by 1;
-- name: get_user_by_username^
-- name: get_user_by_username(username)^
-- Get a user from the database using a named parameter
select user_id, username, name
from users
Expand All @@ -164,6 +164,8 @@ This example has an imaginary SQLite database with greetings and users.
It prints greetings in various languages to the user and showcases the basic
feature of being able to load queries from a SQL file and call them by name
in python code.
Query parameter declarations (eg ``(username)``) are optional, but enforced
when provided.

You can use ``aiosql`` to load the queries in this file for use in your Python
application:
Expand Down
11 changes: 3 additions & 8 deletions docs/source/versions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,16 @@ TODO
- rethink record classes? we just really want a row conversion function?
- add documentation about docker runs.
- allow tagging queries, eg whether it can be cached
- add ability to _declare_ named query parameters for readability and reliability,
allowing to check for unused or undeclared parameters

```sql
-- name: get_foo_by_id(id)^
SELECT * FROM Foo WHERE fooid = :id:
```

? on ?
------

- improve Makefile.
- add optional parameter declarations to queries, and check them when provided.
- warn on probable mission operation.
- add *psycopg2* to CI.
- improve documentation.
- improve Makefile.
- silent some test warnings.

12.2 on 2024-10-02
------------------
Expand Down
8 changes: 4 additions & 4 deletions tests/blogdb/sql/blogs/blogs.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ values (
-- Remove a blog from the database
delete from blogs where blogid = :blogid;

-- name: get-user-blogs
-- name: get-user-blogs(userid)
-- record_class: UserBlogSummary
-- Get blogs authored by a user.
select title AS title,
Expand All @@ -43,7 +43,7 @@ delete from blogs where blogid = :blogid;
order by published desc;


-- name: get-latest-user-blog^
-- name: get-latest-user-blog(userid)^
-- record_class: UserBlogSummary
-- Get latest blog by user.
select title AS title, published AS published
Expand All @@ -52,8 +52,8 @@ where userid = :userid
order by published desc
limit 1;

-- name: search
select title from blogs where title = :title and published = :published;
-- name: search(title, published)
select title from blogs where title LIKE :title and published = :published;

-- name: blog_title^
select blogid, title from blogs where blogid = :blogid;
Expand Down
25 changes: 25 additions & 0 deletions tests/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,28 @@ def test_kwargs():
pytest.fail("must raise an exception") # pragma: no cover
except ValueError as e:
assert "mix" in str(e)

def test_parameter_declarations():
# ok
import sqlite3
conn = sqlite3.connect(":memory:")
q = aiosql.from_str(
"-- name: xlii()$\nSELECT 42;\n"
"-- name: next(n)$\nSELECT :n+1;\n"
"-- name: add(n, m)$\nSELECT :n+:m;\n",
"sqlite3")
assert q.xlii(conn) == 42
assert q.next(conn, n=41) == 42
assert q.add(conn, n=20, m=22) == 42
conn.close()
# errors
try:
aiosql.from_str("-- name: foo()\nSELECT :N + 1;\n", "sqlite3")
pytest.fail("must raise an exception")
except SQLParseException as e:
assert "undeclared" in str(e) and "N" in str(e)
try:
aiosql.from_str("-- name: foo(N, M)\nSELECT :N + 1;\n", "sqlite3")
pytest.fail("must raise an exception")
except SQLParseException as e:
assert "unused" in str(e) and "M" in str(e)

0 comments on commit 18a09e0

Please sign in to comment.