bpo-45138: Revert GH-28240: Expand traced SQL statements by erlend-aasland · Pull Request #31788 · python/cpython
import contextlib import sqlite3 as sqlite import unittest import sqlite3 as sqlite
from test.support.os_helper import TESTFN, unlink
from test.test_sqlite3.test_dbapi import memory_database, cx_limit from test.test_sqlite3.test_userfunctions import with_tracebacks
class CollationTests(unittest.TestCase): def test_create_collation_not_string(self): con = sqlite.connect(":memory:")
class TraceCallbackTests(unittest.TestCase): @contextlib.contextmanager def check_stmt_trace(self, cx, expected): try: traced = [] cx.set_trace_callback(lambda stmt: traced.append(stmt)) yield finally: self.assertEqual(traced, expected) cx.set_trace_callback(None)
def test_trace_callback_used(self): """ Test that the trace callback is invoked once it is set.
@unittest.skipIf(sqlite.sqlite_version_info < (3, 14, 0), "Requires SQLite 3.14.0 or newer") def test_trace_expanded_sql(self): expected = [ "create table t(t)", "BEGIN ", "insert into t values(0)", "insert into t values(1)", "insert into t values(2)", "COMMIT", ] with memory_database() as cx, self.check_stmt_trace(cx, expected): with cx: cx.execute("create table t(t)") cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
@with_tracebacks( sqlite.DataError, regex="Expanded SQL string exceeds the maximum string length" ) def test_trace_too_much_expanded_sql(self): # If the expanded string is too large, we'll fall back to the # unexpanded SQL statement. The resulting string length is limited by # SQLITE_LIMIT_LENGTH. template = "select 'b' as \"a\" from sqlite_master where \"a\"=" category = sqlite.SQLITE_LIMIT_LENGTH with memory_database() as cx, cx_limit(cx, category=category) as lim: nextra = lim - (len(template) + 2) - 1 ok_param = "a" * nextra bad_param = "a" * (nextra + 1)
unexpanded_query = template + "?" with self.check_stmt_trace(cx, [unexpanded_query]): cx.execute(unexpanded_query, (bad_param,))
expanded_query = f"{template}'{ok_param}'" with self.check_stmt_trace(cx, [expanded_query]): cx.execute(unexpanded_query, (ok_param,))
@with_tracebacks(ZeroDivisionError, regex="division by zero") def test_trace_bad_handler(self): with memory_database() as cx: cx.set_trace_callback(lambda stmt: 5/0) cx.execute("select 1")
if __name__ == "__main__": unittest.main()