diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 39dbbb95300331..825f335caef3f9 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -751,22 +751,44 @@ def test_execute_illegal_sql(self): with self.assertRaises(sqlite.OperationalError): self.cu.execute("select asdf") - def test_execute_too_much_sql(self): - self.assertRaisesRegex(sqlite.ProgrammingError, - "You can only execute one statement at a time", - self.cu.execute, "select 5+4; select 4+5") - - def test_execute_too_much_sql2(self): - self.cu.execute("select 5+4; -- foo bar") + def test_execute_multiple_statements(self): + msg = "You can only execute one statement at a time" + dataset = ( + "select 1; select 2", + "select 1; // c++ comments are not allowed", + "select 1; *not a comment", + "select 1; -*not a comment", + "select 1; /* */ a", + "select 1; /**/a", + "select 1; -", + "select 1; /", + "select 1; -\n- select 2", + """select 1; + -- comment + select 2 + """, + ) + for query in dataset: + with self.subTest(query=query): + with self.assertRaisesRegex(sqlite.ProgrammingError, msg): + self.cu.execute(query) - def test_execute_too_much_sql3(self): - self.cu.execute(""" + def test_execute_with_appended_comments(self): + dataset = ( + "select 1; -- foo bar", + "select 1; --", + "select 1; /*", # Unclosed comments ending in \0 are skipped. + """ select 5+4; /* foo */ - """) + """, + ) + for query in dataset: + with self.subTest(query=query): + self.cu.execute(query) def test_execute_wrong_sql_arg(self): with self.assertRaises(TypeError): @@ -911,6 +933,30 @@ def test_rowcount_update_returning(self): self.assertEqual(self.cu.fetchone()[0], 1) self.assertEqual(self.cu.rowcount, 1) + def test_rowcount_prefixed_with_comment(self): + # gh-79579: rowcount is updated even if query is prefixed with comments + self.cu.execute(""" + -- foo + insert into test(name) values ('foo'), ('foo') + """) + self.assertEqual(self.cu.rowcount, 2) + self.cu.execute(""" + /* -- messy *r /* /* ** *- *-- + */ + /* one more */ insert into test(name) values ('messy') + """) + self.assertEqual(self.cu.rowcount, 1) + self.cu.execute("/* bar */ update test set name='bar' where name='foo'") + self.assertEqual(self.cu.rowcount, 3) + + def test_rowcount_vaccuum(self): + data = ((1,), (2,), (3,)) + self.cu.executemany("insert into test(income) values(?)", data) + self.assertEqual(self.cu.rowcount, 3) + self.cx.commit() + self.cu.execute("vacuum") + self.assertEqual(self.cu.rowcount, -1) + def test_total_changes(self): self.cu.execute("insert into test(name) values ('foo')") self.cu.execute("insert into test(name) values ('foo')") diff --git a/Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst b/Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst new file mode 100644 index 00000000000000..82b1a1c28a6001 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-06-06-12-58-27.gh-issue-79579.e8rB-M.rst @@ -0,0 +1,2 @@ +:mod:`sqlite3` now correctly detects DML queries with leading comments. +Patch by Erlend E. Aasland. diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index f9cb70f0ef146c..aee460747b45f4 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -26,16 +26,7 @@ #include "util.h" /* prototypes */ -static int pysqlite_check_remaining_sql(const char* tail); - -typedef enum { - LINECOMMENT_1, - IN_LINECOMMENT, - COMMENTSTART_1, - IN_COMMENT, - COMMENTEND_1, - NORMAL -} parse_remaining_sql_state; +static const char *lstrip_sql(const char *sql); pysqlite_Statement * pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) @@ -73,7 +64,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) return NULL; } - if (pysqlite_check_remaining_sql(tail)) { + if (lstrip_sql(tail) != NULL) { PyErr_SetString(connection->ProgrammingError, "You can only execute one statement at a time."); goto error; @@ -82,20 +73,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) /* Determine if the statement is a DML statement. SELECT is the only exception. See #9924. */ int is_dml = 0; - for (const char *p = sql_cstr; *p != 0; p++) { - switch (*p) { - case ' ': - case '\r': - case '\n': - case '\t': - continue; - } - + const char *p = lstrip_sql(sql_cstr); + if (p != NULL) { is_dml = (PyOS_strnicmp(p, "insert", 6) == 0) || (PyOS_strnicmp(p, "update", 6) == 0) || (PyOS_strnicmp(p, "delete", 6) == 0) || (PyOS_strnicmp(p, "replace", 7) == 0); - break; } pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement, @@ -139,73 +122,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg) } /* - * Checks if there is anything left in an SQL string after SQLite compiled it. - * This is used to check if somebody tried to execute more than one SQL command - * with one execute()/executemany() command, which the DB-API and we don't - * allow. + * Strip leading whitespace and comments from incoming SQL (null terminated C + * string) and return a pointer to the first non-whitespace, non-comment + * character. * - * Returns 1 if there is more left than should be. 0 if ok. + * This is used to check if somebody tries to execute more than one SQL query + * with one execute()/executemany() command, which the DB-API don't allow. + * + * It is also used to harden DML query detection. */ -static int pysqlite_check_remaining_sql(const char* tail) +static inline const char * +lstrip_sql(const char *sql) { - const char* pos = tail; - - parse_remaining_sql_state state = NORMAL; - - for (;;) { + // This loop is borrowed from the SQLite source code. + for (const char *pos = sql; *pos; pos++) { switch (*pos) { - case 0: - return 0; - case '-': - if (state == NORMAL) { - state = LINECOMMENT_1; - } else if (state == LINECOMMENT_1) { - state = IN_LINECOMMENT; - } - break; case ' ': case '\t': - break; + case '\f': case '\n': - case 13: - if (state == IN_LINECOMMENT) { - state = NORMAL; - } + case '\r': + // Skip whitespace. break; - case '/': - if (state == NORMAL) { - state = COMMENTSTART_1; - } else if (state == COMMENTEND_1) { - state = NORMAL; - } else if (state == COMMENTSTART_1) { - return 1; + case '-': + // Skip line comments. + if (pos[1] == '-') { + pos += 2; + while (pos[0] && pos[0] != '\n') { + pos++; + } + if (pos[0] == '\0') { + return NULL; + } + continue; } - break; - case '*': - if (state == NORMAL) { - return 1; - } else if (state == LINECOMMENT_1) { - return 1; - } else if (state == COMMENTSTART_1) { - state = IN_COMMENT; - } else if (state == IN_COMMENT) { - state = COMMENTEND_1; + return pos; + case '/': + // Skip C style comments. + if (pos[1] == '*') { + pos += 2; + while (pos[0] && (pos[0] != '*' || pos[1] != '/')) { + pos++; + } + if (pos[0] == '\0') { + return NULL; + } + pos++; + continue; } - break; + return pos; default: - if (state == COMMENTEND_1) { - state = IN_COMMENT; - } else if (state == IN_LINECOMMENT) { - } else if (state == IN_COMMENT) { - } else { - return 1; - } + return pos; } - - pos++; } - return 0; + return NULL; } static PyType_Slot stmt_slots[] = {