forked from Raiza.dev/EliteBot
Cleaned up the directories
This commit is contained in:
parent
f708506d68
commit
a683fcffea
1340 changed files with 554582 additions and 6840 deletions
|
@ -0,0 +1,95 @@
|
|||
# testing/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from . import config
|
||||
from .assertions import assert_raises
|
||||
from .assertions import assert_raises_context_ok
|
||||
from .assertions import assert_raises_message
|
||||
from .assertions import assert_raises_message_context_ok
|
||||
from .assertions import assert_warns
|
||||
from .assertions import assert_warns_message
|
||||
from .assertions import AssertsCompiledSQL
|
||||
from .assertions import AssertsExecutionResults
|
||||
from .assertions import ComparesIndexes
|
||||
from .assertions import ComparesTables
|
||||
from .assertions import emits_warning
|
||||
from .assertions import emits_warning_on
|
||||
from .assertions import eq_
|
||||
from .assertions import eq_ignore_whitespace
|
||||
from .assertions import eq_regex
|
||||
from .assertions import expect_deprecated
|
||||
from .assertions import expect_deprecated_20
|
||||
from .assertions import expect_raises
|
||||
from .assertions import expect_raises_message
|
||||
from .assertions import expect_warnings
|
||||
from .assertions import in_
|
||||
from .assertions import int_within_variance
|
||||
from .assertions import is_
|
||||
from .assertions import is_false
|
||||
from .assertions import is_instance_of
|
||||
from .assertions import is_none
|
||||
from .assertions import is_not
|
||||
from .assertions import is_not_
|
||||
from .assertions import is_not_none
|
||||
from .assertions import is_true
|
||||
from .assertions import le_
|
||||
from .assertions import ne_
|
||||
from .assertions import not_in
|
||||
from .assertions import not_in_
|
||||
from .assertions import startswith_
|
||||
from .assertions import uses_deprecated
|
||||
from .config import add_to_marker
|
||||
from .config import async_test
|
||||
from .config import combinations
|
||||
from .config import combinations_list
|
||||
from .config import db
|
||||
from .config import fixture
|
||||
from .config import requirements as requires
|
||||
from .config import skip_test
|
||||
from .config import Variation
|
||||
from .config import variation
|
||||
from .config import variation_fixture
|
||||
from .exclusions import _is_excluded
|
||||
from .exclusions import _server_version
|
||||
from .exclusions import against as _against
|
||||
from .exclusions import db_spec
|
||||
from .exclusions import exclude
|
||||
from .exclusions import fails
|
||||
from .exclusions import fails_if
|
||||
from .exclusions import fails_on
|
||||
from .exclusions import fails_on_everything_except
|
||||
from .exclusions import future
|
||||
from .exclusions import only_if
|
||||
from .exclusions import only_on
|
||||
from .exclusions import skip
|
||||
from .exclusions import skip_if
|
||||
from .schema import eq_clause_element
|
||||
from .schema import eq_type_affinity
|
||||
from .util import adict
|
||||
from .util import fail
|
||||
from .util import flag_combinations
|
||||
from .util import force_drop_names
|
||||
from .util import lambda_combinations
|
||||
from .util import metadata_fixture
|
||||
from .util import provide_metadata
|
||||
from .util import resolve_lambda
|
||||
from .util import rowset
|
||||
from .util import run_as_contextmanager
|
||||
from .util import teardown_events
|
||||
from .warnings import assert_warnings
|
||||
from .warnings import warn_test_suite
|
||||
|
||||
|
||||
def against(*queries):
|
||||
return _against(config._current, *queries)
|
||||
|
||||
|
||||
crashes = skip
|
|
@ -0,0 +1,989 @@
|
|||
# testing/assertions.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import contextlib
|
||||
from copy import copy
|
||||
from itertools import filterfalse
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from . import assertsql
|
||||
from . import config
|
||||
from . import engines
|
||||
from . import mock
|
||||
from .exclusions import db_spec
|
||||
from .util import fail
|
||||
from .. import exc as sa_exc
|
||||
from .. import schema
|
||||
from .. import sql
|
||||
from .. import types as sqltypes
|
||||
from .. import util
|
||||
from ..engine import default
|
||||
from ..engine import url
|
||||
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
from ..util import decorator
|
||||
|
||||
|
||||
def expect_warnings(*messages, **kw):
|
||||
"""Context manager which expects one or more warnings.
|
||||
|
||||
With no arguments, squelches all SAWarning emitted via
|
||||
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||||
pass string expressions that will match selected warnings via regex;
|
||||
all non-matching warnings are sent through.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||||
|
||||
""" # noqa
|
||||
return _expect_warnings_sqla_only(sa_exc.SAWarning, messages, **kw)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_warnings_on(db, *messages, **kw):
|
||||
"""Context manager which expects one or more warnings on specific
|
||||
dialects.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
if isinstance(db, str) and not spec(config._current):
|
||||
yield
|
||||
else:
|
||||
with expect_warnings(*messages, **kw):
|
||||
yield
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Decorator form of expect_warnings().
|
||||
|
||||
Note that emits_warning does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings(assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings_sqla_only(
|
||||
sa_exc.SADeprecationWarning, messages, **kw
|
||||
)
|
||||
|
||||
|
||||
def expect_deprecated_20(*messages, **kw):
|
||||
return _expect_warnings_sqla_only(
|
||||
sa_exc.Base20DeprecationWarning, messages, **kw
|
||||
)
|
||||
|
||||
|
||||
def emits_warning_on(db, *messages):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
|
||||
Note that emits_warning_on does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings_on(db, assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def uses_deprecated(*messages):
|
||||
"""Mark a test as immune from fatal deprecation warnings.
|
||||
|
||||
With no arguments, squelches all SADeprecationWarning failures.
|
||||
Or pass one or more strings; these will be matched to the root
|
||||
of the warning description by warnings.filterwarnings().
|
||||
|
||||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
|
||||
Note that uses_deprecated does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_deprecated(*messages, assert_=False):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
_FILTERS = None
|
||||
_SEEN = None
|
||||
_EXC_CLS = None
|
||||
|
||||
|
||||
def _expect_warnings_sqla_only(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=True,
|
||||
search_msg=False,
|
||||
assert_=True,
|
||||
):
|
||||
"""SQLAlchemy internal use only _expect_warnings().
|
||||
|
||||
Alembic is using _expect_warnings() directly, and should be updated
|
||||
to use this new interface.
|
||||
|
||||
"""
|
||||
return _expect_warnings(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=regex,
|
||||
search_msg=search_msg,
|
||||
assert_=assert_,
|
||||
raise_on_any_unexpected=True,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_warnings(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=True,
|
||||
search_msg=False,
|
||||
assert_=True,
|
||||
raise_on_any_unexpected=False,
|
||||
squelch_other_warnings=False,
|
||||
):
|
||||
global _FILTERS, _SEEN, _EXC_CLS
|
||||
|
||||
if regex or search_msg:
|
||||
filters = [re.compile(msg, re.I | re.S) for msg in messages]
|
||||
else:
|
||||
filters = list(messages)
|
||||
|
||||
if _FILTERS is not None:
|
||||
# nested call; update _FILTERS and _SEEN, return. outer
|
||||
# block will assert our messages
|
||||
assert _SEEN is not None
|
||||
assert _EXC_CLS is not None
|
||||
_FILTERS.extend(filters)
|
||||
_SEEN.update(filters)
|
||||
_EXC_CLS += (exc_cls,)
|
||||
yield
|
||||
else:
|
||||
seen = _SEEN = set(filters)
|
||||
_FILTERS = filters
|
||||
_EXC_CLS = (exc_cls,)
|
||||
|
||||
if raise_on_any_unexpected:
|
||||
|
||||
def real_warn(msg, *arg, **kw):
|
||||
raise AssertionError("Got unexpected warning: %r" % msg)
|
||||
|
||||
else:
|
||||
real_warn = warnings.warn
|
||||
|
||||
def our_warn(msg, *arg, **kw):
|
||||
if isinstance(msg, _EXC_CLS):
|
||||
exception = type(msg)
|
||||
msg = str(msg)
|
||||
elif arg:
|
||||
exception = arg[0]
|
||||
else:
|
||||
exception = None
|
||||
|
||||
if not exception or not issubclass(exception, _EXC_CLS):
|
||||
if not squelch_other_warnings:
|
||||
return real_warn(msg, *arg, **kw)
|
||||
else:
|
||||
return
|
||||
|
||||
if not filters and not raise_on_any_unexpected:
|
||||
return
|
||||
|
||||
for filter_ in filters:
|
||||
if (
|
||||
(search_msg and filter_.search(msg))
|
||||
or (regex and filter_.match(msg))
|
||||
or (not regex and filter_ == msg)
|
||||
):
|
||||
seen.discard(filter_)
|
||||
break
|
||||
else:
|
||||
if not squelch_other_warnings:
|
||||
real_warn(msg, *arg, **kw)
|
||||
|
||||
with mock.patch("warnings.warn", our_warn):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_SEEN = _FILTERS = _EXC_CLS = None
|
||||
|
||||
if assert_:
|
||||
assert not seen, "Warnings were not seen: %s" % ", ".join(
|
||||
"%r" % (s.pattern if regex else s) for s in seen
|
||||
)
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
Hardcoded at the moment, a modular system can be built here
|
||||
to support things like PG prepared transactions, tables all
|
||||
dropped, etc.
|
||||
|
||||
"""
|
||||
_assert_no_stray_pool_connections()
|
||||
|
||||
|
||||
def _assert_no_stray_pool_connections():
|
||||
engines.testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
def int_within_variance(expected, received, variance):
|
||||
deviance = int(expected * variance)
|
||||
assert (
|
||||
abs(received - expected) < deviance
|
||||
), "Given int value %s is not within %d%% of expected value %s" % (
|
||||
received,
|
||||
variance * 100,
|
||||
expected,
|
||||
)
|
||||
|
||||
|
||||
def eq_regex(a, b, msg=None):
|
||||
assert re.match(b, a), msg or "%r !~ %r" % (a, b)
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def le_(a, b, msg=None):
|
||||
"""Assert a <= b, with repr messaging on failure."""
|
||||
assert a <= b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def is_instance_of(a, b, msg=None):
|
||||
assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
|
||||
|
||||
|
||||
def is_none(a, msg=None):
|
||||
is_(a, None, msg=msg)
|
||||
|
||||
|
||||
def is_not_none(a, msg=None):
|
||||
is_not(a, None, msg=msg)
|
||||
|
||||
|
||||
def is_true(a, msg=None):
|
||||
is_(bool(a), True, msg=msg)
|
||||
|
||||
|
||||
def is_false(a, msg=None):
|
||||
is_(bool(a), False, msg=msg)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
|
||||
def is_not(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
# deprecated. See #5429
|
||||
is_not_ = is_not
|
||||
|
||||
|
||||
def in_(a, b, msg=None):
|
||||
"""Assert a in b, with repr messaging on failure."""
|
||||
assert a in b, msg or "%r not in %r" % (a, b)
|
||||
|
||||
|
||||
def not_in(a, b, msg=None):
|
||||
"""Assert a in not b, with repr messaging on failure."""
|
||||
assert a not in b, msg or "%r is in %r" % (a, b)
|
||||
|
||||
|
||||
# deprecated. See #5429
|
||||
not_in_ = not_in
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a,
|
||||
fragment,
|
||||
)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
a = re.sub(r"^\s+?|\n", "", a)
|
||||
a = re.sub(r" {2,}", " ", a)
|
||||
a = re.sub(r"\t", "", a)
|
||||
b = re.sub(r"^\s+?|\n", "", b)
|
||||
b = re.sub(r" {2,}", " ", b)
|
||||
b = re.sub(r"\t", "", b)
|
||||
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def _assert_proper_exception_context(exception):
|
||||
"""assert that any exception we're catching does not have a __context__
|
||||
without a __cause__, and that __suppress_context__ is never set.
|
||||
|
||||
Python 3 will report nested as exceptions as "during the handling of
|
||||
error X, error Y occurred". That's not what we want to do. we want
|
||||
these exceptions in a cause chain.
|
||||
|
||||
"""
|
||||
|
||||
if (
|
||||
exception.__context__ is not exception.__cause__
|
||||
and not exception.__suppress_context__
|
||||
):
|
||||
assert False, (
|
||||
"Exception %r was correctly raised but did not set a cause, "
|
||||
"within context %r as its cause."
|
||||
% (exception, exception.__context__)
|
||||
)
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
|
||||
|
||||
|
||||
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
|
||||
return _assert_raises(except_cls, callable_, args, kw)
|
||||
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
return _assert_raises(
|
||||
except_cls, callable_, args, kwargs, msg=msg, check_context=True
|
||||
)
|
||||
|
||||
|
||||
def assert_warns(except_cls, callable_, *args, **kwargs):
|
||||
"""legacy adapter function for functions that were previously using
|
||||
assert_raises with SAWarning or similar.
|
||||
|
||||
has some workarounds to accommodate the fact that the callable completes
|
||||
with this approach rather than stopping at the exception raise.
|
||||
|
||||
|
||||
"""
|
||||
with _expect_warnings_sqla_only(except_cls, [".*"]):
|
||||
return callable_(*args, **kwargs)
|
||||
|
||||
|
||||
def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
"""legacy adapter function for functions that were previously using
|
||||
assert_raises with SAWarning or similar.
|
||||
|
||||
has some workarounds to accommodate the fact that the callable completes
|
||||
with this approach rather than stopping at the exception raise.
|
||||
|
||||
Also uses regex.search() to match the given message to the error string
|
||||
rather than regex.match().
|
||||
|
||||
"""
|
||||
with _expect_warnings_sqla_only(
|
||||
except_cls,
|
||||
[msg],
|
||||
search_msg=True,
|
||||
regex=False,
|
||||
):
|
||||
return callable_(*args, **kwargs)
|
||||
|
||||
|
||||
def assert_raises_message_context_ok(
|
||||
except_cls, msg, callable_, *args, **kwargs
|
||||
):
|
||||
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
|
||||
|
||||
|
||||
def _assert_raises(
|
||||
except_cls, callable_, args, kwargs, msg=None, check_context=False
|
||||
):
|
||||
with _expect_raises(except_cls, msg, check_context) as ec:
|
||||
callable_(*args, **kwargs)
|
||||
return ec.error
|
||||
|
||||
|
||||
class _ErrorContainer:
|
||||
error = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_raises(except_cls, msg=None, check_context=False):
|
||||
if (
|
||||
isinstance(except_cls, type)
|
||||
and issubclass(except_cls, Warning)
|
||||
or isinstance(except_cls, Warning)
|
||||
):
|
||||
raise TypeError(
|
||||
"Use expect_warnings for warnings, not "
|
||||
"expect_raises / assert_raises"
|
||||
)
|
||||
ec = _ErrorContainer()
|
||||
if check_context:
|
||||
are_we_already_in_a_traceback = sys.exc_info()[0]
|
||||
try:
|
||||
yield ec
|
||||
success = False
|
||||
except except_cls as err:
|
||||
ec.error = err
|
||||
success = True
|
||||
if msg is not None:
|
||||
# I'm often pdbing here, and "err" above isn't
|
||||
# in scope, so assign the string explicitly
|
||||
error_as_string = str(err)
|
||||
assert re.search(msg, error_as_string, re.UNICODE), "%r !~ %s" % (
|
||||
msg,
|
||||
error_as_string,
|
||||
)
|
||||
if check_context and not are_we_already_in_a_traceback:
|
||||
_assert_proper_exception_context(err)
|
||||
print(str(err).encode("utf-8"))
|
||||
|
||||
# it's generally a good idea to not carry traceback objects outside
|
||||
# of the except: block, but in this case especially we seem to have
|
||||
# hit some bug in either python 3.10.0b2 or greenlet or both which
|
||||
# this seems to fix:
|
||||
# https://github.com/python-greenlet/greenlet/issues/242
|
||||
del ec
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def expect_raises(except_cls, check_context=True):
|
||||
return _expect_raises(except_cls, check_context=check_context)
|
||||
|
||||
|
||||
def expect_raises_message(except_cls, msg, check_context=True):
|
||||
return _expect_raises(except_cls, msg=msg, check_context=check_context)
|
||||
|
||||
|
||||
class AssertsCompiledSQL:
|
||||
def assert_compile(
|
||||
self,
|
||||
clause,
|
||||
result,
|
||||
params=None,
|
||||
checkparams=None,
|
||||
for_executemany=False,
|
||||
check_literal_execute=None,
|
||||
check_post_param=None,
|
||||
dialect=None,
|
||||
checkpositional=None,
|
||||
check_prefetch=None,
|
||||
use_default_dialect=False,
|
||||
allow_dialect_select=False,
|
||||
supports_default_values=True,
|
||||
supports_default_metavalue=True,
|
||||
literal_binds=False,
|
||||
render_postcompile=False,
|
||||
schema_translate_map=None,
|
||||
render_schema_translate=False,
|
||||
default_schema_name=None,
|
||||
from_linting=False,
|
||||
check_param_order=True,
|
||||
use_literal_execute_for_simple_int=False,
|
||||
):
|
||||
if use_default_dialect:
|
||||
dialect = default.DefaultDialect()
|
||||
dialect.supports_default_values = supports_default_values
|
||||
dialect.supports_default_metavalue = supports_default_metavalue
|
||||
elif allow_dialect_select:
|
||||
dialect = None
|
||||
else:
|
||||
if dialect is None:
|
||||
dialect = getattr(self, "__dialect__", None)
|
||||
|
||||
if dialect is None:
|
||||
dialect = config.db.dialect
|
||||
elif dialect == "default" or dialect == "default_qmark":
|
||||
if dialect == "default":
|
||||
dialect = default.DefaultDialect()
|
||||
else:
|
||||
dialect = default.DefaultDialect("qmark")
|
||||
dialect.supports_default_values = supports_default_values
|
||||
dialect.supports_default_metavalue = supports_default_metavalue
|
||||
elif dialect == "default_enhanced":
|
||||
dialect = default.StrCompileDialect()
|
||||
elif isinstance(dialect, str):
|
||||
dialect = url.URL.create(dialect).get_dialect()()
|
||||
|
||||
if default_schema_name:
|
||||
dialect.default_schema_name = default_schema_name
|
||||
|
||||
kw = {}
|
||||
compile_kwargs = {}
|
||||
|
||||
if schema_translate_map:
|
||||
kw["schema_translate_map"] = schema_translate_map
|
||||
|
||||
if params is not None:
|
||||
kw["column_keys"] = list(params)
|
||||
|
||||
if literal_binds:
|
||||
compile_kwargs["literal_binds"] = True
|
||||
|
||||
if render_postcompile:
|
||||
compile_kwargs["render_postcompile"] = True
|
||||
|
||||
if use_literal_execute_for_simple_int:
|
||||
compile_kwargs["use_literal_execute_for_simple_int"] = True
|
||||
|
||||
if for_executemany:
|
||||
kw["for_executemany"] = True
|
||||
|
||||
if render_schema_translate:
|
||||
kw["render_schema_translate"] = True
|
||||
|
||||
if from_linting or getattr(self, "assert_from_linting", False):
|
||||
kw["linting"] = sql.FROM_LINTING
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
if isinstance(clause, orm.Query):
|
||||
stmt = clause._statement_20()
|
||||
stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
clause = stmt
|
||||
|
||||
if compile_kwargs:
|
||||
kw["compile_kwargs"] = compile_kwargs
|
||||
|
||||
class DontAccess:
|
||||
def __getattribute__(self, key):
|
||||
raise NotImplementedError(
|
||||
"compiler accessed .statement; use "
|
||||
"compiler.current_executable"
|
||||
)
|
||||
|
||||
class CheckCompilerAccess:
|
||||
def __init__(self, test_statement):
|
||||
self.test_statement = test_statement
|
||||
self._annotations = {}
|
||||
self.supports_execution = getattr(
|
||||
test_statement, "supports_execution", False
|
||||
)
|
||||
|
||||
if self.supports_execution:
|
||||
self._execution_options = test_statement._execution_options
|
||||
|
||||
if hasattr(test_statement, "_returning"):
|
||||
self._returning = test_statement._returning
|
||||
if hasattr(test_statement, "_inline"):
|
||||
self._inline = test_statement._inline
|
||||
if hasattr(test_statement, "_return_defaults"):
|
||||
self._return_defaults = test_statement._return_defaults
|
||||
|
||||
@property
|
||||
def _variant_mapping(self):
|
||||
return self.test_statement._variant_mapping
|
||||
|
||||
def _default_dialect(self):
|
||||
return self.test_statement._default_dialect()
|
||||
|
||||
def compile(self, dialect, **kw):
|
||||
return self.test_statement.compile.__func__(
|
||||
self, dialect=dialect, **kw
|
||||
)
|
||||
|
||||
def _compiler(self, dialect, **kw):
|
||||
return self.test_statement._compiler.__func__(
|
||||
self, dialect, **kw
|
||||
)
|
||||
|
||||
def _compiler_dispatch(self, compiler, **kwargs):
|
||||
if hasattr(compiler, "statement"):
|
||||
with mock.patch.object(
|
||||
compiler, "statement", DontAccess()
|
||||
):
|
||||
return self.test_statement._compiler_dispatch(
|
||||
compiler, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.test_statement._compiler_dispatch(
|
||||
compiler, **kwargs
|
||||
)
|
||||
|
||||
# no construct can assume it's the "top level" construct in all cases
|
||||
# as anything can be nested. ensure constructs don't assume they
|
||||
# are the "self.statement" element
|
||||
c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
|
||||
|
||||
if isinstance(clause, sqltypes.TypeEngine):
|
||||
cache_key_no_warnings = clause._static_cache_key
|
||||
if cache_key_no_warnings:
|
||||
hash(cache_key_no_warnings)
|
||||
else:
|
||||
cache_key_no_warnings = clause._generate_cache_key()
|
||||
if cache_key_no_warnings:
|
||||
hash(cache_key_no_warnings[0])
|
||||
|
||||
param_str = repr(getattr(c, "params", {}))
|
||||
param_str = param_str.encode("utf-8").decode("ascii", "ignore")
|
||||
print(("\nSQL String:\n" + str(c) + param_str).encode("utf-8"))
|
||||
|
||||
cc = re.sub(r"[\n\t]", "", str(c))
|
||||
|
||||
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
|
||||
|
||||
if checkparams is not None:
|
||||
if render_postcompile:
|
||||
expanded_state = c.construct_expanded_state(
|
||||
params, escape_names=False
|
||||
)
|
||||
eq_(expanded_state.parameters, checkparams)
|
||||
else:
|
||||
eq_(c.construct_params(params), checkparams)
|
||||
if checkpositional is not None:
|
||||
if render_postcompile:
|
||||
expanded_state = c.construct_expanded_state(
|
||||
params, escape_names=False
|
||||
)
|
||||
eq_(
|
||||
tuple(
|
||||
[
|
||||
expanded_state.parameters[x]
|
||||
for x in expanded_state.positiontup
|
||||
]
|
||||
),
|
||||
checkpositional,
|
||||
)
|
||||
else:
|
||||
p = c.construct_params(params, escape_names=False)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
if check_prefetch is not None:
|
||||
eq_(c.prefetch, check_prefetch)
|
||||
if check_literal_execute is not None:
|
||||
eq_(
|
||||
{
|
||||
c.bind_names[b]: b.effective_value
|
||||
for b in c.literal_execute_params
|
||||
},
|
||||
check_literal_execute,
|
||||
)
|
||||
if check_post_param is not None:
|
||||
eq_(
|
||||
{
|
||||
c.bind_names[b]: b.effective_value
|
||||
for b in c.post_compile_params
|
||||
},
|
||||
check_post_param,
|
||||
)
|
||||
if check_param_order and getattr(c, "params", None):
|
||||
|
||||
def get_dialect(paramstyle, positional):
|
||||
cp = copy(dialect)
|
||||
cp.paramstyle = paramstyle
|
||||
cp.positional = positional
|
||||
return cp
|
||||
|
||||
pyformat_dialect = get_dialect("pyformat", False)
|
||||
pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
|
||||
stmt = re.sub(r"[\n\t]", "", str(pyformat_c))
|
||||
|
||||
qmark_dialect = get_dialect("qmark", True)
|
||||
qmark_c = clause.compile(dialect=qmark_dialect, **kw)
|
||||
values = list(qmark_c.positiontup)
|
||||
escaped = qmark_c.escaped_bind_names
|
||||
|
||||
for post_param in (
|
||||
qmark_c.post_compile_params | qmark_c.literal_execute_params
|
||||
):
|
||||
name = qmark_c.bind_names[post_param]
|
||||
if name in values:
|
||||
values = [v for v in values if v != name]
|
||||
positions = []
|
||||
pos_by_value = defaultdict(list)
|
||||
for v in values:
|
||||
try:
|
||||
if v in pos_by_value:
|
||||
start = pos_by_value[v][-1]
|
||||
else:
|
||||
start = 0
|
||||
esc = escaped.get(v, v)
|
||||
pos = stmt.index("%%(%s)s" % (esc,), start) + 2
|
||||
positions.append(pos)
|
||||
pos_by_value[v].append(pos)
|
||||
except ValueError:
|
||||
msg = "Expected to find bindparam %r in %r" % (v, stmt)
|
||||
assert False, msg
|
||||
|
||||
ordered = all(
|
||||
positions[i - 1] < positions[i]
|
||||
for i in range(1, len(positions))
|
||||
)
|
||||
|
||||
expected = [v for _, v in sorted(zip(positions, values))]
|
||||
|
||||
msg = (
|
||||
"Order of parameters %s does not match the order "
|
||||
"in the statement %s. Statement %r" % (values, expected, stmt)
|
||||
)
|
||||
|
||||
is_true(ordered, msg)
|
||||
|
||||
|
||||
class ComparesTables:
|
||||
def assert_tables_equal(
|
||||
self,
|
||||
table,
|
||||
reflected_table,
|
||||
strict_types=False,
|
||||
strict_constraints=True,
|
||||
):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
eq_(c.name, reflected_c.name)
|
||||
assert reflected_c is reflected_table.c[c.name]
|
||||
|
||||
if strict_constraints:
|
||||
eq_(c.primary_key, reflected_c.primary_key)
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
msg = "Type '%s' doesn't correspond to type '%s'"
|
||||
assert isinstance(reflected_c.type, type(c.type)), msg % (
|
||||
reflected_c.type,
|
||||
c.type,
|
||||
)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
if strict_constraints:
|
||||
eq_(
|
||||
{f.column.name for f in c.foreign_keys},
|
||||
{f.column.name for f in reflected_c.foreign_keys},
|
||||
)
|
||||
if c.server_default:
|
||||
assert isinstance(
|
||||
reflected_c.server_default, schema.FetchedValue
|
||||
)
|
||||
|
||||
if strict_constraints:
|
||||
assert len(table.primary_key) == len(reflected_table.primary_key)
|
||||
for c in table.primary_key:
|
||||
assert reflected_table.primary_key.columns[c.name] is not None
|
||||
|
||||
def assert_types_base(self, c1, c2):
|
||||
assert c1.type._compare_type_affinity(
|
||||
c2.type
|
||||
), "On column %r, type '%s' doesn't correspond to type '%s'" % (
|
||||
c1.name,
|
||||
c1.type,
|
||||
c2.type,
|
||||
)
|
||||
|
||||
|
||||
class AssertsExecutionResults:
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
print(repr(result))
|
||||
self.assert_list(result, class_, objects)
|
||||
|
||||
def assert_list(self, result, class_, list_):
|
||||
self.assert_(
|
||||
len(result) == len(list_),
|
||||
"result list is not the same size as test list, "
|
||||
+ "for class "
|
||||
+ class_.__name__,
|
||||
)
|
||||
for i in range(0, len(list_)):
|
||||
self.assert_row(class_, result[i], list_[i])
|
||||
|
||||
def assert_row(self, class_, rowobj, desc):
|
||||
self.assert_(
|
||||
rowobj.__class__ is class_, "item class is not " + repr(class_)
|
||||
)
|
||||
for key, value in desc.items():
|
||||
if isinstance(value, tuple):
|
||||
if isinstance(value[1], list):
|
||||
self.assert_list(getattr(rowobj, key), value[0], value[1])
|
||||
else:
|
||||
self.assert_row(value[0], getattr(rowobj, key), value[1])
|
||||
else:
|
||||
self.assert_(
|
||||
getattr(rowobj, key) == value,
|
||||
"attribute %s value %s does not match %s"
|
||||
% (key, getattr(rowobj, key), value),
|
||||
)
|
||||
|
||||
def assert_unordered_result(self, result, cls, *expected):
|
||||
"""As assert_result, but the order of objects is not considered.
|
||||
|
||||
The algorithm is very expensive but not a big deal for the small
|
||||
numbers of rows that the test suite manipulates.
|
||||
"""
|
||||
|
||||
class immutabledict(dict):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
found = util.IdentitySet(result)
|
||||
expected = {immutabledict(e) for e in expected}
|
||||
|
||||
for wrong in filterfalse(lambda o: isinstance(o, cls), found):
|
||||
fail(
|
||||
'Unexpected type "%s", expected "%s"'
|
||||
% (type(wrong).__name__, cls.__name__)
|
||||
)
|
||||
|
||||
if len(found) != len(expected):
|
||||
fail(
|
||||
'Unexpected object count "%s", expected "%s"'
|
||||
% (len(found), len(expected))
|
||||
)
|
||||
|
||||
NOVALUE = object()
|
||||
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.items():
|
||||
if isinstance(value, tuple):
|
||||
try:
|
||||
self.assert_unordered_result(
|
||||
getattr(obj, key), value[0], *value[1]
|
||||
)
|
||||
except AssertionError:
|
||||
return False
|
||||
else:
|
||||
if getattr(obj, key, NOVALUE) != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
for expected_item in expected:
|
||||
for found_item in found:
|
||||
if _compare_item(found_item, expected_item):
|
||||
found.remove(found_item)
|
||||
break
|
||||
else:
|
||||
fail(
|
||||
"Expected %s instance with attributes %s not found."
|
||||
% (cls.__name__, repr(expected_item))
|
||||
)
|
||||
return True
|
||||
|
||||
def sql_execution_asserter(self, db=None):
|
||||
if db is None:
|
||||
from . import db as db
|
||||
|
||||
return assertsql.assert_engine(db)
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
result = callable_()
|
||||
asserter.assert_(*rules)
|
||||
return result
|
||||
|
||||
def assert_sql(self, db, callable_, rules):
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(
|
||||
*[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
|
||||
)
|
||||
else:
|
||||
newrule = assertsql.CompiledSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
return self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
return self.assert_sql_execution(
|
||||
db, callable_, assertsql.CountStatements(count)
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_execution(self, db, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
yield
|
||||
asserter.assert_(*rules)
|
||||
|
||||
def assert_statement_count(self, db, count):
|
||||
return self.assert_execution(db, assertsql.CountStatements(count))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_statement_count_multi_db(self, dbs, counts):
|
||||
recs = [
|
||||
(self.sql_execution_asserter(db), db, count)
|
||||
for (db, count) in zip(dbs, counts)
|
||||
]
|
||||
asserters = []
|
||||
for ctx, db, count in recs:
|
||||
asserters.append(ctx.__enter__())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for asserter, (ctx, db, count) in zip(asserters, recs):
|
||||
ctx.__exit__(None, None, None)
|
||||
asserter.assert_(assertsql.CountStatements(count))
|
||||
|
||||
|
||||
class ComparesIndexes:
|
||||
def compare_table_index_with_expected(
|
||||
self, table: schema.Table, expected: list, dialect_name: str
|
||||
):
|
||||
eq_(len(table.indexes), len(expected))
|
||||
idx_dict = {idx.name: idx for idx in table.indexes}
|
||||
for exp in expected:
|
||||
idx = idx_dict[exp["name"]]
|
||||
eq_(idx.unique, exp["unique"])
|
||||
cols = [c for c in exp["column_names"] if c is not None]
|
||||
eq_(len(idx.columns), len(cols))
|
||||
for c in cols:
|
||||
is_true(c in idx.columns)
|
||||
exprs = exp.get("expressions")
|
||||
if exprs:
|
||||
eq_(len(idx.expressions), len(exprs))
|
||||
for idx_exp, expr, col in zip(
|
||||
idx.expressions, exprs, exp["column_names"]
|
||||
):
|
||||
if col is None:
|
||||
eq_(idx_exp.text, expr)
|
||||
if (
|
||||
exp.get("dialect_options")
|
||||
and f"{dialect_name}_include" in exp["dialect_options"]
|
||||
):
|
||||
eq_(
|
||||
idx.dialect_options[dialect_name]["include"],
|
||||
exp["dialect_options"][f"{dialect_name}_include"],
|
||||
)
|
|
@ -0,0 +1,516 @@
|
|||
# testing/assertsql.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import itertools
|
||||
import re
|
||||
|
||||
from .. import event
|
||||
from ..engine import url
|
||||
from ..engine.default import DefaultDialect
|
||||
from ..schema import BaseDDLElement
|
||||
|
||||
|
||||
class AssertRule:
|
||||
is_consumed = False
|
||||
errormessage = None
|
||||
consume_statement = True
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
pass
|
||||
|
||||
def no_more_statements(self):
|
||||
assert False, (
|
||||
"All statements are complete, but pending "
|
||||
"assertion rules remain"
|
||||
)
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
pass
|
||||
|
||||
|
||||
class CursorSQL(SQLMatchRule):
|
||||
def __init__(self, statement, params=None, consume_statement=True):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.consume_statement = consume_statement
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
stmt = execute_observed.statements[0]
|
||||
if self.statement != stmt.statement or (
|
||||
self.params is not None and self.params != stmt.parameters
|
||||
):
|
||||
self.consume_statement = True
|
||||
self.errormessage = (
|
||||
"Testing for exact SQL %s parameters %s received %s %s"
|
||||
% (
|
||||
self.statement,
|
||||
self.params,
|
||||
stmt.statement,
|
||||
stmt.parameters,
|
||||
)
|
||||
)
|
||||
else:
|
||||
execute_observed.statements.pop(0)
|
||||
self.is_consumed = True
|
||||
if not execute_observed.statements:
|
||||
self.consume_statement = True
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
def __init__(
|
||||
self, statement, params=None, dialect="default", enable_returning=True
|
||||
):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
self.enable_returning = enable_returning
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r"[\n\t]", "", self.statement)
|
||||
return received_statement == stmt
|
||||
|
||||
def _compile_dialect(self, execute_observed):
|
||||
if self.dialect == "default":
|
||||
dialect = DefaultDialect()
|
||||
# this is currently what tests are expecting
|
||||
# dialect.supports_default_values = True
|
||||
dialect.supports_default_metavalue = True
|
||||
|
||||
if self.enable_returning:
|
||||
dialect.insert_returning = dialect.update_returning = (
|
||||
dialect.delete_returning
|
||||
) = True
|
||||
dialect.use_insertmanyvalues = True
|
||||
dialect.supports_multivalues_insert = True
|
||||
dialect.update_returning_multifrom = True
|
||||
dialect.delete_returning_multifrom = True
|
||||
# dialect.favor_returning_over_lastrowid = True
|
||||
# dialect.insert_null_pk_still_autoincrements = True
|
||||
|
||||
# this is calculated but we need it to be True for this
|
||||
# to look like all the current RETURNING dialects
|
||||
assert dialect.insert_executemany_returning
|
||||
|
||||
return dialect
|
||||
else:
|
||||
return url.URL.create(self.dialect).get_dialect()()
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
"""reconstruct the statement and params in terms
|
||||
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
||||
|
||||
context = execute_observed.context
|
||||
compare_dialect = self._compile_dialect(execute_observed)
|
||||
|
||||
# received_statement runs a full compile(). we should not need to
|
||||
# consider extracted_parameters; if we do this indicates some state
|
||||
# is being sent from a previous cached query, which some misbehaviors
|
||||
# in the ORM can cause, see #6881
|
||||
cache_key = None # execute_observed.context.compiled.cache_key
|
||||
extracted_parameters = (
|
||||
None # execute_observed.context.extracted_parameters
|
||||
)
|
||||
|
||||
if "schema_translate_map" in context.execution_options:
|
||||
map_ = context.execution_options["schema_translate_map"]
|
||||
else:
|
||||
map_ = None
|
||||
|
||||
if isinstance(execute_observed.clauseelement, BaseDDLElement):
|
||||
compiled = execute_observed.clauseelement.compile(
|
||||
dialect=compare_dialect,
|
||||
schema_translate_map=map_,
|
||||
)
|
||||
else:
|
||||
compiled = execute_observed.clauseelement.compile(
|
||||
cache_key=cache_key,
|
||||
dialect=compare_dialect,
|
||||
column_keys=context.compiled.column_keys,
|
||||
for_executemany=context.compiled.for_executemany,
|
||||
schema_translate_map=map_,
|
||||
)
|
||||
_received_statement = re.sub(r"[\n\t]", "", str(compiled))
|
||||
parameters = execute_observed.parameters
|
||||
|
||||
if not parameters:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(
|
||||
extracted_parameters=extracted_parameters
|
||||
)
|
||||
]
|
||||
else:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(
|
||||
m, extracted_parameters=extracted_parameters
|
||||
)
|
||||
for m in parameters
|
||||
]
|
||||
|
||||
return _received_statement, _received_parameters
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
context = execute_observed.context
|
||||
|
||||
_received_statement, _received_parameters = self._received_statement(
|
||||
execute_observed
|
||||
)
|
||||
params = self._all_params(context)
|
||||
|
||||
equivalent = self._compare_sql(execute_observed, _received_statement)
|
||||
|
||||
if equivalent:
|
||||
if params is not None:
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while all_params and all_received:
|
||||
param = dict(all_params.pop(0))
|
||||
|
||||
for idx, received in enumerate(list(all_received)):
|
||||
# do a positive compare only
|
||||
for param_key in param:
|
||||
# a key in param did not match current
|
||||
# 'received'
|
||||
if (
|
||||
param_key not in received
|
||||
or received[param_key] != param[param_key]
|
||||
):
|
||||
break
|
||||
else:
|
||||
# all keys in param matched 'received';
|
||||
# onto next param
|
||||
del all_received[idx]
|
||||
break
|
||||
else:
|
||||
# param did not match any entry
|
||||
# in all_received
|
||||
equivalent = False
|
||||
break
|
||||
if all_params or all_received:
|
||||
equivalent = False
|
||||
|
||||
if equivalent:
|
||||
self.is_consumed = True
|
||||
self.errormessage = None
|
||||
else:
|
||||
self.errormessage = self._failure_message(
|
||||
execute_observed, params
|
||||
) % {
|
||||
"received_statement": _received_statement,
|
||||
"received_parameters": _received_parameters,
|
||||
}
|
||||
|
||||
def _all_params(self, context):
|
||||
if self.params:
|
||||
if callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
return params
|
||||
else:
|
||||
return None
|
||||
|
||||
def _failure_message(self, execute_observed, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement\n%r partial params %s, "
|
||||
"received\n%%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self.statement.replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RegexSQL(CompiledSQL):
|
||||
def __init__(
|
||||
self, regex, params=None, dialect="default", enable_returning=False
|
||||
):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
self.enable_returning = enable_returning
|
||||
|
||||
def _failure_message(self, execute_observed, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement ~%r partial params %s, "
|
||||
"received %%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self.orig_regex.replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
return bool(self.regex.match(received_statement))
|
||||
|
||||
|
||||
class DialectSQL(CompiledSQL):
|
||||
def _compile_dialect(self, execute_observed):
|
||||
return execute_observed.context.dialect
|
||||
|
||||
def _compare_no_space(self, real_stmt, received_stmt):
|
||||
stmt = re.sub(r"[\n\t]", "", real_stmt)
|
||||
return received_stmt == stmt
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
received_stmt, received_params = super()._received_statement(
|
||||
execute_observed
|
||||
)
|
||||
|
||||
# TODO: why do we need this part?
|
||||
for real_stmt in execute_observed.statements:
|
||||
if self._compare_no_space(real_stmt.statement, received_stmt):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Can't locate compiled statement %r in list of "
|
||||
"statements actually invoked" % received_stmt
|
||||
)
|
||||
|
||||
return received_stmt, execute_observed.context.compiled_parameters
|
||||
|
||||
def _dialect_adjusted_statement(self, dialect):
|
||||
paramstyle = dialect.paramstyle
|
||||
stmt = re.sub(r"[\n\t]", "", self.statement)
|
||||
|
||||
# temporarily escape out PG double colons
|
||||
stmt = stmt.replace("::", "!!")
|
||||
|
||||
if paramstyle == "pyformat":
|
||||
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == "qmark":
|
||||
repl = "?"
|
||||
elif paramstyle == "format":
|
||||
repl = r"%s"
|
||||
elif paramstyle.startswith("numeric"):
|
||||
counter = itertools.count(1)
|
||||
|
||||
num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
|
||||
|
||||
def repl(m):
|
||||
return f"{num_identifier}{next(counter)}"
|
||||
|
||||
stmt = re.sub(r":([\w_]+)", repl, stmt)
|
||||
|
||||
# put them back
|
||||
stmt = stmt.replace("!!", "::")
|
||||
|
||||
return stmt
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = self._dialect_adjusted_statement(
|
||||
execute_observed.context.dialect
|
||||
)
|
||||
return received_statement == stmt
|
||||
|
||||
def _failure_message(self, execute_observed, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement\n%r partial params %s, "
|
||||
"received\n%%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self._dialect_adjusted_statement(
|
||||
execute_observed.context.dialect
|
||||
).replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
def __init__(self, count):
|
||||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
self._statement_count += 1
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.count != self._statement_count:
|
||||
assert False, "desired statement count %d does not match %d" % (
|
||||
self.count,
|
||||
self._statement_count,
|
||||
)
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in list(self.rules):
|
||||
rule.errormessage = None
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.discard(rule)
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
break
|
||||
elif not rule.errormessage:
|
||||
# rule is not done yet
|
||||
self.errormessage = None
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class EachOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = list(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
self.consume_statement = False
|
||||
|
||||
while self.rules:
|
||||
rule = self.rules[0]
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.pop(0)
|
||||
elif rule.errormessage:
|
||||
self.errormessage = rule.errormessage
|
||||
if rule.consume_statement:
|
||||
break
|
||||
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.rules and not self.rules[0].is_consumed:
|
||||
self.rules[0].no_more_statements()
|
||||
elif self.rules:
|
||||
super().no_more_statements()
|
||||
|
||||
|
||||
class Conditional(EachOf):
|
||||
def __init__(self, condition, rules, else_rules):
|
||||
if condition:
|
||||
super().__init__(*rules)
|
||||
else:
|
||||
super().__init__(*else_rules)
|
||||
|
||||
|
||||
class Or(AllOf):
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in self.rules:
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.is_consumed = True
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class SQLExecuteObserved:
|
||||
def __init__(self, context, clauseelement, multiparams, params):
|
||||
self.context = context
|
||||
self.clauseelement = clauseelement
|
||||
|
||||
if multiparams:
|
||||
self.parameters = multiparams
|
||||
elif params:
|
||||
self.parameters = [params]
|
||||
else:
|
||||
self.parameters = []
|
||||
self.statements = []
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.statements)
|
||||
|
||||
|
||||
class SQLCursorExecuteObserved(
|
||||
collections.namedtuple(
|
||||
"SQLCursorExecuteObserved",
|
||||
["statement", "parameters", "context", "executemany"],
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAsserter:
|
||||
def __init__(self):
|
||||
self.accumulated = []
|
||||
|
||||
def _close(self):
|
||||
self._final = self.accumulated
|
||||
del self.accumulated
|
||||
|
||||
def assert_(self, *rules):
|
||||
rule = EachOf(*rules)
|
||||
|
||||
observed = list(self._final)
|
||||
while observed:
|
||||
statement = observed.pop(0)
|
||||
rule.process_statement(statement)
|
||||
if rule.is_consumed:
|
||||
break
|
||||
elif rule.errormessage:
|
||||
assert False, rule.errormessage
|
||||
if observed:
|
||||
assert False, "Additional SQL statements remain:\n%s" % observed
|
||||
elif not rule.is_consumed:
|
||||
rule.no_more_statements()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_engine(engine):
|
||||
asserter = SQLAsserter()
|
||||
|
||||
orig = []
|
||||
|
||||
@event.listens_for(engine, "before_execute")
|
||||
def connection_execute(
|
||||
conn, clauseelement, multiparams, params, execution_options
|
||||
):
|
||||
# grab the original statement + params before any cursor
|
||||
# execution
|
||||
orig[:] = clauseelement, multiparams, params
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def cursor_execute(
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
if not context:
|
||||
return
|
||||
# then grab real cursor statements and associate them all
|
||||
# around a single context
|
||||
if (
|
||||
asserter.accumulated
|
||||
and asserter.accumulated[-1].context is context
|
||||
):
|
||||
obs = asserter.accumulated[-1]
|
||||
else:
|
||||
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
||||
asserter.accumulated.append(obs)
|
||||
obs.statements.append(
|
||||
SQLCursorExecuteObserved(
|
||||
statement, parameters, context, executemany
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
yield asserter
|
||||
finally:
|
||||
event.remove(engine, "after_cursor_execute", cursor_execute)
|
||||
event.remove(engine, "before_execute", connection_execute)
|
||||
asserter._close()
|
|
@ -0,0 +1,130 @@
|
|||
# testing/asyncio.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
# functions and wrappers to run tests, fixtures, provisioning and
|
||||
# setup/teardown in an asyncio event loop, conditionally based on the
|
||||
# current DB driver being used for a test.
|
||||
|
||||
# note that SQLAlchemy's asyncio integration also supports a method
|
||||
# of running individual asyncio functions inside of separate event loops
|
||||
# using "async_fallback" mode; however running whole functions in the event
|
||||
# loop is a more accurate test for how SQLAlchemy's asyncio features
|
||||
# would run in the real world.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
import inspect
|
||||
|
||||
from . import config
|
||||
from ..util.concurrency import _util_async_run
|
||||
from ..util.concurrency import _util_async_run_coroutine_function
|
||||
|
||||
# may be set to False if the
|
||||
# --disable-asyncio flag is passed to the test runner.
|
||||
ENABLE_ASYNCIO = True
|
||||
|
||||
|
||||
def _run_coroutine_function(fn, *args, **kwargs):
|
||||
return _util_async_run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _assume_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop unconditionally.
|
||||
|
||||
This function is used for provisioning features like
|
||||
testing a database connection for server info.
|
||||
|
||||
Note that for blocking IO database drivers, this means they block the
|
||||
event loop.
|
||||
|
||||
"""
|
||||
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _util_async_run(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async_provisioning(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop if any current drivers might need it.
|
||||
|
||||
This function is used for provisioning features that take
|
||||
place outside of a specific database driver being selected, so if the
|
||||
current driver that happens to be used for the provisioning operation
|
||||
is an async driver, it will run in asyncio and not fail.
|
||||
|
||||
Note that for blocking IO database drivers, this means they block the
|
||||
event loop.
|
||||
|
||||
"""
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if config.any_async:
|
||||
return _util_async_run(fn, *args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop if the current selected driver is
|
||||
async.
|
||||
|
||||
This function is used for test setup/teardown and tests themselves
|
||||
where the current DB driver is known.
|
||||
|
||||
|
||||
"""
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
is_async = config._current.is_async
|
||||
|
||||
if is_async:
|
||||
return _util_async_run(fn, *args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async_wrapper(fn):
|
||||
"""Apply the _maybe_async function to an existing function and return
|
||||
as a wrapped callable, supporting generator functions as well.
|
||||
|
||||
This is currently used for pytest fixtures that support generator use.
|
||||
|
||||
"""
|
||||
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
_stop = object()
|
||||
|
||||
def call_next(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
# can't raise StopIteration in an awaitable.
|
||||
except StopIteration:
|
||||
return _stop
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fixture(*args, **kwargs):
|
||||
gen = fn(*args, **kwargs)
|
||||
while True:
|
||||
value = _maybe_async(call_next, gen)
|
||||
if value is _stop:
|
||||
break
|
||||
yield value
|
||||
|
||||
else:
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fixture(*args, **kwargs):
|
||||
return _maybe_async(fn, *args, **kwargs)
|
||||
|
||||
return wrap_fixture
|
|
@ -0,0 +1,427 @@
|
|||
# testing/config.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import Namespace
|
||||
import collections
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Iterable
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from . import mock
|
||||
from . import requirements as _requirements
|
||||
from .util import fail
|
||||
from .. import util
|
||||
|
||||
# default requirements; this is replaced by plugin_base when pytest
|
||||
# is run
|
||||
requirements = _requirements.SuiteRequirements()
|
||||
|
||||
db = None
|
||||
db_url = None
|
||||
db_opts = None
|
||||
file_config = None
|
||||
test_schema = None
|
||||
test_schema_2 = None
|
||||
any_async = False
|
||||
_current = None
|
||||
ident = "main"
|
||||
options: Namespace = None # type: ignore
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .plugin.plugin_base import FixtureFunctions
|
||||
|
||||
_fixture_functions: FixtureFunctions
|
||||
else:
|
||||
|
||||
class _NullFixtureFunctions:
|
||||
def _null_decorator(self):
|
||||
def go(fn):
|
||||
return fn
|
||||
|
||||
return go
|
||||
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
return Exception()
|
||||
|
||||
@property
|
||||
def add_to_marker(self):
|
||||
return mock.Mock()
|
||||
|
||||
def mark_base_test_class(self):
|
||||
return self._null_decorator()
|
||||
|
||||
def combinations(self, *arg_sets, **kw):
|
||||
return self._null_decorator()
|
||||
|
||||
def param_ident(self, *parameters):
|
||||
return self._null_decorator()
|
||||
|
||||
def fixture(self, *arg, **kw):
|
||||
return self._null_decorator()
|
||||
|
||||
def get_current_test_name(self):
|
||||
return None
|
||||
|
||||
def async_test(self, fn):
|
||||
return fn
|
||||
|
||||
# default fixture functions; these are replaced by plugin_base when
|
||||
# pytest runs
|
||||
_fixture_functions = _NullFixtureFunctions()
|
||||
|
||||
|
||||
_FN = TypeVar("_FN", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def combinations(
|
||||
*comb: Union[Any, Tuple[Any, ...]],
|
||||
argnames: Optional[str] = None,
|
||||
id_: Optional[str] = None,
|
||||
**kw: str,
|
||||
) -> Callable[[_FN], _FN]:
|
||||
r"""Deliver multiple versions of a test based on positional combinations.
|
||||
|
||||
This is a facade over pytest.mark.parametrize.
|
||||
|
||||
|
||||
:param \*comb: argument combinations. These are tuples that will be passed
|
||||
positionally to the decorated function.
|
||||
|
||||
:param argnames: optional list of argument names. These are the names
|
||||
of the arguments in the test function that correspond to the entries
|
||||
in each argument tuple. pytest.mark.parametrize requires this, however
|
||||
the combinations function will derive it automatically if not present
|
||||
by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
|
||||
first argument is "self" which is discarded.
|
||||
|
||||
:param id\_: optional id template. This is a string template that
|
||||
describes how the "id" for each parameter set should be defined, if any.
|
||||
The number of characters in the template should match the number of
|
||||
entries in each argument tuple. Each character describes how the
|
||||
corresponding entry in the argument tuple should be handled, as far as
|
||||
whether or not it is included in the arguments passed to the function, as
|
||||
well as if it is included in the tokens used to create the id of the
|
||||
parameter set.
|
||||
|
||||
If omitted, the argument combinations are passed to parametrize as is. If
|
||||
passed, each argument combination is turned into a pytest.param() object,
|
||||
mapping the elements of the argument tuple to produce an id based on a
|
||||
character value in the same position within the string template using the
|
||||
following scheme::
|
||||
|
||||
i - the given argument is a string that is part of the id only, don't
|
||||
pass it as an argument
|
||||
|
||||
n - the given argument should be passed and it should be added to the
|
||||
id by calling the .__name__ attribute
|
||||
|
||||
r - the given argument should be passed and it should be added to the
|
||||
id by calling repr()
|
||||
|
||||
s - the given argument should be passed and it should be added to the
|
||||
id by calling str()
|
||||
|
||||
a - (argument) the given argument should be passed and it should not
|
||||
be used to generated the id
|
||||
|
||||
e.g.::
|
||||
|
||||
@testing.combinations(
|
||||
(operator.eq, "eq"),
|
||||
(operator.ne, "ne"),
|
||||
(operator.gt, "gt"),
|
||||
(operator.lt, "lt"),
|
||||
id_="na"
|
||||
)
|
||||
def test_operator(self, opfunc, name):
|
||||
pass
|
||||
|
||||
The above combination will call ``.__name__`` on the first member of
|
||||
each tuple and use that as the "id" to pytest.param().
|
||||
|
||||
|
||||
"""
|
||||
return _fixture_functions.combinations(
|
||||
*comb, id_=id_, argnames=argnames, **kw
|
||||
)
|
||||
|
||||
|
||||
def combinations_list(arg_iterable: Iterable[Tuple[Any, ...]], **kw):
|
||||
"As combination, but takes a single iterable"
|
||||
return combinations(*arg_iterable, **kw)
|
||||
|
||||
|
||||
class Variation:
|
||||
__slots__ = ("_name", "_argname")
|
||||
|
||||
def __init__(self, case, argname, case_names):
|
||||
self._name = case
|
||||
self._argname = argname
|
||||
for casename in case_names:
|
||||
setattr(self, casename, casename == case)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, key: str) -> bool: ...
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def __bool__(self):
|
||||
return self._name == self._argname
|
||||
|
||||
def __nonzero__(self):
|
||||
return not self.__bool__()
|
||||
|
||||
def __str__(self):
|
||||
return f"{self._argname}={self._name!r}"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def fail(self) -> NoReturn:
|
||||
fail(f"Unknown {self}")
|
||||
|
||||
@classmethod
|
||||
def idfn(cls, variation):
|
||||
return variation.name
|
||||
|
||||
@classmethod
|
||||
def generate_cases(cls, argname, cases):
|
||||
case_names = [
|
||||
argname if c is True else "not_" + argname if c is False else c
|
||||
for c in cases
|
||||
]
|
||||
|
||||
typ = type(
|
||||
argname,
|
||||
(Variation,),
|
||||
{
|
||||
"__slots__": tuple(case_names),
|
||||
},
|
||||
)
|
||||
|
||||
return [typ(casename, argname, case_names) for casename in case_names]
|
||||
|
||||
|
||||
def variation(argname_or_fn, cases=None):
|
||||
"""a helper around testing.combinations that provides a single namespace
|
||||
that can be used as a switch.
|
||||
|
||||
e.g.::
|
||||
|
||||
@testing.variation("querytyp", ["select", "subquery", "legacy_query"])
|
||||
@testing.variation("lazy", ["select", "raise", "raise_on_sql"])
|
||||
def test_thing(
|
||||
self,
|
||||
querytyp,
|
||||
lazy,
|
||||
decl_base
|
||||
):
|
||||
class Thing(decl_base):
|
||||
__tablename__ = 'thing'
|
||||
|
||||
# use name directly
|
||||
rel = relationship("Rel", lazy=lazy.name)
|
||||
|
||||
# use as a switch
|
||||
if querytyp.select:
|
||||
stmt = select(Thing)
|
||||
elif querytyp.subquery:
|
||||
stmt = select(Thing).subquery()
|
||||
elif querytyp.legacy_query:
|
||||
stmt = Session.query(Thing)
|
||||
else:
|
||||
querytyp.fail()
|
||||
|
||||
|
||||
The variable provided is a slots object of boolean variables, as well
|
||||
as the name of the case itself under the attribute ".name"
|
||||
|
||||
"""
|
||||
|
||||
if inspect.isfunction(argname_or_fn):
|
||||
argname = argname_or_fn.__name__
|
||||
cases = argname_or_fn(None)
|
||||
|
||||
@variation_fixture(argname, cases)
|
||||
def go(self, request):
|
||||
yield request.param
|
||||
|
||||
return go
|
||||
else:
|
||||
argname = argname_or_fn
|
||||
cases_plus_limitations = [
|
||||
(
|
||||
entry
|
||||
if (isinstance(entry, tuple) and len(entry) == 2)
|
||||
else (entry, None)
|
||||
)
|
||||
for entry in cases
|
||||
]
|
||||
|
||||
variations = Variation.generate_cases(
|
||||
argname, [c for c, l in cases_plus_limitations]
|
||||
)
|
||||
return combinations(
|
||||
*[
|
||||
(
|
||||
(variation._name, variation, limitation)
|
||||
if limitation is not None
|
||||
else (variation._name, variation)
|
||||
)
|
||||
for variation, (case, limitation) in zip(
|
||||
variations, cases_plus_limitations
|
||||
)
|
||||
],
|
||||
id_="ia",
|
||||
argnames=argname,
|
||||
)
|
||||
|
||||
|
||||
def variation_fixture(argname, cases, scope="function"):
|
||||
return fixture(
|
||||
params=Variation.generate_cases(argname, cases),
|
||||
ids=Variation.idfn,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
|
||||
def fixture(*arg: Any, **kw: Any) -> Any:
|
||||
return _fixture_functions.fixture(*arg, **kw)
|
||||
|
||||
|
||||
def get_current_test_name() -> str:
|
||||
return _fixture_functions.get_current_test_name()
|
||||
|
||||
|
||||
def mark_base_test_class() -> Any:
|
||||
return _fixture_functions.mark_base_test_class()
|
||||
|
||||
|
||||
class _AddToMarker:
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
return getattr(_fixture_functions.add_to_marker, attr)
|
||||
|
||||
|
||||
add_to_marker = _AddToMarker()
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, db, db_opts, options, file_config):
|
||||
self._set_name(db)
|
||||
self.db = db
|
||||
self.db_opts = db_opts
|
||||
self.options = options
|
||||
self.file_config = file_config
|
||||
self.test_schema = "test_schema"
|
||||
self.test_schema_2 = "test_schema_2"
|
||||
|
||||
self.is_async = db.dialect.is_async and not util.asbool(
|
||||
db.url.query.get("async_fallback", False)
|
||||
)
|
||||
|
||||
_stack = collections.deque()
|
||||
_configs = set()
|
||||
|
||||
def _set_name(self, db):
|
||||
suffix = "_async" if db.dialect.is_async else ""
|
||||
if db.dialect.server_version_info:
|
||||
svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
|
||||
self.name = "%s+%s%s_[%s]" % (db.name, db.driver, suffix, svi)
|
||||
else:
|
||||
self.name = "%s+%s%s" % (db.name, db.driver, suffix)
|
||||
|
||||
@classmethod
|
||||
def register(cls, db, db_opts, options, file_config):
|
||||
"""add a config as one of the global configs.
|
||||
|
||||
If there are no configs set up yet, this config also
|
||||
gets set as the "_current".
|
||||
"""
|
||||
global any_async
|
||||
|
||||
cfg = Config(db, db_opts, options, file_config)
|
||||
|
||||
# if any backends include an async driver, then ensure
|
||||
# all setup/teardown and tests are wrapped in the maybe_async()
|
||||
# decorator that will set up a greenlet context for async drivers.
|
||||
any_async = any_async or cfg.is_async
|
||||
|
||||
cls._configs.add(cfg)
|
||||
return cfg
|
||||
|
||||
@classmethod
|
||||
def set_as_current(cls, config, namespace):
|
||||
global db, _current, db_url, test_schema, test_schema_2, db_opts
|
||||
_current = config
|
||||
db_url = config.db.url
|
||||
db_opts = config.db_opts
|
||||
test_schema = config.test_schema
|
||||
test_schema_2 = config.test_schema_2
|
||||
namespace.db = db = config.db
|
||||
|
||||
@classmethod
|
||||
def push_engine(cls, db, namespace):
|
||||
assert _current, "Can't push without a default Config set up"
|
||||
cls.push(
|
||||
Config(
|
||||
db, _current.db_opts, _current.options, _current.file_config
|
||||
),
|
||||
namespace,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def push(cls, config, namespace):
|
||||
cls._stack.append(_current)
|
||||
cls.set_as_current(config, namespace)
|
||||
|
||||
@classmethod
|
||||
def pop(cls, namespace):
|
||||
if cls._stack:
|
||||
# a failed test w/ -x option can call reset() ahead of time
|
||||
_current = cls._stack[-1]
|
||||
del cls._stack[-1]
|
||||
cls.set_as_current(_current, namespace)
|
||||
|
||||
@classmethod
|
||||
def reset(cls, namespace):
|
||||
if cls._stack:
|
||||
cls.set_as_current(cls._stack[0], namespace)
|
||||
cls._stack.clear()
|
||||
|
||||
@classmethod
|
||||
def all_configs(cls):
|
||||
return cls._configs
|
||||
|
||||
@classmethod
|
||||
def all_dbs(cls):
|
||||
for cfg in cls.all_configs():
|
||||
yield cfg.db
|
||||
|
||||
def skip_test(self, msg):
|
||||
skip_test(msg)
|
||||
|
||||
|
||||
def skip_test(msg):
|
||||
raise _fixture_functions.skip_test_exception(msg)
|
||||
|
||||
|
||||
def async_test(fn):
|
||||
return _fixture_functions.async_test(fn)
|
|
@ -0,0 +1,467 @@
|
|||
# testing/engines.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import re
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .util import gc_collect
|
||||
from .. import event
|
||||
from .. import pool
|
||||
from ..util import await_only
|
||||
from ..util.typing import Literal
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
from ..engine.url import URL
|
||||
from ..ext.asyncio import AsyncEngine
|
||||
|
||||
|
||||
class ConnectionKiller:
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = collections.defaultdict(set)
|
||||
self.dbapi_connections = set()
|
||||
|
||||
def add_pool(self, pool):
|
||||
event.listen(pool, "checkout", self._add_conn)
|
||||
event.listen(pool, "checkin", self._remove_conn)
|
||||
event.listen(pool, "close", self._remove_conn)
|
||||
event.listen(pool, "close_detached", self._remove_conn)
|
||||
# note we are keeping "invalidated" here, as those are still
|
||||
# opened connections we would like to roll back
|
||||
|
||||
def _add_conn(self, dbapi_con, con_record, con_proxy):
|
||||
self.dbapi_connections.add(dbapi_con)
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def _remove_conn(self, dbapi_conn, *arg):
|
||||
self.dbapi_connections.discard(dbapi_conn)
|
||||
|
||||
def add_engine(self, engine, scope):
|
||||
self.add_pool(engine.pool)
|
||||
|
||||
assert scope in ("class", "global", "function", "fixture")
|
||||
self.testing_engines[scope].add(engine)
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't rollback/close connection: %s" % e
|
||||
)
|
||||
|
||||
def rollback_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec.rollback)
|
||||
|
||||
def checkin_all(self):
|
||||
# run pool.checkin() for all ConnectionFairy instances we have
|
||||
# tracked.
|
||||
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self.dbapi_connections.discard(rec.dbapi_connection)
|
||||
self._safe(rec._checkin)
|
||||
|
||||
# for fairy refs that were GCed and could not close the connection,
|
||||
# such as asyncio, roll back those remaining connections
|
||||
for con in self.dbapi_connections:
|
||||
self._safe(con.rollback)
|
||||
self.dbapi_connections.clear()
|
||||
|
||||
def close_all(self):
|
||||
self.checkin_all()
|
||||
|
||||
def prepare_for_drop_tables(self, connection):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
from . import provision
|
||||
|
||||
provision.prepare_for_drop_tables(connection.engine.url, connection)
|
||||
|
||||
def _drop_testing_engines(self, scope):
|
||||
eng = self.testing_engines[scope]
|
||||
for rec in list(eng):
|
||||
for proxy_ref in list(self.proxy_refs):
|
||||
if proxy_ref is not None and proxy_ref.is_valid:
|
||||
if (
|
||||
proxy_ref._pool is not None
|
||||
and proxy_ref._pool is rec.pool
|
||||
):
|
||||
self._safe(proxy_ref._checkin)
|
||||
|
||||
if hasattr(rec, "sync_engine"):
|
||||
await_only(rec.dispose())
|
||||
else:
|
||||
rec.dispose()
|
||||
eng.clear()
|
||||
|
||||
def after_test(self):
|
||||
self._drop_testing_engines("function")
|
||||
|
||||
def after_test_outside_fixtures(self, test):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
if test.__class__.__leave_connections_for_teardown__:
|
||||
return
|
||||
|
||||
self.checkin_all()
|
||||
|
||||
# on PostgreSQL, this will test for any "idle in transaction"
|
||||
# connections. useful to identify tests with unusual patterns
|
||||
# that can't be cleaned up correctly.
|
||||
from . import provision
|
||||
|
||||
with config.db.connect() as conn:
|
||||
provision.prepare_for_drop_tables(conn.engine.url, conn)
|
||||
|
||||
def stop_test_class_inside_fixtures(self):
|
||||
self.checkin_all()
|
||||
self._drop_testing_engines("function")
|
||||
self._drop_testing_engines("class")
|
||||
|
||||
def stop_test_class_outside_fixtures(self):
|
||||
# ensure no refs to checked out connections at all.
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
gc_collect()
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
ln = len(pool.base._strong_ref_connection_records)
|
||||
pool.base._strong_ref_connection_records.clear()
|
||||
assert (
|
||||
False
|
||||
), "%d connection recs not cleared after test suite" % (ln)
|
||||
|
||||
def final_cleanup(self):
|
||||
self.checkin_all()
|
||||
for scope in self.testing_engines:
|
||||
self._drop_testing_engines(scope)
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
|
||||
testing_reaper.checkin_all()
|
||||
fn(*args, **kw)
|
||||
|
||||
|
||||
@decorator
|
||||
def close_open_connections(fn, *args, **kw):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.checkin_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.dialects as d
|
||||
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(
|
||||
__import__("sqlalchemy.dialects.%s" % name).dialects, name
|
||||
)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture:
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
self.is_stopped = False
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
if self.is_stopped:
|
||||
self._safe(conn.close)
|
||||
curs = conn.cursor() # should fail on Oracle etc.
|
||||
# should fail for everything that didn't fail
|
||||
# above, connection is closed
|
||||
curs.execute("select 1")
|
||||
assert False, "simulated connect failure didn't work"
|
||||
else:
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
|
||||
|
||||
def shutdown(self, stop=False):
|
||||
# TODO: this doesn't cover all cases
|
||||
# as nicely as we'd like, namely MySQLdb.
|
||||
# would need to implement R. Brewer's
|
||||
# proxy server idea to get better
|
||||
# coverage.
|
||||
self.is_stopped = stop
|
||||
for c in list(self.connections):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
def restart(self):
|
||||
self.is_stopped = False
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db.url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options["module"] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
engine.dialect.dbapi.is_stopped = False
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.test_restart = engine.dialect.dbapi.restart
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
|
||||
|
||||
@typing.overload
|
||||
def testing_engine(
|
||||
url: Optional[URL] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
asyncio: Literal[False] = False,
|
||||
transfer_staticpool: bool = False,
|
||||
) -> Engine: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
def testing_engine(
|
||||
url: Optional[URL] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
asyncio: Literal[True] = True,
|
||||
transfer_staticpool: bool = False,
|
||||
) -> AsyncEngine: ...
|
||||
|
||||
|
||||
def testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
share_pool=False,
|
||||
_sqlite_savepoint=False,
|
||||
):
|
||||
if asyncio:
|
||||
assert not _sqlite_savepoint
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine as create_engine,
|
||||
)
|
||||
else:
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
scope = "function"
|
||||
sqlite_savepoint = False
|
||||
else:
|
||||
use_reaper = options.pop("use_reaper", True)
|
||||
scope = options.pop("scope", "function")
|
||||
sqlite_savepoint = options.pop("sqlite_savepoint", False)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
if options is None:
|
||||
if config.db is None or url.drivername == config.db.url.drivername:
|
||||
options = config.db_opts
|
||||
else:
|
||||
options = {}
|
||||
elif config.db is not None and url.drivername == config.db.url.drivername:
|
||||
default_opt = config.db_opts.copy()
|
||||
default_opt.update(options)
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
|
||||
if sqlite_savepoint and engine.name == "sqlite":
|
||||
# apply SQLite savepoint workaround
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
if transfer_staticpool:
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
if config.db is not None and isinstance(config.db.pool, StaticPool):
|
||||
use_reaper = False
|
||||
engine.pool._transfer_from(config.db.pool)
|
||||
elif share_pool:
|
||||
engine.pool = config.db.pool
|
||||
|
||||
if scope == "global":
|
||||
if asyncio:
|
||||
engine.sync_engine._has_events = True
|
||||
else:
|
||||
engine._has_events = (
|
||||
True # enable event blocks, helps with profiling
|
||||
)
|
||||
|
||||
if isinstance(engine.pool, pool.QueuePool):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
if use_reaper:
|
||||
testing_reaper.add_engine(engine, scope)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_mock_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
|
||||
|
||||
engine = create_mock_engine(dialect_name + "://", executor)
|
||||
assert not hasattr(engine, "mock")
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor:
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level cursor operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
return self.cursor.execute(stmt, parameters, **kw)
|
||||
else:
|
||||
return self.cursor.execute(stmt, **kw)
|
||||
|
||||
def executemany(self, stmt, params, **kw):
|
||||
return self.cursor.executemany(stmt, params, **kw)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection:
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level connection operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, cursor_cls):
|
||||
self.conn = conn
|
||||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
|
@ -0,0 +1,117 @@
|
|||
# testing/entities.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .. import exc as sa_exc
|
||||
from ..orm.writeonly import WriteOnlyCollection
|
||||
|
||||
_repr_stack = set()
|
||||
|
||||
|
||||
class BasicEntity:
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
if id(self) in _repr_stack:
|
||||
return object.__repr__(self)
|
||||
_repr_stack.add(id(self))
|
||||
try:
|
||||
return "%s(%s)" % (
|
||||
(self.__class__.__name__),
|
||||
", ".join(
|
||||
[
|
||||
"%s=%r" % (key, getattr(self, key))
|
||||
for key in sorted(self.__dict__.keys())
|
||||
if not key.startswith("_")
|
||||
]
|
||||
),
|
||||
)
|
||||
finally:
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
|
||||
_recursion_stack = set()
|
||||
|
||||
|
||||
class ComparableMixin:
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""'Deep, sparse compare.
|
||||
|
||||
Deeply compare two entities, following the non-None attributes of the
|
||||
non-persisted object, if possible.
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return True
|
||||
elif not self.__class__ == other.__class__:
|
||||
return False
|
||||
|
||||
if id(self) in _recursion_stack:
|
||||
return True
|
||||
_recursion_stack.add(id(self))
|
||||
|
||||
try:
|
||||
# pick the entity that's not SA persisted as the source
|
||||
try:
|
||||
self_key = sa.orm.attributes.instance_state(self).key
|
||||
except sa.orm.exc.NO_STATE:
|
||||
self_key = None
|
||||
|
||||
if other is None:
|
||||
a = self
|
||||
b = other
|
||||
elif self_key is not None:
|
||||
a = other
|
||||
b = self
|
||||
else:
|
||||
a = self
|
||||
b = other
|
||||
|
||||
for attr in list(a.__dict__):
|
||||
if attr.startswith("_"):
|
||||
continue
|
||||
|
||||
value = getattr(a, attr)
|
||||
|
||||
if isinstance(value, WriteOnlyCollection):
|
||||
continue
|
||||
|
||||
try:
|
||||
# handle lazy loader errors
|
||||
battr = getattr(b, attr)
|
||||
except (AttributeError, sa_exc.UnboundExecutionError):
|
||||
return False
|
||||
|
||||
if hasattr(value, "__iter__") and not isinstance(value, str):
|
||||
if hasattr(value, "__getitem__") and not hasattr(
|
||||
value, "keys"
|
||||
):
|
||||
if list(value) != list(battr):
|
||||
return False
|
||||
else:
|
||||
if set(value) != set(battr):
|
||||
return False
|
||||
else:
|
||||
if value is not None and value != battr:
|
||||
return False
|
||||
return True
|
||||
finally:
|
||||
_recursion_stack.remove(id(self))
|
||||
|
||||
|
||||
class ComparableEntity(ComparableMixin, BasicEntity):
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
|
@ -0,0 +1,435 @@
|
|||
# testing/exclusions.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import contextlib
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from .. import util
|
||||
from ..util import decorator
|
||||
from ..util.compat import inspect_getfullargspec
|
||||
|
||||
|
||||
def skip_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.skips.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
def fails_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.fails.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
class compound:
|
||||
def __init__(self):
|
||||
self.fails = set()
|
||||
self.skips = set()
|
||||
|
||||
def __add__(self, other):
|
||||
return self.add(other)
|
||||
|
||||
def as_skips(self):
|
||||
rule = compound()
|
||||
rule.skips.update(self.skips)
|
||||
rule.skips.update(self.fails)
|
||||
return rule
|
||||
|
||||
def add(self, *others):
|
||||
copy = compound()
|
||||
copy.fails.update(self.fails)
|
||||
copy.skips.update(self.skips)
|
||||
|
||||
for other in others:
|
||||
copy.fails.update(other.fails)
|
||||
copy.skips.update(other.skips)
|
||||
return copy
|
||||
|
||||
def not_(self):
|
||||
copy = compound()
|
||||
copy.fails.update(NotPredicate(fail) for fail in self.fails)
|
||||
copy.skips.update(NotPredicate(skip) for skip in self.skips)
|
||||
return copy
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.enabled_for_config(config._current)
|
||||
|
||||
def enabled_for_config(self, config):
|
||||
for predicate in self.skips.union(self.fails):
|
||||
if predicate(config):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def matching_config_reasons(self, config):
|
||||
return [
|
||||
predicate._as_string(config)
|
||||
for predicate in self.skips.union(self.fails)
|
||||
if predicate(config)
|
||||
]
|
||||
|
||||
def _extend(self, other):
|
||||
self.skips.update(other.skips)
|
||||
self.fails.update(other.fails)
|
||||
|
||||
def __call__(self, fn):
|
||||
if hasattr(fn, "_sa_exclusion_extend"):
|
||||
fn._sa_exclusion_extend._extend(self)
|
||||
return fn
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
return self._do(config._current, fn, *args, **kw)
|
||||
|
||||
decorated = decorate(fn)
|
||||
decorated._sa_exclusion_extend = self
|
||||
return decorated
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fail_if(self):
|
||||
all_fails = compound()
|
||||
all_fails.fails.update(self.skips.union(self.fails))
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception as ex:
|
||||
all_fails._expect_failure(config._current, ex)
|
||||
else:
|
||||
all_fails._expect_success(config._current)
|
||||
|
||||
def _do(self, cfg, fn, *args, **kw):
|
||||
for skip in self.skips:
|
||||
if skip(cfg):
|
||||
msg = "'%s' : %s" % (
|
||||
config.get_current_test_name(),
|
||||
skip._as_string(cfg),
|
||||
)
|
||||
config.skip_test(msg)
|
||||
|
||||
try:
|
||||
return_value = fn(*args, **kw)
|
||||
except Exception as ex:
|
||||
self._expect_failure(cfg, ex, name=fn.__name__)
|
||||
else:
|
||||
self._expect_success(cfg, name=fn.__name__)
|
||||
return return_value
|
||||
|
||||
def _expect_failure(self, config, ex, name="block"):
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
print(
|
||||
"%s failed as expected (%s): %s "
|
||||
% (name, fail._as_string(config), ex)
|
||||
)
|
||||
break
|
||||
else:
|
||||
raise ex.with_traceback(sys.exc_info()[2])
|
||||
|
||||
def _expect_success(self, config, name="block"):
|
||||
if not self.fails:
|
||||
return
|
||||
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (%s)"
|
||||
% (
|
||||
name,
|
||||
" and ".join(
|
||||
fail._as_string(config) for fail in self.fails
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def only_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return skip_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
def succeeds_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return fails_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
class Predicate:
|
||||
@classmethod
|
||||
def as_predicate(cls, predicate, description=None):
|
||||
if isinstance(predicate, compound):
|
||||
return cls.as_predicate(predicate.enabled_for_config, description)
|
||||
elif isinstance(predicate, Predicate):
|
||||
if description and predicate.description is None:
|
||||
predicate.description = description
|
||||
return predicate
|
||||
elif isinstance(predicate, (list, set)):
|
||||
return OrPredicate(
|
||||
[cls.as_predicate(pred) for pred in predicate], description
|
||||
)
|
||||
elif isinstance(predicate, tuple):
|
||||
return SpecPredicate(*predicate)
|
||||
elif isinstance(predicate, str):
|
||||
tokens = re.match(
|
||||
r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
|
||||
)
|
||||
if not tokens:
|
||||
raise ValueError(
|
||||
"Couldn't locate DB name in predicate: %r" % predicate
|
||||
)
|
||||
db = tokens.group(1)
|
||||
op = tokens.group(2)
|
||||
spec = (
|
||||
tuple(int(d) for d in tokens.group(3).split("."))
|
||||
if tokens.group(3)
|
||||
else None
|
||||
)
|
||||
|
||||
return SpecPredicate(db, op, spec, description=description)
|
||||
elif callable(predicate):
|
||||
return LambdaPredicate(predicate, description)
|
||||
else:
|
||||
assert False, "unknown predicate type: %s" % predicate
|
||||
|
||||
def _format_description(self, config, negate=False):
|
||||
bool_ = self(config)
|
||||
if negate:
|
||||
bool_ = not negate
|
||||
return self.description % {
|
||||
"driver": (
|
||||
config.db.url.get_driver_name() if config else "<no driver>"
|
||||
),
|
||||
"database": (
|
||||
config.db.url.get_backend_name() if config else "<no database>"
|
||||
),
|
||||
"doesnt_support": "doesn't support" if bool_ else "does support",
|
||||
"does_support": "does support" if bool_ else "doesn't support",
|
||||
}
|
||||
|
||||
def _as_string(self, config=None, negate=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BooleanPredicate(Predicate):
|
||||
def __init__(self, value, description=None):
|
||||
self.value = value
|
||||
self.description = description or "boolean %s" % value
|
||||
|
||||
def __call__(self, config):
|
||||
return self.value
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config, negate=negate)
|
||||
|
||||
|
||||
class SpecPredicate(Predicate):
|
||||
def __init__(self, db, op=None, spec=None, description=None):
|
||||
self.db = db
|
||||
self.op = op
|
||||
self.spec = spec
|
||||
self.description = description
|
||||
|
||||
_ops = {
|
||||
"<": operator.lt,
|
||||
">": operator.gt,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
"<=": operator.le,
|
||||
">=": operator.ge,
|
||||
"in": operator.contains,
|
||||
"between": lambda val, pair: val >= pair[0] and val <= pair[1],
|
||||
}
|
||||
|
||||
def __call__(self, config):
|
||||
if config is None:
|
||||
return False
|
||||
|
||||
engine = config.db
|
||||
|
||||
if "+" in self.db:
|
||||
dialect, driver = self.db.split("+")
|
||||
else:
|
||||
dialect, driver = self.db, None
|
||||
|
||||
if dialect and engine.name != dialect:
|
||||
return False
|
||||
if driver is not None and engine.driver != driver:
|
||||
return False
|
||||
|
||||
if self.op is not None:
|
||||
assert driver is None, "DBAPI version specs not supported yet"
|
||||
|
||||
version = _server_version(engine)
|
||||
oper = (
|
||||
hasattr(self.op, "__call__") and self.op or self._ops[self.op]
|
||||
)
|
||||
return oper(version, self.spec)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
elif self.op is None:
|
||||
if negate:
|
||||
return "not %s" % self.db
|
||||
else:
|
||||
return "%s" % self.db
|
||||
else:
|
||||
if negate:
|
||||
return "not %s %s %s" % (self.db, self.op, self.spec)
|
||||
else:
|
||||
return "%s %s %s" % (self.db, self.op, self.spec)
|
||||
|
||||
|
||||
class LambdaPredicate(Predicate):
|
||||
def __init__(self, lambda_, description=None, args=None, kw=None):
|
||||
spec = inspect_getfullargspec(lambda_)
|
||||
if not spec[0]:
|
||||
self.lambda_ = lambda db: lambda_()
|
||||
else:
|
||||
self.lambda_ = lambda_
|
||||
self.args = args or ()
|
||||
self.kw = kw or {}
|
||||
if description:
|
||||
self.description = description
|
||||
elif lambda_.__doc__:
|
||||
self.description = lambda_.__doc__
|
||||
else:
|
||||
self.description = "custom function"
|
||||
|
||||
def __call__(self, config):
|
||||
return self.lambda_(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config)
|
||||
|
||||
|
||||
class NotPredicate(Predicate):
|
||||
def __init__(self, predicate, description=None):
|
||||
self.predicate = predicate
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
return not self.predicate(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description:
|
||||
return self._format_description(config, not negate)
|
||||
else:
|
||||
return self.predicate._as_string(config, not negate)
|
||||
|
||||
|
||||
class OrPredicate(Predicate):
|
||||
def __init__(self, predicates, description=None):
|
||||
self.predicates = predicates
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
for pred in self.predicates:
|
||||
if pred(config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _eval_str(self, config, negate=False):
|
||||
if negate:
|
||||
conjunction = " and "
|
||||
else:
|
||||
conjunction = " or "
|
||||
return conjunction.join(
|
||||
p._as_string(config, negate=negate) for p in self.predicates
|
||||
)
|
||||
|
||||
def _negation_str(self, config):
|
||||
if self.description is not None:
|
||||
return "Not " + self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config, negate=True)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if negate:
|
||||
return self._negation_str(config)
|
||||
else:
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config)
|
||||
|
||||
|
||||
_as_predicate = Predicate.as_predicate
|
||||
|
||||
|
||||
def _is_excluded(db, op, spec):
|
||||
return SpecPredicate(db, op, spec)(config._current)
|
||||
|
||||
|
||||
def _server_version(engine):
|
||||
"""Return a server_version_info tuple."""
|
||||
|
||||
# force metadata to be retrieved
|
||||
conn = engine.connect()
|
||||
version = getattr(engine.dialect, "server_version_info", None)
|
||||
if version is None:
|
||||
version = ()
|
||||
conn.close()
|
||||
return version
|
||||
|
||||
|
||||
def db_spec(*dbs):
|
||||
return OrPredicate([Predicate.as_predicate(db) for db in dbs])
|
||||
|
||||
|
||||
def open(): # noqa
|
||||
return skip_if(BooleanPredicate(False, "mark as execute"))
|
||||
|
||||
|
||||
def closed():
|
||||
return skip_if(BooleanPredicate(True, "marked as skip"))
|
||||
|
||||
|
||||
def fails(reason=None):
|
||||
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
|
||||
|
||||
|
||||
def future():
|
||||
return fails_if(BooleanPredicate(True, "Future feature"))
|
||||
|
||||
|
||||
def fails_on(db, reason=None):
|
||||
return fails_if(db, reason)
|
||||
|
||||
|
||||
def fails_on_everything_except(*dbs):
|
||||
return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
|
||||
|
||||
|
||||
def skip(db, reason=None):
|
||||
return skip_if(db, reason)
|
||||
|
||||
|
||||
def only_on(dbs, reason=None):
|
||||
return only_if(
|
||||
OrPredicate(
|
||||
[Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def exclude(db, op, spec, reason=None):
|
||||
return skip_if(SpecPredicate(db, op, spec), reason)
|
||||
|
||||
|
||||
def against(config, *queries):
|
||||
assert queries, "no queries sent!"
|
||||
return OrPredicate([Predicate.as_predicate(query) for query in queries])(
|
||||
config
|
||||
)
|
|
@ -0,0 +1,28 @@
|
|||
# testing/fixtures/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from .base import FutureEngineMixin as FutureEngineMixin
|
||||
from .base import TestBase as TestBase
|
||||
from .mypy import MypyTest as MypyTest
|
||||
from .orm import after_test as after_test
|
||||
from .orm import close_all_sessions as close_all_sessions
|
||||
from .orm import DeclarativeMappedTest as DeclarativeMappedTest
|
||||
from .orm import fixture_session as fixture_session
|
||||
from .orm import MappedTest as MappedTest
|
||||
from .orm import ORMTest as ORMTest
|
||||
from .orm import RemoveORMEventsGlobally as RemoveORMEventsGlobally
|
||||
from .orm import (
|
||||
stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
|
||||
)
|
||||
from .sql import CacheKeyFixture as CacheKeyFixture
|
||||
from .sql import (
|
||||
ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
|
||||
)
|
||||
from .sql import insertmanyvalues_fixture as insertmanyvalues_fixture
|
||||
from .sql import NoCache as NoCache
|
||||
from .sql import RemovesEvents as RemovesEvents
|
||||
from .sql import TablesTest as TablesTest
|
|
@ -0,0 +1,366 @@
|
|||
# testing/fixtures/base.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .. import assertions
|
||||
from .. import config
|
||||
from ..assertions import eq_
|
||||
from ..util import drop_all_tables_from_metadata
|
||||
from ... import Column
|
||||
from ... import func
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import Table
|
||||
from ...orm import DeclarativeBase
|
||||
from ...orm import MappedAsDataclass
|
||||
from ...orm import registry
|
||||
|
||||
|
||||
@config.mark_base_test_class()
|
||||
class TestBase:
|
||||
# A sequence of requirement names matching testing.requires decorators
|
||||
__requires__ = ()
|
||||
|
||||
# A sequence of dialect names to exclude from the test class.
|
||||
__unsupported_on__ = ()
|
||||
|
||||
# If present, test class is only runnable for the *single* specified
|
||||
# dialect. If you need multiple, use __unsupported_on__ and invert.
|
||||
__only_on__ = None
|
||||
|
||||
# A sequence of no-arg callables. If any are True, the entire testcase is
|
||||
# skipped.
|
||||
__skip_if__ = None
|
||||
|
||||
# if True, the testing reaper will not attempt to touch connection
|
||||
# state after a test is completed and before the outer teardown
|
||||
# starts
|
||||
__leave_connections_for_teardown__ = False
|
||||
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
@config.fixture()
|
||||
def nocache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
@config.fixture()
|
||||
def connection_no_trans(self):
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
with eng.connect() as conn:
|
||||
yield conn
|
||||
|
||||
@config.fixture()
|
||||
def connection(self):
|
||||
global _connection_fixture_connection
|
||||
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
conn = eng.connect()
|
||||
trans = conn.begin()
|
||||
|
||||
_connection_fixture_connection = conn
|
||||
yield conn
|
||||
|
||||
_connection_fixture_connection = None
|
||||
|
||||
if trans.is_active:
|
||||
trans.rollback()
|
||||
# trans would not be active here if the test is using
|
||||
# the legacy @provide_metadata decorator still, as it will
|
||||
# run a close all connections.
|
||||
conn.close()
|
||||
|
||||
@config.fixture()
|
||||
def close_result_when_finished(self):
|
||||
to_close = []
|
||||
to_consume = []
|
||||
|
||||
def go(result, consume=False):
|
||||
to_close.append(result)
|
||||
if consume:
|
||||
to_consume.append(result)
|
||||
|
||||
yield go
|
||||
for r in to_consume:
|
||||
try:
|
||||
r.all()
|
||||
except:
|
||||
pass
|
||||
for r in to_close:
|
||||
try:
|
||||
r.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
@config.fixture()
|
||||
def registry(self, metadata):
|
||||
reg = registry(
|
||||
metadata=metadata,
|
||||
type_annotation_map={
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
},
|
||||
)
|
||||
yield reg
|
||||
reg.dispose()
|
||||
|
||||
@config.fixture
|
||||
def decl_base(self, metadata):
|
||||
_md = metadata
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
metadata = _md
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
}
|
||||
|
||||
yield Base
|
||||
Base.registry.dispose()
|
||||
|
||||
@config.fixture
|
||||
def dc_decl_base(self, metadata):
|
||||
_md = metadata
|
||||
|
||||
class Base(MappedAsDataclass, DeclarativeBase):
|
||||
metadata = _md
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb"
|
||||
)
|
||||
}
|
||||
|
||||
yield Base
|
||||
Base.registry.dispose()
|
||||
|
||||
@config.fixture()
|
||||
def future_connection(self, future_engine, connection):
|
||||
# integrate the future_engine and connection fixtures so
|
||||
# that users of the "connection" fixture will get at the
|
||||
# "future" connection
|
||||
yield connection
|
||||
|
||||
@config.fixture()
|
||||
def future_engine(self):
|
||||
yield
|
||||
|
||||
@config.fixture()
|
||||
def testing_engine(self):
|
||||
from .. import engines
|
||||
|
||||
def gen_testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
future=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
share_pool=False,
|
||||
):
|
||||
if options is None:
|
||||
options = {}
|
||||
options["scope"] = "fixture"
|
||||
return engines.testing_engine(
|
||||
url=url,
|
||||
options=options,
|
||||
asyncio=asyncio,
|
||||
transfer_staticpool=transfer_staticpool,
|
||||
share_pool=share_pool,
|
||||
)
|
||||
|
||||
yield gen_testing_engine
|
||||
|
||||
engines.testing_reaper._drop_testing_engines("fixture")
|
||||
|
||||
@config.fixture()
|
||||
def async_testing_engine(self, testing_engine):
|
||||
def go(**kw):
|
||||
kw["asyncio"] = True
|
||||
return testing_engine(**kw)
|
||||
|
||||
return go
|
||||
|
||||
@config.fixture()
|
||||
def metadata(self, request):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards."""
|
||||
|
||||
from ...sql import schema
|
||||
|
||||
metadata = schema.MetaData()
|
||||
request.instance.metadata = metadata
|
||||
yield metadata
|
||||
del request.instance.metadata
|
||||
|
||||
if (
|
||||
_connection_fixture_connection
|
||||
and _connection_fixture_connection.in_transaction()
|
||||
):
|
||||
trans = _connection_fixture_connection.get_transaction()
|
||||
trans.rollback()
|
||||
with _connection_fixture_connection.begin():
|
||||
drop_all_tables_from_metadata(
|
||||
metadata, _connection_fixture_connection
|
||||
)
|
||||
else:
|
||||
drop_all_tables_from_metadata(metadata, config.db)
|
||||
|
||||
@config.fixture(
|
||||
params=[
|
||||
(rollback, second_operation, begin_nested)
|
||||
for rollback in (True, False)
|
||||
for second_operation in ("none", "execute", "begin")
|
||||
for begin_nested in (
|
||||
True,
|
||||
False,
|
||||
)
|
||||
]
|
||||
)
|
||||
def trans_ctx_manager_fixture(self, request, metadata):
|
||||
rollback, second_operation, begin_nested = request.param
|
||||
|
||||
t = Table("test", metadata, Column("data", Integer))
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
t.create(eng)
|
||||
|
||||
def run_test(subject, trans_on_subject, execute_on_subject):
|
||||
with subject.begin() as trans:
|
||||
if begin_nested:
|
||||
if not config.requirements.savepoints.enabled:
|
||||
config.skip_test("savepoints not enabled")
|
||||
if execute_on_subject:
|
||||
nested_trans = subject.begin_nested()
|
||||
else:
|
||||
nested_trans = trans.begin_nested()
|
||||
|
||||
with nested_trans:
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 10})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 10})
|
||||
|
||||
# for nested trans, we always commit/rollback on the
|
||||
# "nested trans" object itself.
|
||||
# only Session(future=False) will affect savepoint
|
||||
# transaction for session.commit/rollback
|
||||
|
||||
if rollback:
|
||||
nested_trans.rollback()
|
||||
else:
|
||||
nested_trans.commit()
|
||||
|
||||
if second_operation != "none":
|
||||
with assertions.expect_raises_message(
|
||||
sa.exc.InvalidRequestError,
|
||||
"Can't operate on closed transaction "
|
||||
"inside context "
|
||||
"manager. Please complete the context "
|
||||
"manager "
|
||||
"before emitting further commands.",
|
||||
):
|
||||
if second_operation == "execute":
|
||||
if execute_on_subject:
|
||||
subject.execute(
|
||||
t.insert(), {"data": 12}
|
||||
)
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 12})
|
||||
elif second_operation == "begin":
|
||||
if execute_on_subject:
|
||||
subject.begin_nested()
|
||||
else:
|
||||
trans.begin_nested()
|
||||
|
||||
# outside the nested trans block, but still inside the
|
||||
# transaction block, we can run SQL, and it will be
|
||||
# committed
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 14})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 14})
|
||||
|
||||
else:
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 10})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 10})
|
||||
|
||||
if trans_on_subject:
|
||||
if rollback:
|
||||
subject.rollback()
|
||||
else:
|
||||
subject.commit()
|
||||
else:
|
||||
if rollback:
|
||||
trans.rollback()
|
||||
else:
|
||||
trans.commit()
|
||||
|
||||
if second_operation != "none":
|
||||
with assertions.expect_raises_message(
|
||||
sa.exc.InvalidRequestError,
|
||||
"Can't operate on closed transaction inside "
|
||||
"context "
|
||||
"manager. Please complete the context manager "
|
||||
"before emitting further commands.",
|
||||
):
|
||||
if second_operation == "execute":
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 12})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 12})
|
||||
elif second_operation == "begin":
|
||||
if hasattr(trans, "begin"):
|
||||
trans.begin()
|
||||
else:
|
||||
subject.begin()
|
||||
elif second_operation == "begin_nested":
|
||||
if execute_on_subject:
|
||||
subject.begin_nested()
|
||||
else:
|
||||
trans.begin_nested()
|
||||
|
||||
expected_committed = 0
|
||||
if begin_nested:
|
||||
# begin_nested variant, we inserted a row after the nested
|
||||
# block
|
||||
expected_committed += 1
|
||||
if not rollback:
|
||||
# not rollback variant, our row inserted in the target
|
||||
# block itself would be committed
|
||||
expected_committed += 1
|
||||
|
||||
if execute_on_subject:
|
||||
eq_(
|
||||
subject.scalar(select(func.count()).select_from(t)),
|
||||
expected_committed,
|
||||
)
|
||||
else:
|
||||
with subject.connect() as conn:
|
||||
eq_(
|
||||
conn.scalar(select(func.count()).select_from(t)),
|
||||
expected_committed,
|
||||
)
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
_connection_fixture_connection = None
|
||||
|
||||
|
||||
class FutureEngineMixin:
|
||||
"""alembic's suite still using this"""
|
|
@ -0,0 +1,312 @@
|
|||
# testing/fixtures/mypy.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from .base import TestBase
|
||||
from .. import config
|
||||
from ..assertions import eq_
|
||||
from ... import util
|
||||
|
||||
|
||||
@config.add_to_marker.mypy
|
||||
class MypyTest(TestBase):
|
||||
__requires__ = ("no_sqlalchemy2_stubs",)
|
||||
|
||||
@config.fixture(scope="function")
|
||||
def per_func_cachedir(self):
|
||||
yield from self._cachedir()
|
||||
|
||||
@config.fixture(scope="class")
|
||||
def cachedir(self):
|
||||
yield from self._cachedir()
|
||||
|
||||
def _cachedir(self):
|
||||
# as of mypy 0.971 i think we need to keep mypy_path empty
|
||||
mypy_path = ""
|
||||
|
||||
with tempfile.TemporaryDirectory() as cachedir:
|
||||
with open(
|
||||
Path(cachedir) / "sqla_mypy_config.cfg", "w"
|
||||
) as config_file:
|
||||
config_file.write(
|
||||
f"""
|
||||
[mypy]\n
|
||||
plugins = sqlalchemy.ext.mypy.plugin\n
|
||||
show_error_codes = True\n
|
||||
{mypy_path}
|
||||
disable_error_code = no-untyped-call
|
||||
|
||||
[mypy-sqlalchemy.*]
|
||||
ignore_errors = True
|
||||
|
||||
"""
|
||||
)
|
||||
with open(
|
||||
Path(cachedir) / "plain_mypy_config.cfg", "w"
|
||||
) as config_file:
|
||||
config_file.write(
|
||||
f"""
|
||||
[mypy]\n
|
||||
show_error_codes = True\n
|
||||
{mypy_path}
|
||||
disable_error_code = var-annotated,no-untyped-call
|
||||
[mypy-sqlalchemy.*]
|
||||
ignore_errors = True
|
||||
|
||||
"""
|
||||
)
|
||||
yield cachedir
|
||||
|
||||
@config.fixture()
|
||||
def mypy_runner(self, cachedir):
|
||||
from mypy import api
|
||||
|
||||
def run(path, use_plugin=False, use_cachedir=None):
|
||||
if use_cachedir is None:
|
||||
use_cachedir = cachedir
|
||||
args = [
|
||||
"--strict",
|
||||
"--raise-exceptions",
|
||||
"--cache-dir",
|
||||
use_cachedir,
|
||||
"--config-file",
|
||||
os.path.join(
|
||||
use_cachedir,
|
||||
(
|
||||
"sqla_mypy_config.cfg"
|
||||
if use_plugin
|
||||
else "plain_mypy_config.cfg"
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# mypy as of 0.990 is more aggressively blocking messaging
|
||||
# for paths that are in sys.path, and as pytest puts currdir,
|
||||
# test/ etc in sys.path, just copy the source file to the
|
||||
# tempdir we are working in so that we don't have to try to
|
||||
# manipulate sys.path and/or guess what mypy is doing
|
||||
filename = os.path.basename(path)
|
||||
test_program = os.path.join(use_cachedir, filename)
|
||||
if path != test_program:
|
||||
shutil.copyfile(path, test_program)
|
||||
args.append(test_program)
|
||||
|
||||
# I set this locally but for the suite here needs to be
|
||||
# disabled
|
||||
os.environ.pop("MYPY_FORCE_COLOR", None)
|
||||
|
||||
stdout, stderr, exitcode = api.run(args)
|
||||
return stdout, stderr, exitcode
|
||||
|
||||
return run
|
||||
|
||||
@config.fixture
|
||||
def mypy_typecheck_file(self, mypy_runner):
|
||||
def run(path, use_plugin=False):
|
||||
expected_messages = self._collect_messages(path)
|
||||
stdout, stderr, exitcode = mypy_runner(path, use_plugin=use_plugin)
|
||||
self._check_output(
|
||||
path, expected_messages, stdout, stderr, exitcode
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
@staticmethod
|
||||
def file_combinations(dirname):
|
||||
if os.path.isabs(dirname):
|
||||
path = dirname
|
||||
else:
|
||||
caller_path = inspect.stack()[1].filename
|
||||
path = os.path.join(os.path.dirname(caller_path), dirname)
|
||||
files = list(Path(path).glob("**/*.py"))
|
||||
|
||||
for extra_dir in config.options.mypy_extra_test_paths:
|
||||
if extra_dir and os.path.isdir(extra_dir):
|
||||
files.extend((Path(extra_dir) / dirname).glob("**/*.py"))
|
||||
return files
|
||||
|
||||
def _collect_messages(self, path):
|
||||
from sqlalchemy.ext.mypy.util import mypy_14
|
||||
|
||||
expected_messages = []
|
||||
expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
|
||||
py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
|
||||
with open(path) as file_:
|
||||
current_assert_messages = []
|
||||
for num, line in enumerate(file_, 1):
|
||||
m = py_ver_re.match(line)
|
||||
if m:
|
||||
major, _, minor = m.group(1).partition(".")
|
||||
if sys.version_info < (int(major), int(minor)):
|
||||
config.skip_test(
|
||||
"Requires python >= %s" % (m.group(1))
|
||||
)
|
||||
continue
|
||||
|
||||
m = expected_re.match(line)
|
||||
if m:
|
||||
is_mypy = bool(m.group(1))
|
||||
is_re = bool(m.group(2))
|
||||
is_type = bool(m.group(3))
|
||||
|
||||
expected_msg = re.sub(r"# noqa[:]? ?.*", "", m.group(4))
|
||||
if is_type:
|
||||
if not is_re:
|
||||
# the goal here is that we can cut-and-paste
|
||||
# from vscode -> pylance into the
|
||||
# EXPECTED_TYPE: line, then the test suite will
|
||||
# validate that line against what mypy produces
|
||||
expected_msg = re.sub(
|
||||
r"([\[\]])",
|
||||
lambda m: rf"\{m.group(0)}",
|
||||
expected_msg,
|
||||
)
|
||||
|
||||
# note making sure preceding text matches
|
||||
# with a dot, so that an expect for "Select"
|
||||
# does not match "TypedSelect"
|
||||
expected_msg = re.sub(
|
||||
r"([\w_]+)",
|
||||
lambda m: rf"(?:.*\.)?{m.group(1)}\*?",
|
||||
expected_msg,
|
||||
)
|
||||
|
||||
expected_msg = re.sub(
|
||||
"List", "builtins.list", expected_msg
|
||||
)
|
||||
|
||||
expected_msg = re.sub(
|
||||
r"\b(int|str|float|bool)\b",
|
||||
lambda m: rf"builtins.{m.group(0)}\*?",
|
||||
expected_msg,
|
||||
)
|
||||
# expected_msg = re.sub(
|
||||
# r"(Sequence|Tuple|List|Union)",
|
||||
# lambda m: fr"typing.{m.group(0)}\*?",
|
||||
# expected_msg,
|
||||
# )
|
||||
|
||||
is_mypy = is_re = True
|
||||
expected_msg = f'Revealed type is "{expected_msg}"'
|
||||
|
||||
if mypy_14 and util.py39:
|
||||
# use_lowercase_names, py39 and above
|
||||
# https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L363 # noqa: E501
|
||||
|
||||
# skip first character which could be capitalized
|
||||
# "List item x not found" type of message
|
||||
expected_msg = expected_msg[0] + re.sub(
|
||||
(
|
||||
r"\b(List|Tuple|Dict|Set)\b"
|
||||
if is_type
|
||||
else r"\b(List|Tuple|Dict|Set|Type)\b"
|
||||
),
|
||||
lambda m: m.group(1).lower(),
|
||||
expected_msg[1:],
|
||||
)
|
||||
|
||||
if mypy_14 and util.py310:
|
||||
# use_or_syntax, py310 and above
|
||||
# https://github.com/python/mypy/blob/304997bfb85200fb521ac727ee0ce3e6085e5278/mypy/options.py#L368 # noqa: E501
|
||||
expected_msg = re.sub(
|
||||
r"Optional\[(.*?)\]",
|
||||
lambda m: f"{m.group(1)} | None",
|
||||
expected_msg,
|
||||
)
|
||||
current_assert_messages.append(
|
||||
(is_mypy, is_re, expected_msg.strip())
|
||||
)
|
||||
elif current_assert_messages:
|
||||
expected_messages.extend(
|
||||
(num, is_mypy, is_re, expected_msg)
|
||||
for (
|
||||
is_mypy,
|
||||
is_re,
|
||||
expected_msg,
|
||||
) in current_assert_messages
|
||||
)
|
||||
current_assert_messages[:] = []
|
||||
|
||||
return expected_messages
|
||||
|
||||
def _check_output(self, path, expected_messages, stdout, stderr, exitcode):
|
||||
not_located = []
|
||||
filename = os.path.basename(path)
|
||||
if expected_messages:
|
||||
# mypy 0.990 changed how return codes work, so don't assume a
|
||||
# 1 or a 0 return code here, could be either depending on if
|
||||
# errors were generated or not
|
||||
|
||||
output = []
|
||||
|
||||
raw_lines = stdout.split("\n")
|
||||
while raw_lines:
|
||||
e = raw_lines.pop(0)
|
||||
if re.match(r".+\.py:\d+: error: .*", e):
|
||||
output.append(("error", e))
|
||||
elif re.match(
|
||||
r".+\.py:\d+: note: +(?:Possible overload|def ).*", e
|
||||
):
|
||||
while raw_lines:
|
||||
ol = raw_lines.pop(0)
|
||||
if not re.match(r".+\.py:\d+: note: +def \[.*", ol):
|
||||
break
|
||||
elif re.match(
|
||||
r".+\.py:\d+: note: .*(?:perhaps|suggestion)", e, re.I
|
||||
):
|
||||
pass
|
||||
elif re.match(r".+\.py:\d+: note: .*", e):
|
||||
output.append(("note", e))
|
||||
|
||||
for num, is_mypy, is_re, msg in expected_messages:
|
||||
msg = msg.replace("'", '"')
|
||||
prefix = "[SQLAlchemy Mypy plugin] " if not is_mypy else ""
|
||||
for idx, (typ, errmsg) in enumerate(output):
|
||||
if is_re:
|
||||
if re.match(
|
||||
rf".*{filename}\:{num}\: {typ}\: {prefix}{msg}",
|
||||
errmsg,
|
||||
):
|
||||
break
|
||||
elif (
|
||||
f"{filename}:{num}: {typ}: {prefix}{msg}"
|
||||
in errmsg.replace("'", '"')
|
||||
):
|
||||
break
|
||||
else:
|
||||
not_located.append(msg)
|
||||
continue
|
||||
del output[idx]
|
||||
|
||||
if not_located:
|
||||
missing = "\n".join(not_located)
|
||||
print("Couldn't locate expected messages:", missing, sep="\n")
|
||||
if output:
|
||||
extra = "\n".join(msg for _, msg in output)
|
||||
print("Remaining messages:", extra, sep="\n")
|
||||
assert False, "expected messages not found, see stdout"
|
||||
|
||||
if output:
|
||||
print(f"{len(output)} messages from mypy were not consumed:")
|
||||
print("\n".join(msg for _, msg in output))
|
||||
assert False, "errors and/or notes remain, see stdout"
|
||||
|
||||
else:
|
||||
if exitcode != 0:
|
||||
print(stdout, stderr, sep="\n")
|
||||
|
||||
eq_(exitcode, 0, msg=stdout)
|
|
@ -0,0 +1,227 @@
|
|||
# testing/fixtures/orm.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .base import TestBase
|
||||
from .sql import TablesTest
|
||||
from .. import assertions
|
||||
from .. import config
|
||||
from .. import schema
|
||||
from ..entities import BasicEntity
|
||||
from ..entities import ComparableEntity
|
||||
from ..util import adict
|
||||
from ... import orm
|
||||
from ...orm import DeclarativeBase
|
||||
from ...orm import events as orm_events
|
||||
from ...orm import registry
|
||||
|
||||
|
||||
class ORMTest(TestBase):
|
||||
@config.fixture
|
||||
def fixture_session(self):
|
||||
return fixture_session()
|
||||
|
||||
|
||||
class MappedTest(ORMTest, TablesTest, assertions.AssertsExecutionResults):
|
||||
# 'once', 'each', None
|
||||
run_setup_classes = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_setup_mappers = "each"
|
||||
|
||||
classes: Any = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
if cls.classes is None:
|
||||
cls.classes = adict()
|
||||
|
||||
cls._setup_once_tables()
|
||||
cls._setup_once_classes()
|
||||
cls._setup_once_mappers()
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_class()
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_classes()
|
||||
self._setup_each_mappers()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
orm.session.close_all_sessions()
|
||||
self._teardown_each_mappers()
|
||||
self._teardown_each_classes()
|
||||
self._teardown_each_tables()
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_class(cls):
|
||||
cls.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_classes(cls):
|
||||
if cls.run_setup_classes == "once":
|
||||
cls._with_register_classes(cls.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_mappers(cls):
|
||||
if cls.run_setup_mappers == "once":
|
||||
cls.mapper_registry, cls.mapper = cls._generate_registry()
|
||||
cls._with_register_classes(cls.setup_mappers)
|
||||
|
||||
def _setup_each_mappers(self):
|
||||
if self.run_setup_mappers != "once":
|
||||
(
|
||||
self.__class__.mapper_registry,
|
||||
self.__class__.mapper,
|
||||
) = self._generate_registry()
|
||||
|
||||
if self.run_setup_mappers == "each":
|
||||
self._with_register_classes(self.setup_mappers)
|
||||
|
||||
def _setup_each_classes(self):
|
||||
if self.run_setup_classes == "each":
|
||||
self._with_register_classes(self.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _generate_registry(cls):
|
||||
decl = registry(metadata=cls._tables_metadata)
|
||||
return decl, decl.map_imperatively
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
"""Run a setup method, framing the operation with a Base class
|
||||
that will catch new subclasses to be established within
|
||||
the "classes" registry.
|
||||
|
||||
"""
|
||||
cls_registry = cls.classes
|
||||
|
||||
class _Base:
|
||||
def __init_subclass__(cls) -> None:
|
||||
assert cls_registry is not None
|
||||
cls_registry[cls.__name__] = cls
|
||||
super().__init_subclass__()
|
||||
|
||||
class Basic(BasicEntity, _Base):
|
||||
pass
|
||||
|
||||
class Comparable(ComparableEntity, _Base):
|
||||
pass
|
||||
|
||||
cls.Basic = Basic
|
||||
cls.Comparable = Comparable
|
||||
fn()
|
||||
|
||||
def _teardown_each_mappers(self):
|
||||
# some tests create mappers in the test bodies
|
||||
# and will define setup_mappers as None -
|
||||
# clear mappers in any case
|
||||
if self.run_setup_mappers != "once":
|
||||
orm.clear_mappers()
|
||||
|
||||
def _teardown_each_classes(self):
|
||||
if self.run_setup_classes != "once":
|
||||
self.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_mappers(cls):
|
||||
pass
|
||||
|
||||
|
||||
class DeclarativeMappedTest(MappedTest):
|
||||
run_setup_classes = "once"
|
||||
run_setup_mappers = "once"
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
cls_registry = cls.classes
|
||||
|
||||
class _DeclBase(DeclarativeBase):
|
||||
__table_cls__ = schema.Table
|
||||
metadata = cls._tables_metadata
|
||||
type_annotation_map = {
|
||||
str: sa.String().with_variant(
|
||||
sa.String(50), "mysql", "mariadb", "oracle"
|
||||
)
|
||||
}
|
||||
|
||||
def __init_subclass__(cls, **kw) -> None:
|
||||
assert cls_registry is not None
|
||||
cls_registry[cls.__name__] = cls
|
||||
super().__init_subclass__(**kw)
|
||||
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
|
||||
# sets up cls.Basic which is helpful for things like composite
|
||||
# classes
|
||||
super()._with_register_classes(fn)
|
||||
|
||||
if cls._tables_metadata.tables and cls.run_create_tables:
|
||||
cls._tables_metadata.create_all(config.db)
|
||||
|
||||
|
||||
class RemoveORMEventsGlobally:
|
||||
@config.fixture(autouse=True)
|
||||
def _remove_listeners(self):
|
||||
yield
|
||||
orm_events.MapperEvents._clear()
|
||||
orm_events.InstanceEvents._clear()
|
||||
orm_events.SessionEvents._clear()
|
||||
orm_events.InstrumentationEvents._clear()
|
||||
orm_events.QueryEvents._clear()
|
||||
|
||||
|
||||
_fixture_sessions = set()
|
||||
|
||||
|
||||
def fixture_session(**kw):
|
||||
kw.setdefault("autoflush", True)
|
||||
kw.setdefault("expire_on_commit", True)
|
||||
|
||||
bind = kw.pop("bind", config.db)
|
||||
|
||||
sess = orm.Session(bind, **kw)
|
||||
_fixture_sessions.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
def close_all_sessions():
|
||||
# will close all still-referenced sessions
|
||||
orm.close_all_sessions()
|
||||
_fixture_sessions.clear()
|
||||
|
||||
|
||||
def stop_test_class_inside_fixtures(cls):
|
||||
close_all_sessions()
|
||||
orm.clear_mappers()
|
||||
|
||||
|
||||
def after_test():
|
||||
if _fixture_sessions:
|
||||
close_all_sessions()
|
|
@ -0,0 +1,492 @@
|
|||
# testing/fixtures/sql.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .base import TestBase
|
||||
from .. import config
|
||||
from .. import mock
|
||||
from ..assertions import eq_
|
||||
from ..assertions import ne_
|
||||
from ..util import adict
|
||||
from ..util import drop_all_tables_from_metadata
|
||||
from ... import event
|
||||
from ... import util
|
||||
from ...schema import sort_tables_and_constraints
|
||||
from ...sql import visitors
|
||||
from ...sql.elements import ClauseElement
|
||||
|
||||
|
||||
class TablesTest(TestBase):
|
||||
# 'once', None
|
||||
run_setup_bind = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_define_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_create_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_inserts = "each"
|
||||
|
||||
# 'each', None
|
||||
run_deletes = "each"
|
||||
|
||||
# 'once', None
|
||||
run_dispose_bind = None
|
||||
|
||||
bind = None
|
||||
_tables_metadata = None
|
||||
tables = None
|
||||
other = None
|
||||
sequences = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
cls._setup_once_tables()
|
||||
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
self._teardown_each_tables()
|
||||
|
||||
@property
|
||||
def tables_test_metadata(self):
|
||||
return self._tables_metadata
|
||||
|
||||
@classmethod
|
||||
def _init_class(cls):
|
||||
if cls.run_define_tables == "each":
|
||||
if cls.run_create_tables == "once":
|
||||
cls.run_create_tables = "each"
|
||||
assert cls.run_inserts in ("each", None)
|
||||
|
||||
cls.other = adict()
|
||||
cls.tables = adict()
|
||||
cls.sequences = adict()
|
||||
|
||||
cls.bind = cls.setup_bind()
|
||||
cls._tables_metadata = sa.MetaData()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_inserts(cls):
|
||||
if cls.run_inserts == "once":
|
||||
cls._load_fixtures()
|
||||
with cls.bind.begin() as conn:
|
||||
cls.insert_data(conn)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
if cls.run_define_tables == "once":
|
||||
cls.define_tables(cls._tables_metadata)
|
||||
if cls.run_create_tables == "once":
|
||||
cls._tables_metadata.create_all(cls.bind)
|
||||
cls.tables.update(cls._tables_metadata.tables)
|
||||
cls.sequences.update(cls._tables_metadata._sequences)
|
||||
|
||||
def _setup_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.define_tables(self._tables_metadata)
|
||||
if self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
self.tables.update(self._tables_metadata.tables)
|
||||
self.sequences.update(self._tables_metadata._sequences)
|
||||
elif self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
|
||||
def _setup_each_inserts(self):
|
||||
if self.run_inserts == "each":
|
||||
self._load_fixtures()
|
||||
with self.bind.begin() as conn:
|
||||
self.insert_data(conn)
|
||||
|
||||
def _teardown_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.tables.clear()
|
||||
if self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
self._tables_metadata.clear()
|
||||
elif self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
|
||||
savepoints = getattr(config.requirements, "savepoints", False)
|
||||
if savepoints:
|
||||
savepoints = savepoints.enabled
|
||||
|
||||
# no need to run deletes if tables are recreated on setup
|
||||
if (
|
||||
self.run_define_tables != "each"
|
||||
and self.run_create_tables != "each"
|
||||
and self.run_deletes == "each"
|
||||
):
|
||||
with self.bind.begin() as conn:
|
||||
for table in reversed(
|
||||
[
|
||||
t
|
||||
for (t, fks) in sort_tables_and_constraints(
|
||||
self._tables_metadata.tables.values()
|
||||
)
|
||||
if t is not None
|
||||
]
|
||||
):
|
||||
try:
|
||||
if savepoints:
|
||||
with conn.begin_nested():
|
||||
conn.execute(table.delete())
|
||||
else:
|
||||
conn.execute(table.delete())
|
||||
except sa.exc.DBAPIError as ex:
|
||||
print(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_metadata_bind(cls):
|
||||
if cls.run_create_tables:
|
||||
drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
|
||||
|
||||
if cls.run_dispose_bind == "once":
|
||||
cls.dispose_bind(cls.bind)
|
||||
|
||||
cls._tables_metadata.bind = None
|
||||
|
||||
if cls.run_setup_bind is not None:
|
||||
cls.bind = None
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def dispose_bind(cls, bind):
|
||||
if hasattr(bind, "dispose"):
|
||||
bind.dispose()
|
||||
elif hasattr(bind, "close"):
|
||||
bind.close()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def fixtures(cls):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
pass
|
||||
|
||||
def sql_count_(self, count, fn):
|
||||
self.assert_sql_count(self.bind, fn, count)
|
||||
|
||||
def sql_eq_(self, callable_, statements):
|
||||
self.assert_sql(self.bind, callable_, statements)
|
||||
|
||||
@classmethod
|
||||
def _load_fixtures(cls):
|
||||
"""Insert rows as represented by the fixtures() method."""
|
||||
headers, rows = {}, {}
|
||||
for table, data in cls.fixtures().items():
|
||||
if len(data) < 2:
|
||||
continue
|
||||
if isinstance(table, str):
|
||||
table = cls.tables[table]
|
||||
headers[table] = data[0]
|
||||
rows[table] = data[1:]
|
||||
for table, fks in sort_tables_and_constraints(
|
||||
cls._tables_metadata.tables.values()
|
||||
):
|
||||
if table is None:
|
||||
continue
|
||||
if table not in headers:
|
||||
continue
|
||||
with cls.bind.begin() as conn:
|
||||
conn.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(zip(headers[table], column_values))
|
||||
for column_values in rows[table]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class NoCache:
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _disable_cache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
|
||||
class RemovesEvents:
|
||||
@util.memoized_property
|
||||
def _event_fns(self):
|
||||
return set()
|
||||
|
||||
def event_listen(self, target, name, fn, **kw):
|
||||
self._event_fns.add((target, name, fn))
|
||||
event.listen(target, name, fn, **kw)
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _remove_events(self):
|
||||
yield
|
||||
for key in self._event_fns:
|
||||
event.remove(*key)
|
||||
|
||||
|
||||
class ComputedReflectionFixtureTest(TablesTest):
|
||||
run_inserts = run_deletes = None
|
||||
|
||||
__backend__ = True
|
||||
__requires__ = ("computed_columns", "table_reflection")
|
||||
|
||||
regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
|
||||
|
||||
def normalize(self, text):
|
||||
return self.regexp.sub("", text).lower()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
from ... import Integer
|
||||
from ... import testing
|
||||
from ...schema import Column
|
||||
from ...schema import Computed
|
||||
from ...schema import Table
|
||||
|
||||
Table(
|
||||
"computed_default_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_col", Integer, Computed("normal + 42")),
|
||||
Column("with_default", Integer, server_default="42"),
|
||||
)
|
||||
|
||||
t = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal + 42")),
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
t2 = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal / 42")),
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
if testing.requires.computed_columns_virtual.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal + 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal / 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.computed_columns_stored.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal - 42", persisted=True),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal * 42", persisted=True),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CacheKeyFixture:
|
||||
def _compare_equal(self, a, b, compare_values):
|
||||
a_key = a._generate_cache_key()
|
||||
b_key = b._generate_cache_key()
|
||||
|
||||
if a_key is None:
|
||||
assert a._annotations.get("nocache")
|
||||
|
||||
assert b_key is None
|
||||
else:
|
||||
eq_(a_key.key, b_key.key)
|
||||
eq_(hash(a_key.key), hash(b_key.key))
|
||||
|
||||
for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
|
||||
assert a_param.compare(b_param, compare_values=compare_values)
|
||||
return a_key, b_key
|
||||
|
||||
def _run_cache_key_fixture(self, fixture, compare_values):
|
||||
case_a = fixture()
|
||||
case_b = fixture()
|
||||
|
||||
for a, b in itertools.combinations_with_replacement(
|
||||
range(len(case_a)), 2
|
||||
):
|
||||
if a == b:
|
||||
a_key, b_key = self._compare_equal(
|
||||
case_a[a], case_b[b], compare_values
|
||||
)
|
||||
if a_key is None:
|
||||
continue
|
||||
else:
|
||||
a_key = case_a[a]._generate_cache_key()
|
||||
b_key = case_b[b]._generate_cache_key()
|
||||
|
||||
if a_key is None or b_key is None:
|
||||
if a_key is None:
|
||||
assert case_a[a]._annotations.get("nocache")
|
||||
if b_key is None:
|
||||
assert case_b[b]._annotations.get("nocache")
|
||||
continue
|
||||
|
||||
if a_key.key == b_key.key:
|
||||
for a_param, b_param in zip(
|
||||
a_key.bindparams, b_key.bindparams
|
||||
):
|
||||
if not a_param.compare(
|
||||
b_param, compare_values=compare_values
|
||||
):
|
||||
break
|
||||
else:
|
||||
# this fails unconditionally since we could not
|
||||
# find bound parameter values that differed.
|
||||
# Usually we intended to get two distinct keys here
|
||||
# so the failure will be more descriptive using the
|
||||
# ne_() assertion.
|
||||
ne_(a_key.key, b_key.key)
|
||||
else:
|
||||
ne_(a_key.key, b_key.key)
|
||||
|
||||
# ClauseElement-specific test to ensure the cache key
|
||||
# collected all the bound parameters that aren't marked
|
||||
# as "literal execute"
|
||||
if isinstance(case_a[a], ClauseElement) and isinstance(
|
||||
case_b[b], ClauseElement
|
||||
):
|
||||
assert_a_params = []
|
||||
assert_b_params = []
|
||||
|
||||
for elem in visitors.iterate(case_a[a]):
|
||||
if elem.__visit_name__ == "bindparam":
|
||||
assert_a_params.append(elem)
|
||||
|
||||
for elem in visitors.iterate(case_b[b]):
|
||||
if elem.__visit_name__ == "bindparam":
|
||||
assert_b_params.append(elem)
|
||||
|
||||
# note we're asserting the order of the params as well as
|
||||
# if there are dupes or not. ordering has to be
|
||||
# deterministic and matches what a traversal would provide.
|
||||
eq_(
|
||||
sorted(a_key.bindparams, key=lambda b: b.key),
|
||||
sorted(
|
||||
util.unique_list(assert_a_params), key=lambda b: b.key
|
||||
),
|
||||
)
|
||||
eq_(
|
||||
sorted(b_key.bindparams, key=lambda b: b.key),
|
||||
sorted(
|
||||
util.unique_list(assert_b_params), key=lambda b: b.key
|
||||
),
|
||||
)
|
||||
|
||||
def _run_cache_key_equal_fixture(self, fixture, compare_values):
|
||||
case_a = fixture()
|
||||
case_b = fixture()
|
||||
|
||||
for a, b in itertools.combinations_with_replacement(
|
||||
range(len(case_a)), 2
|
||||
):
|
||||
self._compare_equal(case_a[a], case_b[b], compare_values)
|
||||
|
||||
|
||||
def insertmanyvalues_fixture(
|
||||
connection, randomize_rows=False, warn_on_downgraded=False
|
||||
):
|
||||
dialect = connection.dialect
|
||||
orig_dialect = dialect._deliver_insertmanyvalues_batches
|
||||
orig_conn = connection._exec_insertmany_context
|
||||
|
||||
class RandomCursor:
|
||||
__slots__ = ("cursor",)
|
||||
|
||||
def __init__(self, cursor):
|
||||
self.cursor = cursor
|
||||
|
||||
# only this method is called by the deliver method.
|
||||
# by not having the other methods we assert that those aren't being
|
||||
# used
|
||||
|
||||
def fetchall(self):
|
||||
rows = self.cursor.fetchall()
|
||||
rows = list(rows)
|
||||
random.shuffle(rows)
|
||||
return rows
|
||||
|
||||
def _deliver_insertmanyvalues_batches(
|
||||
cursor, statement, parameters, generic_setinputsizes, context
|
||||
):
|
||||
if randomize_rows:
|
||||
cursor = RandomCursor(cursor)
|
||||
for batch in orig_dialect(
|
||||
cursor, statement, parameters, generic_setinputsizes, context
|
||||
):
|
||||
if warn_on_downgraded and batch.is_downgraded:
|
||||
util.warn("Batches were downgraded for sorted INSERT")
|
||||
|
||||
yield batch
|
||||
|
||||
def _exec_insertmany_context(
|
||||
dialect,
|
||||
context,
|
||||
):
|
||||
with mock.patch.object(
|
||||
dialect,
|
||||
"_deliver_insertmanyvalues_batches",
|
||||
new=_deliver_insertmanyvalues_batches,
|
||||
):
|
||||
return orig_conn(dialect, context)
|
||||
|
||||
connection._exec_insertmany_context = _exec_insertmany_context
|
|
@ -0,0 +1,155 @@
|
|||
# testing/pickleable.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""Classes used in pickling tests, need to be at the module level for
|
||||
unpickling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .entities import ComparableEntity
|
||||
from ..schema import Column
|
||||
from ..types import String
|
||||
|
||||
|
||||
class User(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Order(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Dingaling(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class EmailUser(User):
|
||||
pass
|
||||
|
||||
|
||||
class Address(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: these are kind of arbitrary....
|
||||
class Child1(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Child2(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Parent(ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Screen:
|
||||
def __init__(self, obj, parent=None):
|
||||
self.obj = obj
|
||||
self.parent = parent
|
||||
|
||||
|
||||
class Mixin:
|
||||
email_address = Column(String)
|
||||
|
||||
|
||||
class AddressWMixin(Mixin, ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Foo:
|
||||
def __init__(self, moredata, stuff="im stuff"):
|
||||
self.data = "im data"
|
||||
self.stuff = stuff
|
||||
self.moredata = moredata
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.data == self.data
|
||||
and other.stuff == self.stuff
|
||||
and other.moredata == self.moredata
|
||||
)
|
||||
|
||||
|
||||
class Bar:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.__class__ is self.__class__
|
||||
and other.x == self.x
|
||||
and other.y == self.y
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class OldSchool:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.__class__ is self.__class__
|
||||
and other.x == self.x
|
||||
and other.y == self.y
|
||||
)
|
||||
|
||||
|
||||
class OldSchoolWithoutCompare:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
class BarWithoutCompare:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class NotComparable:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BrokenComparable:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def __ne__(self, other):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,6 @@
|
|||
# testing/plugin/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
@ -0,0 +1,51 @@
|
|||
# testing/plugin/bootstrap.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Bootstrapper for test framework plugins.
|
||||
|
||||
The entire rationale for this system is to get the modules in plugin/
|
||||
imported without importing all of the supporting library, so that we can
|
||||
set up things for testing before coverage starts.
|
||||
|
||||
The rationale for all of plugin/ being *in* the supporting library in the
|
||||
first place is so that the testing and plugin suite is available to other
|
||||
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
|
||||
of the same test environment and standard suites available to
|
||||
SQLAlchemy/Alembic themselves without the need to ship/install a separate
|
||||
package outside of SQLAlchemy.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
bootstrap_file = locals()["bootstrap_file"]
|
||||
to_bootstrap = locals()["to_bootstrap"]
|
||||
|
||||
|
||||
def load_file_as_module(name):
|
||||
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
if to_bootstrap == "pytest":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
|
||||
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
|
||||
else:
|
||||
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
|
|
@ -0,0 +1,779 @@
|
|||
# testing/plugin/plugin_base.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from argparse import Namespace
|
||||
import configparser
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
"""Testing extensions.
|
||||
|
||||
this module is designed to work as a testing-framework-agnostic library,
|
||||
created so that multiple test frameworks can be supported at once
|
||||
(mostly so that we can migrate to new ones). The current target
|
||||
is pytest.
|
||||
|
||||
"""
|
||||
|
||||
# flag which indicates we are in the SQLAlchemy testing suite,
|
||||
# and not that of Alembic or a third party dialect.
|
||||
bootstrapped_as_sqlalchemy = False
|
||||
|
||||
log = logging.getLogger("sqlalchemy.testing.plugin_base")
|
||||
|
||||
# late imports
|
||||
fixtures = None
|
||||
engines = None
|
||||
exclusions = None
|
||||
warnings = None
|
||||
profiling = None
|
||||
provision = None
|
||||
assertions = None
|
||||
requirements = None
|
||||
config = None
|
||||
testing = None
|
||||
util = None
|
||||
file_config = None
|
||||
|
||||
logging = None
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
options: Namespace = None # type: ignore
|
||||
|
||||
|
||||
def setup_options(make_option):
|
||||
make_option(
|
||||
"--log-info",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_log,
|
||||
help="turn on info logging for <LOG> (multiple OK)",
|
||||
)
|
||||
make_option(
|
||||
"--log-debug",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_log,
|
||||
help="turn on debug logging for <LOG> (multiple OK)",
|
||||
)
|
||||
make_option(
|
||||
"--db",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="db",
|
||||
help="Use prefab database uri. Multiple OK, "
|
||||
"first one is run by default.",
|
||||
)
|
||||
make_option(
|
||||
"--dbs",
|
||||
action="callback",
|
||||
zeroarg_callback=_list_dbs,
|
||||
help="List available prefab dbs",
|
||||
)
|
||||
make_option(
|
||||
"--dburi",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="dburi",
|
||||
help="Database uri. Multiple OK, first one is run by default.",
|
||||
)
|
||||
make_option(
|
||||
"--dbdriver",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="dbdriver",
|
||||
help="Additional database drivers to include in tests. "
|
||||
"These are linked to the existing database URLs by the "
|
||||
"provisioning system.",
|
||||
)
|
||||
make_option(
|
||||
"--dropfirst",
|
||||
action="store_true",
|
||||
dest="dropfirst",
|
||||
help="Drop all tables in the target database first",
|
||||
)
|
||||
make_option(
|
||||
"--disable-asyncio",
|
||||
action="store_true",
|
||||
help="disable test / fixtures / provisoning running in asyncio",
|
||||
)
|
||||
make_option(
|
||||
"--backend-only",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_include("backend"),
|
||||
help=(
|
||||
"Run only tests marked with __backend__ or __sparse_backend__; "
|
||||
"this is now equivalent to the pytest -m backend mark expression"
|
||||
),
|
||||
)
|
||||
make_option(
|
||||
"--nomemory",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_exclude("memory_intensive"),
|
||||
help="Don't run memory profiling tests; "
|
||||
"this is now equivalent to the pytest -m 'not memory_intensive' "
|
||||
"mark expression",
|
||||
)
|
||||
make_option(
|
||||
"--notimingintensive",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_exclude("timing_intensive"),
|
||||
help="Don't run timing intensive tests; "
|
||||
"this is now equivalent to the pytest -m 'not timing_intensive' "
|
||||
"mark expression",
|
||||
)
|
||||
make_option(
|
||||
"--nomypy",
|
||||
action="callback",
|
||||
zeroarg_callback=_set_tag_exclude("mypy"),
|
||||
help="Don't run mypy typing tests; "
|
||||
"this is now equivalent to the pytest -m 'not mypy' mark expression",
|
||||
)
|
||||
make_option(
|
||||
"--profile-sort",
|
||||
type=str,
|
||||
default="cumulative",
|
||||
dest="profilesort",
|
||||
help="Type of sort for profiling standard output",
|
||||
)
|
||||
make_option(
|
||||
"--profile-dump",
|
||||
type=str,
|
||||
dest="profiledump",
|
||||
help="Filename where a single profile run will be dumped",
|
||||
)
|
||||
make_option(
|
||||
"--low-connections",
|
||||
action="store_true",
|
||||
dest="low_connections",
|
||||
help="Use a low number of distinct connections - "
|
||||
"i.e. for Oracle TNS",
|
||||
)
|
||||
make_option(
|
||||
"--write-idents",
|
||||
type=str,
|
||||
dest="write_idents",
|
||||
help="write out generated follower idents to <file>, "
|
||||
"when -n<num> is used",
|
||||
)
|
||||
make_option(
|
||||
"--requirements",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_requirements_opt,
|
||||
help="requirements class for testing, overrides setup.cfg",
|
||||
)
|
||||
make_option(
|
||||
"--include-tag",
|
||||
action="callback",
|
||||
callback=_include_tag,
|
||||
type=str,
|
||||
help="Include tests with tag <tag>; "
|
||||
"legacy, use pytest -m 'tag' instead",
|
||||
)
|
||||
make_option(
|
||||
"--exclude-tag",
|
||||
action="callback",
|
||||
callback=_exclude_tag,
|
||||
type=str,
|
||||
help="Exclude tests with tag <tag>; "
|
||||
"legacy, use pytest -m 'not tag' instead",
|
||||
)
|
||||
make_option(
|
||||
"--write-profiles",
|
||||
action="store_true",
|
||||
dest="write_profiles",
|
||||
default=False,
|
||||
help="Write/update failing profiling data.",
|
||||
)
|
||||
make_option(
|
||||
"--force-write-profiles",
|
||||
action="store_true",
|
||||
dest="force_write_profiles",
|
||||
default=False,
|
||||
help="Unconditionally write/update profiling data.",
|
||||
)
|
||||
make_option(
|
||||
"--dump-pyannotate",
|
||||
type=str,
|
||||
dest="dump_pyannotate",
|
||||
help="Run pyannotate and dump json info to given file",
|
||||
)
|
||||
make_option(
|
||||
"--mypy-extra-test-path",
|
||||
type=str,
|
||||
action="append",
|
||||
default=[],
|
||||
dest="mypy_extra_test_paths",
|
||||
help="Additional test directories to add to the mypy tests. "
|
||||
"This is used only when running mypy tests. Multiple OK",
|
||||
)
|
||||
# db specific options
|
||||
make_option(
|
||||
"--postgresql-templatedb",
|
||||
type=str,
|
||||
help="name of template database to use for PostgreSQL "
|
||||
"CREATE DATABASE (defaults to current database)",
|
||||
)
|
||||
make_option(
|
||||
"--oracledb-thick-mode",
|
||||
action="store_true",
|
||||
help="enables the 'thick mode' when testing with oracle+oracledb",
|
||||
)
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
"""Configure required state for a follower.
|
||||
|
||||
This invokes in the parent process and typically includes
|
||||
database creation.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
provision.FOLLOWER_IDENT = follower_ident
|
||||
|
||||
|
||||
def memoize_important_follower_config(dict_):
|
||||
"""Store important configuration we will need to send to a follower.
|
||||
|
||||
This invokes in the parent process after normal config is set up.
|
||||
|
||||
Hook is currently not used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def restore_important_follower_config(dict_):
|
||||
"""Restore important configuration needed by a follower.
|
||||
|
||||
This invokes in the follower process.
|
||||
|
||||
Hook is currently not used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def read_config(root_path):
|
||||
global file_config
|
||||
file_config = configparser.ConfigParser()
|
||||
file_config.read(
|
||||
[str(root_path / "setup.cfg"), str(root_path / "test.cfg")]
|
||||
)
|
||||
|
||||
|
||||
def pre_begin(opt):
|
||||
"""things to set up early, before coverage might be setup."""
|
||||
global options
|
||||
options = opt
|
||||
for fn in pre_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
|
||||
def set_coverage_flag(value):
|
||||
options.has_coverage = value
|
||||
|
||||
|
||||
def post_begin():
|
||||
"""things to set up later, once we know coverage is running."""
|
||||
# Lazy setup of other options (post coverage)
|
||||
for fn in post_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
# late imports, has to happen after config.
|
||||
global util, fixtures, engines, exclusions, assertions, provision
|
||||
global warnings, profiling, config, testing
|
||||
from sqlalchemy import testing # noqa
|
||||
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
|
||||
from sqlalchemy.testing import assertions, warnings, profiling # noqa
|
||||
from sqlalchemy.testing import config, provision # noqa
|
||||
from sqlalchemy import util # noqa
|
||||
|
||||
warnings.setup_filters()
|
||||
|
||||
|
||||
def _log(opt_str, value, parser):
|
||||
global logging
|
||||
if not logging:
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
if opt_str.endswith("-info"):
|
||||
logging.getLogger(value).setLevel(logging.INFO)
|
||||
elif opt_str.endswith("-debug"):
|
||||
logging.getLogger(value).setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _list_dbs(*args):
|
||||
if file_config is None:
|
||||
# assume the current working directory is the one containing the
|
||||
# setup file
|
||||
read_config(Path.cwd())
|
||||
print("Available --db options (use --dburi to override)")
|
||||
for macro in sorted(file_config.options("db")):
|
||||
print("%20s\t%s" % (macro, file_config.get("db", macro)))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _requirements_opt(opt_str, value, parser):
|
||||
_setup_requirements(value)
|
||||
|
||||
|
||||
def _set_tag_include(tag):
|
||||
def _do_include_tag(opt_str, value, parser):
|
||||
_include_tag(opt_str, tag, parser)
|
||||
|
||||
return _do_include_tag
|
||||
|
||||
|
||||
def _set_tag_exclude(tag):
|
||||
def _do_exclude_tag(opt_str, value, parser):
|
||||
_exclude_tag(opt_str, tag, parser)
|
||||
|
||||
return _do_exclude_tag
|
||||
|
||||
|
||||
def _exclude_tag(opt_str, value, parser):
|
||||
exclude_tags.add(value.replace("-", "_"))
|
||||
|
||||
|
||||
def _include_tag(opt_str, value, parser):
|
||||
include_tags.add(value.replace("-", "_"))
|
||||
|
||||
|
||||
pre_configure = []
|
||||
post_configure = []
|
||||
|
||||
|
||||
def pre(fn):
|
||||
pre_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def post(fn):
|
||||
post_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
@pre
|
||||
def _setup_options(opt, file_config):
|
||||
global options
|
||||
options = opt
|
||||
|
||||
|
||||
@pre
|
||||
def _register_sqlite_numeric_dialect(opt, file_config):
|
||||
from sqlalchemy.dialects import registry
|
||||
|
||||
registry.register(
|
||||
"sqlite.pysqlite_numeric",
|
||||
"sqlalchemy.dialects.sqlite.pysqlite",
|
||||
"_SQLiteDialect_pysqlite_numeric",
|
||||
)
|
||||
registry.register(
|
||||
"sqlite.pysqlite_dollar",
|
||||
"sqlalchemy.dialects.sqlite.pysqlite",
|
||||
"_SQLiteDialect_pysqlite_dollar",
|
||||
)
|
||||
|
||||
|
||||
@post
|
||||
def __ensure_cext(opt, file_config):
|
||||
if os.environ.get("REQUIRE_SQLALCHEMY_CEXT", "0") == "1":
|
||||
from sqlalchemy.util import has_compiled_ext
|
||||
|
||||
try:
|
||||
has_compiled_ext(raise_=True)
|
||||
except ImportError as err:
|
||||
raise AssertionError(
|
||||
"REQUIRE_SQLALCHEMY_CEXT is set but can't import the "
|
||||
"cython extensions"
|
||||
) from err
|
||||
|
||||
|
||||
@post
|
||||
def _init_symbols(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config._fixture_functions = _fixture_fn_class()
|
||||
|
||||
|
||||
@pre
|
||||
def _set_disable_asyncio(opt, file_config):
|
||||
if opt.disable_asyncio:
|
||||
asyncio.ENABLE_ASYNCIO = False
|
||||
|
||||
|
||||
@post
|
||||
def _engine_uri(options, file_config):
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.engine import url as sa_url
|
||||
|
||||
if options.dburi:
|
||||
db_urls = list(options.dburi)
|
||||
else:
|
||||
db_urls = []
|
||||
|
||||
extra_drivers = options.dbdriver or []
|
||||
|
||||
if options.db:
|
||||
for db_token in options.db:
|
||||
for db in re.split(r"[,\s]+", db_token):
|
||||
if db not in file_config.options("db"):
|
||||
raise RuntimeError(
|
||||
"Unknown URI specifier '%s'. "
|
||||
"Specify --dbs for known uris." % db
|
||||
)
|
||||
else:
|
||||
db_urls.append(file_config.get("db", db))
|
||||
|
||||
if not db_urls:
|
||||
db_urls.append(file_config.get("db", "default"))
|
||||
|
||||
config._current = None
|
||||
|
||||
if options.write_idents and provision.FOLLOWER_IDENT:
|
||||
for db_url in [sa_url.make_url(db_url) for db_url in db_urls]:
|
||||
with open(options.write_idents, "a") as file_:
|
||||
file_.write(
|
||||
f"{provision.FOLLOWER_IDENT} "
|
||||
f"{db_url.render_as_string(hide_password=False)}\n"
|
||||
)
|
||||
|
||||
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
|
||||
|
||||
for db_url in expanded_urls:
|
||||
log.info("Adding database URL: %s", db_url)
|
||||
|
||||
cfg = provision.setup_config(
|
||||
db_url, options, file_config, provision.FOLLOWER_IDENT
|
||||
)
|
||||
if not config._current:
|
||||
cfg.set_as_current(cfg, testing)
|
||||
|
||||
|
||||
@post
|
||||
def _requirements(options, file_config):
|
||||
requirement_cls = file_config.get("sqla_testing", "requirement_cls")
|
||||
_setup_requirements(requirement_cls)
|
||||
|
||||
|
||||
def _setup_requirements(argument):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy import testing
|
||||
|
||||
modname, clsname = argument.split(":")
|
||||
|
||||
# importlib.import_module() only introduced in 2.7, a little
|
||||
# late
|
||||
mod = __import__(modname)
|
||||
for component in modname.split(".")[1:]:
|
||||
mod = getattr(mod, component)
|
||||
req_cls = getattr(mod, clsname)
|
||||
|
||||
config.requirements = testing.requires = req_cls()
|
||||
|
||||
config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
|
||||
|
||||
|
||||
@post
|
||||
def _prep_testing_database(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if options.dropfirst:
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
provision.drop_all_schema_objects(cfg, cfg.db)
|
||||
|
||||
|
||||
@post
|
||||
def _post_setup_options(opt, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config.options = options
|
||||
config.file_config = file_config
|
||||
|
||||
|
||||
@post
|
||||
def _setup_profiling(options, file_config):
|
||||
from sqlalchemy.testing import profiling
|
||||
|
||||
profiling._profile_stats = profiling.ProfileStatsFile(
|
||||
file_config.get("sqla_testing", "profile_file"),
|
||||
sort=options.profilesort,
|
||||
dump=options.profiledump,
|
||||
)
|
||||
|
||||
|
||||
def want_class(name, cls):
|
||||
if not issubclass(cls, fixtures.TestBase):
|
||||
return False
|
||||
elif name.startswith("_"):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def want_method(cls, fn):
|
||||
if not fn.__name__.startswith("test_"):
|
||||
return False
|
||||
elif fn.__module__ is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def generate_sub_tests(cls, module, markers):
|
||||
if "backend" in markers or "sparse_backend" in markers:
|
||||
sparse = "sparse_backend" in markers
|
||||
for cfg in _possible_configs_for_cls(cls, sparse=sparse):
|
||||
orig_name = cls.__name__
|
||||
|
||||
# we can have special chars in these names except for the
|
||||
# pytest junit plugin, which is tripped up by the brackets
|
||||
# and periods, so sanitize
|
||||
|
||||
alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
|
||||
alpha_name = re.sub(r"_+$", "", alpha_name)
|
||||
name = "%s_%s" % (cls.__name__, alpha_name)
|
||||
subcls = type(
|
||||
name,
|
||||
(cls,),
|
||||
{"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
|
||||
)
|
||||
setattr(module, name, subcls)
|
||||
yield subcls
|
||||
else:
|
||||
yield cls
|
||||
|
||||
|
||||
def start_test_class_outside_fixtures(cls):
|
||||
_do_skips(cls)
|
||||
_setup_engine(cls)
|
||||
|
||||
|
||||
def stop_test_class(cls):
|
||||
# close sessions, immediate connections, etc.
|
||||
fixtures.stop_test_class_inside_fixtures(cls)
|
||||
|
||||
# close outstanding connection pool connections, dispose of
|
||||
# additional engines
|
||||
engines.testing_reaper.stop_test_class_inside_fixtures()
|
||||
|
||||
|
||||
def stop_test_class_outside_fixtures(cls):
|
||||
engines.testing_reaper.stop_test_class_outside_fixtures()
|
||||
provision.stop_test_class_outside_fixtures(config, config.db, cls)
|
||||
try:
|
||||
if not options.low_connections:
|
||||
assertions.global_cleanup_assertions()
|
||||
finally:
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _restore_engine():
|
||||
if config._current:
|
||||
config._current.reset(testing)
|
||||
|
||||
|
||||
def final_process_cleanup():
|
||||
engines.testing_reaper.final_cleanup()
|
||||
assertions.global_cleanup_assertions()
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _setup_engine(cls):
|
||||
if getattr(cls, "__engine_options__", None):
|
||||
opts = dict(cls.__engine_options__)
|
||||
opts["scope"] = "class"
|
||||
eng = engines.testing_engine(options=opts)
|
||||
config._current.push_engine(eng, testing)
|
||||
|
||||
|
||||
def before_test(test, test_module_name, test_class, test_name):
|
||||
# format looks like:
|
||||
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
|
||||
|
||||
name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
|
||||
|
||||
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
|
||||
|
||||
profiling._start_current_test(id_)
|
||||
|
||||
|
||||
def after_test(test):
|
||||
fixtures.after_test()
|
||||
engines.testing_reaper.after_test()
|
||||
|
||||
|
||||
def after_test_fixtures(test):
|
||||
engines.testing_reaper.after_test_outside_fixtures(test)
|
||||
|
||||
|
||||
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
|
||||
all_configs = set(config.Config.all_configs())
|
||||
|
||||
if cls.__unsupported_on__:
|
||||
spec = exclusions.db_spec(*cls.__unsupported_on__)
|
||||
for config_obj in list(all_configs):
|
||||
if spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, "__only_on__", None):
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
|
||||
for config_obj in list(all_configs):
|
||||
if not spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, "__only_on_config__", None):
|
||||
all_configs.intersection_update([cls.__only_on_config__])
|
||||
|
||||
if hasattr(cls, "__requires__"):
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
skip_reasons = check.matching_config_reasons(config_obj)
|
||||
if skip_reasons:
|
||||
all_configs.remove(config_obj)
|
||||
if reasons is not None:
|
||||
reasons.extend(skip_reasons)
|
||||
break
|
||||
|
||||
if hasattr(cls, "__prefer_requires__"):
|
||||
non_preferred = set()
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__prefer_requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
if not check.enabled_for_config(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if sparse:
|
||||
# pick only one config from each base dialect
|
||||
# sorted so we get the same backend each time selecting the highest
|
||||
# server version info.
|
||||
per_dialect = {}
|
||||
for cfg in reversed(
|
||||
sorted(
|
||||
all_configs,
|
||||
key=lambda cfg: (
|
||||
cfg.db.name,
|
||||
cfg.db.driver,
|
||||
cfg.db.dialect.server_version_info,
|
||||
),
|
||||
)
|
||||
):
|
||||
db = cfg.db.name
|
||||
if db not in per_dialect:
|
||||
per_dialect[db] = cfg
|
||||
return per_dialect.values()
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def _do_skips(cls):
|
||||
reasons = []
|
||||
all_configs = _possible_configs_for_cls(cls, reasons)
|
||||
|
||||
if getattr(cls, "__skip_if__", False):
|
||||
for c in getattr(cls, "__skip_if__"):
|
||||
if c():
|
||||
config.skip_test(
|
||||
"'%s' skipped by %s" % (cls.__name__, c.__name__)
|
||||
)
|
||||
|
||||
if not all_configs:
|
||||
msg = "'%s.%s' unsupported on any DB implementation %s%s" % (
|
||||
cls.__module__,
|
||||
cls.__name__,
|
||||
", ".join(
|
||||
"'%s(%s)+%s'"
|
||||
% (
|
||||
config_obj.db.name,
|
||||
".".join(
|
||||
str(dig)
|
||||
for dig in exclusions._server_version(config_obj.db)
|
||||
),
|
||||
config_obj.db.driver,
|
||||
)
|
||||
for config_obj in config.Config.all_configs()
|
||||
),
|
||||
", ".join(reasons),
|
||||
)
|
||||
config.skip_test(msg)
|
||||
elif hasattr(cls, "__prefer_backends__"):
|
||||
non_preferred = set()
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
|
||||
for config_obj in all_configs:
|
||||
if not spec(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if config._current not in all_configs:
|
||||
_setup_config(all_configs.pop(), cls)
|
||||
|
||||
|
||||
def _setup_config(config_obj, ctx):
|
||||
config._current.push(config_obj, testing)
|
||||
|
||||
|
||||
class FixtureFunctions(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def combinations(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def param_ident(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def fixture(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_test_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def mark_base_test_class(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractproperty
|
||||
def add_to_marker(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_fixture_fn_class = None
|
||||
|
||||
|
||||
def set_fixture_functions(fixture_fn_class):
|
||||
global _fixture_fn_class
|
||||
_fixture_fn_class = fixture_fn_class
|
|
@ -0,0 +1,862 @@
|
|||
# testing/plugin/pytestplugin.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
from functools import update_wrapper
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
# installed by bootstrap.py
|
||||
if not TYPE_CHECKING:
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
group = parser.getgroup("sqlalchemy")
|
||||
|
||||
def make_option(name, **kw):
|
||||
callback_ = kw.pop("callback", None)
|
||||
if callback_:
|
||||
|
||||
class CallableAction(argparse.Action):
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None
|
||||
):
|
||||
callback_(option_string, values, parser)
|
||||
|
||||
kw["action"] = CallableAction
|
||||
|
||||
zeroarg_callback = kw.pop("zeroarg_callback", None)
|
||||
if zeroarg_callback:
|
||||
|
||||
class CallableAction(argparse.Action):
|
||||
def __init__(
|
||||
self,
|
||||
option_strings,
|
||||
dest,
|
||||
default=False,
|
||||
required=False,
|
||||
help=None, # noqa
|
||||
):
|
||||
super().__init__(
|
||||
option_strings=option_strings,
|
||||
dest=dest,
|
||||
nargs=0,
|
||||
const=True,
|
||||
default=default,
|
||||
required=required,
|
||||
help=help,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None
|
||||
):
|
||||
zeroarg_callback(option_string, values, parser)
|
||||
|
||||
kw["action"] = CallableAction
|
||||
|
||||
group.addoption(name, **kw)
|
||||
|
||||
plugin_base.setup_options(make_option)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config):
|
||||
plugin_base.read_config(config.rootpath)
|
||||
if plugin_base.exclude_tags or plugin_base.include_tags:
|
||||
new_expr = " and ".join(
|
||||
list(plugin_base.include_tags)
|
||||
+ [f"not {tag}" for tag in plugin_base.exclude_tags]
|
||||
)
|
||||
|
||||
if config.option.markexpr:
|
||||
config.option.markexpr += f" and {new_expr}"
|
||||
else:
|
||||
config.option.markexpr = new_expr
|
||||
|
||||
if config.pluginmanager.hasplugin("xdist"):
|
||||
config.pluginmanager.register(XDistHooks())
|
||||
|
||||
if hasattr(config, "workerinput"):
|
||||
plugin_base.restore_important_follower_config(config.workerinput)
|
||||
plugin_base.configure_follower(config.workerinput["follower_ident"])
|
||||
else:
|
||||
if config.option.write_idents and os.path.exists(
|
||||
config.option.write_idents
|
||||
):
|
||||
os.remove(config.option.write_idents)
|
||||
|
||||
plugin_base.pre_begin(config.option)
|
||||
|
||||
plugin_base.set_coverage_flag(
|
||||
bool(getattr(config.option, "cov_source", False))
|
||||
)
|
||||
|
||||
plugin_base.set_fixture_functions(PytestFixtureFunctions)
|
||||
|
||||
if config.option.dump_pyannotate:
|
||||
global DUMP_PYANNOTATE
|
||||
DUMP_PYANNOTATE = True
|
||||
|
||||
|
||||
DUMP_PYANNOTATE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def collect_types_fixture():
|
||||
if DUMP_PYANNOTATE:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
collect_types.start()
|
||||
yield
|
||||
if DUMP_PYANNOTATE:
|
||||
collect_types.stop()
|
||||
|
||||
|
||||
def _log_sqlalchemy_info(session):
|
||||
import sqlalchemy
|
||||
from sqlalchemy import __version__
|
||||
from sqlalchemy.util import has_compiled_ext
|
||||
from sqlalchemy.util._has_cy import _CYEXTENSION_MSG
|
||||
|
||||
greet = "sqlalchemy installation"
|
||||
site = "no user site" if sys.flags.no_user_site else "user site loaded"
|
||||
msgs = [
|
||||
f"SQLAlchemy {__version__} ({site})",
|
||||
f"Path: {sqlalchemy.__file__}",
|
||||
]
|
||||
|
||||
if has_compiled_ext():
|
||||
from sqlalchemy.cyextension import util
|
||||
|
||||
msgs.append(f"compiled extension enabled, e.g. {util.__file__} ")
|
||||
else:
|
||||
msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}")
|
||||
|
||||
pm = session.config.pluginmanager.get_plugin("terminalreporter")
|
||||
if pm:
|
||||
pm.write_sep("=", greet)
|
||||
for m in msgs:
|
||||
pm.write_line(m)
|
||||
else:
|
||||
# fancy pants reporter not found, fallback to plain print
|
||||
print("=" * 25, greet, "=" * 25)
|
||||
for m in msgs:
|
||||
print(m)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
_log_sqlalchemy_info(session)
|
||||
asyncio._assume_async(plugin_base.post_begin)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
|
||||
|
||||
if session.config.option.dump_pyannotate:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
collect_types.dump_stats(session.config.option.dump_pyannotate)
|
||||
|
||||
|
||||
def pytest_collection_finish(session):
|
||||
if session.config.option.dump_pyannotate:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
|
||||
|
||||
def _filter(filename):
|
||||
filename = os.path.normpath(os.path.abspath(filename))
|
||||
if "lib/sqlalchemy" not in os.path.commonpath(
|
||||
[filename, lib_sqlalchemy]
|
||||
):
|
||||
return None
|
||||
if "testing" in filename:
|
||||
return None
|
||||
|
||||
return filename
|
||||
|
||||
collect_types.init_types_collection(filter_filename=_filter)
|
||||
|
||||
|
||||
class XDistHooks:
|
||||
def pytest_configure_node(self, node):
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# the master for each node fills workerinput dictionary
|
||||
# which pytest-xdist will transfer to the subprocess
|
||||
|
||||
plugin_base.memoize_important_follower_config(node.workerinput)
|
||||
|
||||
node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
|
||||
|
||||
asyncio._maybe_async_provisioning(
|
||||
provision.create_follower_db, node.workerinput["follower_ident"]
|
||||
)
|
||||
|
||||
def pytest_testnodedown(self, node, error):
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async_provisioning(
|
||||
provision.drop_follower_db, node.workerinput["follower_ident"]
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
# look for all those classes that specify __backend__ and
|
||||
# expand them out into per-database test cases.
|
||||
|
||||
# this is much easier to do within pytest_pycollect_makeitem, however
|
||||
# pytest is iterating through cls.__dict__ as makeitem is
|
||||
# called which causes a "dictionary changed size" error on py3k.
|
||||
# I'd submit a pullreq for them to turn it into a list first, but
|
||||
# it's to suit the rather odd use case here which is that we are adding
|
||||
# new classes to a module on the fly.
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
rebuilt_items = collections.defaultdict(
|
||||
lambda: collections.defaultdict(list)
|
||||
)
|
||||
|
||||
items[:] = [
|
||||
item
|
||||
for item in items
|
||||
if item.getparent(pytest.Class) is not None
|
||||
and not item.getparent(pytest.Class).name.startswith("_")
|
||||
]
|
||||
|
||||
test_classes = {item.getparent(pytest.Class) for item in items}
|
||||
|
||||
def collect(element):
|
||||
for inst_or_fn in element.collect():
|
||||
if isinstance(inst_or_fn, pytest.Collector):
|
||||
yield from collect(inst_or_fn)
|
||||
else:
|
||||
yield inst_or_fn
|
||||
|
||||
def setup_test_classes():
|
||||
for test_class in test_classes:
|
||||
# transfer legacy __backend__ and __sparse_backend__ symbols
|
||||
# to be markers
|
||||
add_markers = set()
|
||||
if getattr(test_class.cls, "__backend__", False) or getattr(
|
||||
test_class.cls, "__only_on__", False
|
||||
):
|
||||
add_markers = {"backend"}
|
||||
elif getattr(test_class.cls, "__sparse_backend__", False):
|
||||
add_markers = {"sparse_backend"}
|
||||
else:
|
||||
add_markers = frozenset()
|
||||
|
||||
existing_markers = {
|
||||
mark.name for mark in test_class.iter_markers()
|
||||
}
|
||||
add_markers = add_markers - existing_markers
|
||||
all_markers = existing_markers.union(add_markers)
|
||||
|
||||
for marker in add_markers:
|
||||
test_class.add_marker(marker)
|
||||
|
||||
for sub_cls in plugin_base.generate_sub_tests(
|
||||
test_class.cls, test_class.module, all_markers
|
||||
):
|
||||
if sub_cls is not test_class.cls:
|
||||
per_cls_dict = rebuilt_items[test_class.cls]
|
||||
|
||||
module = test_class.getparent(pytest.Module)
|
||||
|
||||
new_cls = pytest.Class.from_parent(
|
||||
name=sub_cls.__name__, parent=module
|
||||
)
|
||||
for marker in add_markers:
|
||||
new_cls.add_marker(marker)
|
||||
|
||||
for fn in collect(new_cls):
|
||||
per_cls_dict[fn.name].append(fn)
|
||||
|
||||
# class requirements will sometimes need to access the DB to check
|
||||
# capabilities, so need to do this for async
|
||||
asyncio._maybe_async_provisioning(setup_test_classes)
|
||||
|
||||
newitems = []
|
||||
for item in items:
|
||||
cls_ = item.cls
|
||||
if cls_ in rebuilt_items:
|
||||
newitems.extend(rebuilt_items[cls_][item.name])
|
||||
else:
|
||||
newitems.append(item)
|
||||
|
||||
# seems like the functions attached to a test class aren't sorted already?
|
||||
# is that true and why's that? (when using unittest, they're sorted)
|
||||
items[:] = sorted(
|
||||
newitems,
|
||||
key=lambda item: (
|
||||
item.getparent(pytest.Module).name,
|
||||
item.getparent(pytest.Class).name,
|
||||
item.name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if config.any_async:
|
||||
obj = _apply_maybe_async(obj)
|
||||
|
||||
return [
|
||||
pytest.Class.from_parent(
|
||||
name=parametrize_cls.__name__, parent=collector
|
||||
)
|
||||
for parametrize_cls in _parametrize_cls(collector.module, obj)
|
||||
]
|
||||
elif (
|
||||
inspect.isfunction(obj)
|
||||
and collector.cls is not None
|
||||
and plugin_base.want_method(collector.cls, obj)
|
||||
):
|
||||
# None means, fall back to default logic, which includes
|
||||
# method-level parametrize
|
||||
return None
|
||||
else:
|
||||
# empty list means skip this item
|
||||
return []
|
||||
|
||||
|
||||
def _is_wrapped_coroutine_function(fn):
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
|
||||
return inspect.iscoroutinefunction(fn)
|
||||
|
||||
|
||||
def _apply_maybe_async(obj, recurse=True):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
for name, value in vars(obj).items():
|
||||
if (
|
||||
(callable(value) or isinstance(value, classmethod))
|
||||
and not getattr(value, "_maybe_async_applied", False)
|
||||
and (name.startswith("test_"))
|
||||
and not _is_wrapped_coroutine_function(value)
|
||||
):
|
||||
is_classmethod = False
|
||||
if isinstance(value, classmethod):
|
||||
value = value.__func__
|
||||
is_classmethod = True
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def make_async(fn, *args, **kwargs):
|
||||
return asyncio._maybe_async(fn, *args, **kwargs)
|
||||
|
||||
do_async = make_async(value)
|
||||
if is_classmethod:
|
||||
do_async = classmethod(do_async)
|
||||
do_async._maybe_async_applied = True
|
||||
|
||||
setattr(obj, name, do_async)
|
||||
if recurse:
|
||||
for cls in obj.mro()[1:]:
|
||||
if cls != object:
|
||||
_apply_maybe_async(cls, False)
|
||||
return obj
|
||||
|
||||
|
||||
def _parametrize_cls(module, cls):
|
||||
"""implement a class-based version of pytest parametrize."""
|
||||
|
||||
if "_sa_parametrize" not in cls.__dict__:
|
||||
return [cls]
|
||||
|
||||
_sa_parametrize = cls._sa_parametrize
|
||||
classes = []
|
||||
for full_param_set in itertools.product(
|
||||
*[params for argname, params in _sa_parametrize]
|
||||
):
|
||||
cls_variables = {}
|
||||
|
||||
for argname, param in zip(
|
||||
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
|
||||
):
|
||||
if not argname:
|
||||
raise TypeError("need argnames for class-based combinations")
|
||||
argname_split = re.split(r",\s*", argname)
|
||||
for arg, val in zip(argname_split, param.values):
|
||||
cls_variables[arg] = val
|
||||
parametrized_name = "_".join(
|
||||
re.sub(r"\W", "", token)
|
||||
for param in full_param_set
|
||||
for token in param.id.split("-")
|
||||
)
|
||||
name = "%s_%s" % (cls.__name__, parametrized_name)
|
||||
newcls = type.__new__(type, name, (cls,), cls_variables)
|
||||
setattr(module, name, newcls)
|
||||
classes.append(newcls)
|
||||
return classes
|
||||
|
||||
|
||||
_current_class = None
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# pytest_runtest_setup runs *before* pytest fixtures with scope="class".
|
||||
# plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
|
||||
# for the whole class and has to run things that are across all current
|
||||
# databases, so we run this outside of the pytest fixture system altogether
|
||||
# and ensure asyncio greenlet if any engines are async
|
||||
|
||||
global _current_class
|
||||
|
||||
if isinstance(item, pytest.Function) and _current_class is None:
|
||||
asyncio._maybe_async_provisioning(
|
||||
plugin_base.start_test_class_outside_fixtures,
|
||||
item.cls,
|
||||
)
|
||||
_current_class = item.getparent(pytest.Class)
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_teardown(item, nextitem):
|
||||
# runs inside of pytest function fixture scope
|
||||
# after test function runs
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async(plugin_base.after_test, item)
|
||||
|
||||
yield
|
||||
# this is now after all the fixture teardown have run, the class can be
|
||||
# finalized. Since pytest v7 this finalizer can no longer be added in
|
||||
# pytest_runtest_setup since the class has not yet been setup at that
|
||||
# time.
|
||||
# See https://github.com/pytest-dev/pytest/issues/9343
|
||||
global _current_class, _current_report
|
||||
|
||||
if _current_class is not None and (
|
||||
# last test or a new class
|
||||
nextitem is None
|
||||
or nextitem.getparent(pytest.Class) is not _current_class
|
||||
):
|
||||
_current_class = None
|
||||
|
||||
try:
|
||||
asyncio._maybe_async_provisioning(
|
||||
plugin_base.stop_test_class_outside_fixtures, item.cls
|
||||
)
|
||||
except Exception as e:
|
||||
# in case of an exception during teardown attach the original
|
||||
# error to the exception message, otherwise it will get lost
|
||||
if _current_report.failed:
|
||||
if not e.args:
|
||||
e.args = (
|
||||
"__Original test failure__:\n"
|
||||
+ _current_report.longreprtext,
|
||||
)
|
||||
elif e.args[-1] and isinstance(e.args[-1], str):
|
||||
args = list(e.args)
|
||||
args[-1] += (
|
||||
"\n__Original test failure__:\n"
|
||||
+ _current_report.longreprtext
|
||||
)
|
||||
e.args = tuple(args)
|
||||
else:
|
||||
e.args += (
|
||||
"__Original test failure__",
|
||||
_current_report.longreprtext,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_current_report = None
|
||||
|
||||
|
||||
def pytest_runtest_call(item):
|
||||
# runs inside of pytest function fixture scope
|
||||
# before test function runs
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async(
|
||||
plugin_base.before_test,
|
||||
item,
|
||||
item.module.__name__,
|
||||
item.cls,
|
||||
item.name,
|
||||
)
|
||||
|
||||
|
||||
_current_report = None
|
||||
|
||||
|
||||
def pytest_runtest_logreport(report):
|
||||
global _current_report
|
||||
if report.when == "call":
|
||||
_current_report = report
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup_class_methods(request):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
cls = request.cls
|
||||
|
||||
if hasattr(cls, "setup_test_class"):
|
||||
asyncio._maybe_async(cls.setup_test_class)
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(cls, "teardown_test_class"):
|
||||
asyncio._maybe_async(cls.teardown_test_class)
|
||||
|
||||
asyncio._maybe_async(plugin_base.stop_test_class, cls)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def setup_test_methods(request):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# called for each test
|
||||
|
||||
self = request.instance
|
||||
|
||||
# before this fixture runs:
|
||||
|
||||
# 1. function level "autouse" fixtures under py3k (examples: TablesTest
|
||||
# define tables / data, MappedTest define tables / mappers / data)
|
||||
|
||||
# 2. was for p2k. no longer applies
|
||||
|
||||
# 3. run outer xdist-style setup
|
||||
if hasattr(self, "setup_test"):
|
||||
asyncio._maybe_async(self.setup_test)
|
||||
|
||||
# alembic test suite is using setUp and tearDown
|
||||
# xdist methods; support these in the test suite
|
||||
# for the near term
|
||||
if hasattr(self, "setUp"):
|
||||
asyncio._maybe_async(self.setUp)
|
||||
|
||||
# inside the yield:
|
||||
# 4. function level fixtures defined on test functions themselves,
|
||||
# e.g. "connection", "metadata" run next
|
||||
|
||||
# 5. pytest hook pytest_runtest_call then runs
|
||||
|
||||
# 6. test itself runs
|
||||
|
||||
yield
|
||||
|
||||
# yield finishes:
|
||||
|
||||
# 7. function level fixtures defined on test functions
|
||||
# themselves, e.g. "connection" rolls back the transaction, "metadata"
|
||||
# emits drop all
|
||||
|
||||
# 8. pytest hook pytest_runtest_teardown hook runs, this is associated
|
||||
# with fixtures close all sessions, provisioning.stop_test_class(),
|
||||
# engines.testing_reaper -> ensure all connection pool connections
|
||||
# are returned, engines created by testing_engine that aren't the
|
||||
# config engine are disposed
|
||||
|
||||
asyncio._maybe_async(plugin_base.after_test_fixtures, self)
|
||||
|
||||
# 10. run xdist-style teardown
|
||||
if hasattr(self, "tearDown"):
|
||||
asyncio._maybe_async(self.tearDown)
|
||||
|
||||
if hasattr(self, "teardown_test"):
|
||||
asyncio._maybe_async(self.teardown_test)
|
||||
|
||||
# 11. was for p2k. no longer applies
|
||||
|
||||
# 12. function level "autouse" fixtures under py3k (examples: TablesTest /
|
||||
# MappedTest delete table data, possibly drop tables and clear mappers
|
||||
# depending on the flags defined by the test class)
|
||||
|
||||
|
||||
def _pytest_fn_decorator(target):
|
||||
"""Port of langhelpers.decorator with pytest-specific tricks."""
|
||||
|
||||
from sqlalchemy.util.langhelpers import format_argspec_plus
|
||||
from sqlalchemy.util.compat import inspect_getfullargspec
|
||||
|
||||
def _exec_code_in_env(code, env, fn_name):
|
||||
# note this is affected by "from __future__ import annotations" at
|
||||
# the top; exec'ed code will use non-evaluated annotations
|
||||
# which allows us to be more flexible with code rendering
|
||||
# in format_argpsec_plus()
|
||||
exec(code, env)
|
||||
return env[fn_name]
|
||||
|
||||
def decorate(fn, add_positional_parameters=()):
|
||||
spec = inspect_getfullargspec(fn)
|
||||
if add_positional_parameters:
|
||||
spec.args.extend(add_positional_parameters)
|
||||
|
||||
metadata = dict(
|
||||
__target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
|
||||
)
|
||||
metadata.update(format_argspec_plus(spec, grouped=False))
|
||||
code = (
|
||||
"""\
|
||||
def %(name)s%(grouped_args)s:
|
||||
return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
|
||||
"""
|
||||
% metadata
|
||||
)
|
||||
decorated = _exec_code_in_env(
|
||||
code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
|
||||
)
|
||||
if not add_positional_parameters:
|
||||
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
|
||||
decorated.__wrapped__ = fn
|
||||
return update_wrapper(decorated, fn)
|
||||
else:
|
||||
# this is the pytest hacky part. don't do a full update wrapper
|
||||
# because pytest is really being sneaky about finding the args
|
||||
# for the wrapped function
|
||||
decorated.__module__ = fn.__module__
|
||||
decorated.__name__ = fn.__name__
|
||||
if hasattr(fn, "pytestmark"):
|
||||
decorated.pytestmark = fn.pytestmark
|
||||
return decorated
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
return pytest.skip.Exception(*arg, **kw)
|
||||
|
||||
@property
|
||||
def add_to_marker(self):
|
||||
return pytest.mark
|
||||
|
||||
def mark_base_test_class(self):
|
||||
return pytest.mark.usefixtures(
|
||||
"setup_class_methods", "setup_test_methods"
|
||||
)
|
||||
|
||||
_combination_id_fns = {
|
||||
"i": lambda obj: obj,
|
||||
"r": repr,
|
||||
"s": str,
|
||||
"n": lambda obj: (
|
||||
obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__
|
||||
),
|
||||
}
|
||||
|
||||
def combinations(self, *arg_sets, **kw):
|
||||
"""Facade for pytest.mark.parametrize.
|
||||
|
||||
Automatically derives argument names from the callable which in our
|
||||
case is always a method on a class with positional arguments.
|
||||
|
||||
ids for parameter sets are derived using an optional template.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import exclusions
|
||||
|
||||
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
|
||||
arg_sets = list(arg_sets[0])
|
||||
|
||||
argnames = kw.pop("argnames", None)
|
||||
|
||||
def _filter_exclusions(args):
|
||||
result = []
|
||||
gathered_exclusions = []
|
||||
for a in args:
|
||||
if isinstance(a, exclusions.compound):
|
||||
gathered_exclusions.append(a)
|
||||
else:
|
||||
result.append(a)
|
||||
|
||||
return result, gathered_exclusions
|
||||
|
||||
id_ = kw.pop("id_", None)
|
||||
|
||||
tobuild_pytest_params = []
|
||||
has_exclusions = False
|
||||
if id_:
|
||||
_combination_id_fns = self._combination_id_fns
|
||||
|
||||
# because itemgetter is not consistent for one argument vs.
|
||||
# multiple, make it multiple in all cases and use a slice
|
||||
# to omit the first argument
|
||||
_arg_getter = operator.itemgetter(
|
||||
0,
|
||||
*[
|
||||
idx
|
||||
for idx, char in enumerate(id_)
|
||||
if char in ("n", "r", "s", "a")
|
||||
],
|
||||
)
|
||||
fns = [
|
||||
(operator.itemgetter(idx), _combination_id_fns[char])
|
||||
for idx, char in enumerate(id_)
|
||||
if char in _combination_id_fns
|
||||
]
|
||||
|
||||
for arg in arg_sets:
|
||||
if not isinstance(arg, tuple):
|
||||
arg = (arg,)
|
||||
|
||||
fn_params, param_exclusions = _filter_exclusions(arg)
|
||||
|
||||
parameters = _arg_getter(fn_params)[1:]
|
||||
|
||||
if param_exclusions:
|
||||
has_exclusions = True
|
||||
|
||||
tobuild_pytest_params.append(
|
||||
(
|
||||
parameters,
|
||||
param_exclusions,
|
||||
"-".join(
|
||||
comb_fn(getter(arg)) for getter, comb_fn in fns
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
for arg in arg_sets:
|
||||
if not isinstance(arg, tuple):
|
||||
arg = (arg,)
|
||||
|
||||
fn_params, param_exclusions = _filter_exclusions(arg)
|
||||
|
||||
if param_exclusions:
|
||||
has_exclusions = True
|
||||
|
||||
tobuild_pytest_params.append(
|
||||
(fn_params, param_exclusions, None)
|
||||
)
|
||||
|
||||
pytest_params = []
|
||||
for parameters, param_exclusions, id_ in tobuild_pytest_params:
|
||||
if has_exclusions:
|
||||
parameters += (param_exclusions,)
|
||||
|
||||
param = pytest.param(*parameters, id=id_)
|
||||
pytest_params.append(param)
|
||||
|
||||
def decorate(fn):
|
||||
if inspect.isclass(fn):
|
||||
if has_exclusions:
|
||||
raise NotImplementedError(
|
||||
"exclusions not supported for class level combinations"
|
||||
)
|
||||
if "_sa_parametrize" not in fn.__dict__:
|
||||
fn._sa_parametrize = []
|
||||
fn._sa_parametrize.append((argnames, pytest_params))
|
||||
return fn
|
||||
else:
|
||||
_fn_argnames = inspect.getfullargspec(fn).args[1:]
|
||||
if argnames is None:
|
||||
_argnames = _fn_argnames
|
||||
else:
|
||||
_argnames = re.split(r", *", argnames)
|
||||
|
||||
if has_exclusions:
|
||||
existing_exl = sum(
|
||||
1 for n in _fn_argnames if n.startswith("_exclusions")
|
||||
)
|
||||
current_exclusion_name = f"_exclusions_{existing_exl}"
|
||||
_argnames += [current_exclusion_name]
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def check_exclusions(fn, *args, **kw):
|
||||
_exclusions = args[-1]
|
||||
if _exclusions:
|
||||
exlu = exclusions.compound().add(*_exclusions)
|
||||
fn = exlu(fn)
|
||||
return fn(*args[:-1], **kw)
|
||||
|
||||
fn = check_exclusions(
|
||||
fn, add_positional_parameters=(current_exclusion_name,)
|
||||
)
|
||||
|
||||
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
def param_ident(self, *parameters):
|
||||
ident = parameters[0]
|
||||
return pytest.param(*parameters[1:], id=ident)
|
||||
|
||||
def fixture(self, *arg, **kw):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# wrapping pytest.fixture function. determine if
|
||||
# decorator was called as @fixture or @fixture().
|
||||
if len(arg) > 0 and callable(arg[0]):
|
||||
# was called as @fixture(), we have the function to wrap.
|
||||
fn = arg[0]
|
||||
arg = arg[1:]
|
||||
else:
|
||||
# was called as @fixture, don't have the function yet.
|
||||
fn = None
|
||||
|
||||
# create a pytest.fixture marker. because the fn is not being
|
||||
# passed, this is always a pytest.FixtureFunctionMarker()
|
||||
# object (or whatever pytest is calling it when you read this)
|
||||
# that is waiting for a function.
|
||||
fixture = pytest.fixture(*arg, **kw)
|
||||
|
||||
# now apply wrappers to the function, including fixture itself
|
||||
|
||||
def wrap(fn):
|
||||
if config.any_async:
|
||||
fn = asyncio._maybe_async_wrapper(fn)
|
||||
# other wrappers may be added here
|
||||
|
||||
# now apply FixtureFunctionMarker
|
||||
fn = fixture(fn)
|
||||
|
||||
return fn
|
||||
|
||||
if fn:
|
||||
return wrap(fn)
|
||||
else:
|
||||
return wrap
|
||||
|
||||
def get_current_test_name(self):
|
||||
return os.environ.get("PYTEST_CURRENT_TEST")
|
||||
|
||||
def async_test(self, fn):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def decorate(fn, *args, **kwargs):
|
||||
asyncio._run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
return decorate(fn)
|
|
@ -0,0 +1,324 @@
|
|||
# testing/profiling.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""Profiling support for unit and performance tests.
|
||||
|
||||
These are special purpose profiling methods which operate
|
||||
in a more fine-grained way than nose's profiling plugin.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import pstats
|
||||
import re
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from .util import gc_collect
|
||||
from ..util import has_compiled_ext
|
||||
|
||||
|
||||
try:
|
||||
import cProfile
|
||||
except ImportError:
|
||||
cProfile = None
|
||||
|
||||
_profile_stats = None
|
||||
"""global ProfileStatsFileInstance.
|
||||
|
||||
plugin_base assigns this at the start of all tests.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
_current_test = None
|
||||
"""String id of current test.
|
||||
|
||||
plugin_base assigns this at the start of each test using
|
||||
_start_current_test.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _start_current_test(id_):
|
||||
global _current_test
|
||||
_current_test = id_
|
||||
|
||||
if _profile_stats.force_write:
|
||||
_profile_stats.reset_count()
|
||||
|
||||
|
||||
class ProfileStatsFile:
|
||||
"""Store per-platform/fn profiling results in a file.
|
||||
|
||||
There was no json module available when this was written, but now
|
||||
the file format which is very deterministically line oriented is kind of
|
||||
handy in any case for diffs and merges.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename, sort="cumulative", dump=None):
|
||||
self.force_write = (
|
||||
config.options is not None and config.options.force_write_profiles
|
||||
)
|
||||
self.write = self.force_write or (
|
||||
config.options is not None and config.options.write_profiles
|
||||
)
|
||||
self.fname = os.path.abspath(filename)
|
||||
self.short_fname = os.path.split(self.fname)[-1]
|
||||
self.data = collections.defaultdict(
|
||||
lambda: collections.defaultdict(dict)
|
||||
)
|
||||
self.dump = dump
|
||||
self.sort = sort
|
||||
self._read()
|
||||
if self.write:
|
||||
# rewrite for the case where features changed,
|
||||
# etc.
|
||||
self._write()
|
||||
|
||||
@property
|
||||
def platform_key(self):
|
||||
dbapi_key = config.db.name + "_" + config.db.driver
|
||||
|
||||
if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
|
||||
config.db.url
|
||||
):
|
||||
dbapi_key += "_file"
|
||||
|
||||
# keep it at 2.7, 3.1, 3.2, etc. for now.
|
||||
py_version = ".".join([str(v) for v in sys.version_info[0:2]])
|
||||
|
||||
platform_tokens = [
|
||||
platform.machine(),
|
||||
platform.system().lower(),
|
||||
platform.python_implementation().lower(),
|
||||
py_version,
|
||||
dbapi_key,
|
||||
]
|
||||
|
||||
platform_tokens.append("dbapiunicode")
|
||||
_has_cext = has_compiled_ext()
|
||||
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
|
||||
return "_".join(platform_tokens)
|
||||
|
||||
def has_stats(self):
|
||||
test_key = _current_test
|
||||
return (
|
||||
test_key in self.data and self.platform_key in self.data[test_key]
|
||||
)
|
||||
|
||||
def result(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
|
||||
if "counts" not in per_platform:
|
||||
per_platform["counts"] = counts = []
|
||||
else:
|
||||
counts = per_platform["counts"]
|
||||
|
||||
if "current_count" not in per_platform:
|
||||
per_platform["current_count"] = current_count = 0
|
||||
else:
|
||||
current_count = per_platform["current_count"]
|
||||
|
||||
has_count = len(counts) > current_count
|
||||
|
||||
if not has_count:
|
||||
counts.append(callcount)
|
||||
if self.write:
|
||||
self._write()
|
||||
result = None
|
||||
else:
|
||||
result = per_platform["lineno"], counts[current_count]
|
||||
per_platform["current_count"] += 1
|
||||
return result
|
||||
|
||||
def reset_count(self):
|
||||
test_key = _current_test
|
||||
# since self.data is a defaultdict, don't access a key
|
||||
# if we don't know it's there first.
|
||||
if test_key not in self.data:
|
||||
return
|
||||
per_fn = self.data[test_key]
|
||||
if self.platform_key not in per_fn:
|
||||
return
|
||||
per_platform = per_fn[self.platform_key]
|
||||
if "counts" in per_platform:
|
||||
per_platform["counts"][:] = []
|
||||
|
||||
def replace(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
counts = per_platform["counts"]
|
||||
current_count = per_platform["current_count"]
|
||||
if current_count < len(counts):
|
||||
counts[current_count - 1] = callcount
|
||||
else:
|
||||
counts[-1] = callcount
|
||||
if self.write:
|
||||
self._write()
|
||||
|
||||
def _header(self):
|
||||
return (
|
||||
"# %s\n"
|
||||
"# This file is written out on a per-environment basis.\n"
|
||||
"# For each test in aaa_profiling, the corresponding "
|
||||
"function and \n"
|
||||
"# environment is located within this file. "
|
||||
"If it doesn't exist,\n"
|
||||
"# the test is skipped.\n"
|
||||
"# If a callcount does exist, it is compared "
|
||||
"to what we received. \n"
|
||||
"# assertions are raised if the counts do not match.\n"
|
||||
"# \n"
|
||||
"# To add a new callcount test, apply the function_call_count \n"
|
||||
"# decorator and re-run the tests using the --write-profiles \n"
|
||||
"# option - this file will be rewritten including the new count.\n"
|
||||
"# \n"
|
||||
) % (self.fname)
|
||||
|
||||
def _read(self):
|
||||
try:
|
||||
profile_f = open(self.fname)
|
||||
except OSError:
|
||||
return
|
||||
for lineno, line in enumerate(profile_f):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
test_key, platform_key, counts = line.split()
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[platform_key]
|
||||
c = [int(count) for count in counts.split(",")]
|
||||
per_platform["counts"] = c
|
||||
per_platform["lineno"] = lineno + 1
|
||||
per_platform["current_count"] = 0
|
||||
profile_f.close()
|
||||
|
||||
def _write(self):
|
||||
print("Writing profile file %s" % self.fname)
|
||||
profile_f = open(self.fname, "w")
|
||||
profile_f.write(self._header())
|
||||
for test_key in sorted(self.data):
|
||||
per_fn = self.data[test_key]
|
||||
profile_f.write("\n# TEST: %s\n\n" % test_key)
|
||||
for platform_key in sorted(per_fn):
|
||||
per_platform = per_fn[platform_key]
|
||||
c = ",".join(str(count) for count in per_platform["counts"])
|
||||
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
|
||||
profile_f.close()
|
||||
|
||||
|
||||
def function_call_count(variance=0.05, times=1, warmup=0):
|
||||
"""Assert a target for a test case's function call count.
|
||||
|
||||
The main purpose of this assertion is to detect changes in
|
||||
callcounts for various functions - the actual number is not as important.
|
||||
Callcounts are stored in a file keyed to Python version and OS platform
|
||||
information. This file is generated automatically for new tests,
|
||||
and versioned so that unexpected changes in callcounts will be detected.
|
||||
|
||||
"""
|
||||
|
||||
# use signature-rewriting decorator function so that pytest fixtures
|
||||
# still work on py27. In Py3, update_wrapper() alone is good enough,
|
||||
# likely due to the introduction of __signature__.
|
||||
|
||||
from sqlalchemy.util import decorator
|
||||
|
||||
@decorator
|
||||
def wrap(fn, *args, **kw):
|
||||
for warm in range(warmup):
|
||||
fn(*args, **kw)
|
||||
|
||||
timerange = range(times)
|
||||
with count_functions(variance=variance):
|
||||
for time in timerange:
|
||||
rv = fn(*args, **kw)
|
||||
return rv
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def count_functions(variance=0.05):
|
||||
if cProfile is None:
|
||||
raise config._skip_test_exception("cProfile is not installed")
|
||||
|
||||
if not _profile_stats.has_stats() and not _profile_stats.write:
|
||||
config.skip_test(
|
||||
"No profiling stats available on this "
|
||||
"platform for this function. Run tests with "
|
||||
"--write-profiles to add statistics to %s for "
|
||||
"this platform." % _profile_stats.short_fname
|
||||
)
|
||||
|
||||
gc_collect()
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
# began = time.time()
|
||||
yield
|
||||
# ended = time.time()
|
||||
pr.disable()
|
||||
|
||||
# s = StringIO()
|
||||
stats = pstats.Stats(pr, stream=sys.stdout)
|
||||
|
||||
# timespent = ended - began
|
||||
callcount = stats.total_calls
|
||||
|
||||
expected = _profile_stats.result(callcount)
|
||||
|
||||
if expected is None:
|
||||
expected_count = None
|
||||
else:
|
||||
line_no, expected_count = expected
|
||||
|
||||
print("Pstats calls: %d Expected %s" % (callcount, expected_count))
|
||||
stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
|
||||
stats.print_stats()
|
||||
if _profile_stats.dump:
|
||||
base, ext = os.path.splitext(_profile_stats.dump)
|
||||
test_name = _current_test.split(".")[-1]
|
||||
dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
|
||||
stats.dump_stats(dumpfile)
|
||||
print("Dumped stats to file %s" % dumpfile)
|
||||
# stats.print_callers()
|
||||
if _profile_stats.force_write:
|
||||
_profile_stats.replace(callcount)
|
||||
elif expected_count:
|
||||
deviance = int(callcount * variance)
|
||||
failed = abs(callcount - expected_count) > deviance
|
||||
|
||||
if failed:
|
||||
if _profile_stats.write:
|
||||
_profile_stats.replace(callcount)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Adjusted function call count %s not within %s%% "
|
||||
"of expected %s, platform %s. Rerun with "
|
||||
"--write-profiles to "
|
||||
"regenerate this callcount."
|
||||
% (
|
||||
callcount,
|
||||
(variance * 100),
|
||||
expected_count,
|
||||
_profile_stats.platform_key,
|
||||
)
|
||||
)
|
|
@ -0,0 +1,496 @@
|
|||
# testing/provision.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
from . import config
|
||||
from . import engines
|
||||
from . import util
|
||||
from .. import exc
|
||||
from .. import inspect
|
||||
from ..engine import url as sa_url
|
||||
from ..sql import ddl
|
||||
from ..sql import schema
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FOLLOWER_IDENT = None
|
||||
|
||||
|
||||
class register:
|
||||
def __init__(self, decorator=None):
|
||||
self.fns = {}
|
||||
self.decorator = decorator
|
||||
|
||||
@classmethod
|
||||
def init(cls, fn):
|
||||
return register().for_db("*")(fn)
|
||||
|
||||
@classmethod
|
||||
def init_decorator(cls, decorator):
|
||||
return register(decorator).for_db("*")
|
||||
|
||||
def for_db(self, *dbnames):
|
||||
def decorate(fn):
|
||||
if self.decorator:
|
||||
fn = self.decorator(fn)
|
||||
for dbname in dbnames:
|
||||
self.fns[dbname] = fn
|
||||
return self
|
||||
|
||||
return decorate
|
||||
|
||||
def __call__(self, cfg, *arg, **kw):
|
||||
if isinstance(cfg, str):
|
||||
url = sa_url.make_url(cfg)
|
||||
elif isinstance(cfg, sa_url.URL):
|
||||
url = cfg
|
||||
else:
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
if backend in self.fns:
|
||||
return self.fns[backend](cfg, *arg, **kw)
|
||||
else:
|
||||
return self.fns["*"](cfg, *arg, **kw)
|
||||
|
||||
|
||||
def create_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
|
||||
create_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def setup_config(db_url, options, file_config, follower_ident):
|
||||
# load the dialect, which should also have it set up its provision
|
||||
# hooks
|
||||
|
||||
dialect = sa_url.make_url(db_url).get_dialect()
|
||||
|
||||
dialect.load_provisioning()
|
||||
|
||||
if follower_ident:
|
||||
db_url = follower_url_from_main(db_url, follower_ident)
|
||||
db_opts = {}
|
||||
update_db_opts(db_url, db_opts, options)
|
||||
db_opts["scope"] = "global"
|
||||
eng = engines.testing_engine(db_url, db_opts)
|
||||
post_configure_engine(db_url, eng, follower_ident)
|
||||
eng.connect().close()
|
||||
|
||||
cfg = config.Config.register(eng, db_opts, options, file_config)
|
||||
|
||||
# a symbolic name that tests can use if they need to disambiguate
|
||||
# names across databases
|
||||
if follower_ident:
|
||||
config.ident = follower_ident
|
||||
|
||||
if follower_ident:
|
||||
configure_follower(cfg, follower_ident)
|
||||
return cfg
|
||||
|
||||
|
||||
def drop_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
|
||||
drop_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def generate_db_urls(db_urls, extra_drivers):
|
||||
"""Generate a set of URLs to test given configured URLs plus additional
|
||||
driver names.
|
||||
|
||||
Given::
|
||||
|
||||
--dburi postgresql://db1 \
|
||||
--dburi postgresql://db2 \
|
||||
--dburi postgresql://db2 \
|
||||
--dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
|
||||
|
||||
Noting that the default postgresql driver is psycopg2, the output
|
||||
would be::
|
||||
|
||||
postgresql+psycopg2://db1
|
||||
postgresql+asyncpg://db1
|
||||
postgresql+psycopg2://db2
|
||||
postgresql+psycopg2://db3
|
||||
|
||||
That is, for the driver in a --dburi, we want to keep that and use that
|
||||
driver for each URL it's part of . For a driver that is only
|
||||
in --dbdrivers, we want to use it just once for one of the URLs.
|
||||
for a driver that is both coming from --dburi as well as --dbdrivers,
|
||||
we want to keep it in that dburi.
|
||||
|
||||
Driver specific query options can be specified by added them to the
|
||||
driver name. For example, to enable the async fallback option for
|
||||
asyncpg::
|
||||
|
||||
--dburi postgresql://db1 \
|
||||
--dbdriver=asyncpg?async_fallback=true
|
||||
|
||||
"""
|
||||
urls = set()
|
||||
|
||||
backend_to_driver_we_already_have = collections.defaultdict(set)
|
||||
|
||||
urls_plus_dialects = [
|
||||
(url_obj, url_obj.get_dialect())
|
||||
for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
|
||||
]
|
||||
|
||||
for url_obj, dialect in urls_plus_dialects:
|
||||
# use get_driver_name instead of dialect.driver to account for
|
||||
# "_async" virtual drivers like oracledb and psycopg
|
||||
driver_name = url_obj.get_driver_name()
|
||||
backend_to_driver_we_already_have[dialect.name].add(driver_name)
|
||||
|
||||
backend_to_driver_we_need = {}
|
||||
|
||||
for url_obj, dialect in urls_plus_dialects:
|
||||
backend = dialect.name
|
||||
dialect.load_provisioning()
|
||||
|
||||
if backend not in backend_to_driver_we_need:
|
||||
backend_to_driver_we_need[backend] = extra_per_backend = set(
|
||||
extra_drivers
|
||||
).difference(backend_to_driver_we_already_have[backend])
|
||||
else:
|
||||
extra_per_backend = backend_to_driver_we_need[backend]
|
||||
|
||||
for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
|
||||
if driver_url in urls:
|
||||
continue
|
||||
urls.add(driver_url)
|
||||
yield driver_url
|
||||
|
||||
|
||||
def _generate_driver_urls(url, extra_drivers):
|
||||
main_driver = url.get_driver_name()
|
||||
extra_drivers.discard(main_driver)
|
||||
|
||||
url = generate_driver_url(url, main_driver, "")
|
||||
yield url
|
||||
|
||||
for drv in list(extra_drivers):
|
||||
if "?" in drv:
|
||||
driver_only, query_str = drv.split("?", 1)
|
||||
|
||||
else:
|
||||
driver_only = drv
|
||||
query_str = None
|
||||
|
||||
new_url = generate_driver_url(url, driver_only, query_str)
|
||||
if new_url:
|
||||
extra_drivers.remove(drv)
|
||||
|
||||
yield new_url
|
||||
|
||||
|
||||
@register.init
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
backend = url.get_backend_name()
|
||||
|
||||
new_url = url.set(
|
||||
drivername="%s+%s" % (backend, driver),
|
||||
)
|
||||
if query_str:
|
||||
new_url = new_url.update_query_string(query_str)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
return None
|
||||
else:
|
||||
return new_url
|
||||
|
||||
|
||||
def _configs_for_db_operation():
|
||||
hosts = set()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
host_conf = (backend, url.username, url.host, url.database)
|
||||
|
||||
if host_conf not in hosts:
|
||||
yield cfg
|
||||
hosts.add(host_conf)
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_all_schema_objects_post_tables(cfg, eng):
|
||||
pass
|
||||
|
||||
|
||||
def drop_all_schema_objects(cfg, eng):
|
||||
drop_all_schema_objects_pre_tables(cfg, eng)
|
||||
|
||||
drop_views(cfg, eng)
|
||||
|
||||
if config.requirements.materialized_views.enabled:
|
||||
drop_materialized_views(cfg, eng)
|
||||
|
||||
inspector = inspect(eng)
|
||||
|
||||
consider_schemas = (None,)
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
consider_schemas += (cfg.test_schema, cfg.test_schema_2)
|
||||
util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
|
||||
|
||||
drop_all_schema_objects_post_tables(cfg, eng)
|
||||
|
||||
if config.requirements.sequences.enabled_for_config(cfg):
|
||||
with eng.begin() as conn:
|
||||
for seq in inspector.get_sequence_names():
|
||||
conn.execute(ddl.DropSequence(schema.Sequence(seq)))
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
for schema_name in [cfg.test_schema, cfg.test_schema_2]:
|
||||
for seq in inspector.get_sequence_names(
|
||||
schema=schema_name
|
||||
):
|
||||
conn.execute(
|
||||
ddl.DropSequence(
|
||||
schema.Sequence(seq, schema=schema_name)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def drop_views(cfg, eng):
|
||||
inspector = inspect(eng)
|
||||
|
||||
try:
|
||||
view_names = inspector.get_view_names()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
with eng.begin() as conn:
|
||||
for vname in view_names:
|
||||
conn.execute(
|
||||
ddl._DropView(schema.Table(vname, schema.MetaData()))
|
||||
)
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
try:
|
||||
view_names = inspector.get_view_names(schema=cfg.test_schema)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
with eng.begin() as conn:
|
||||
for vname in view_names:
|
||||
conn.execute(
|
||||
ddl._DropView(
|
||||
schema.Table(
|
||||
vname,
|
||||
schema.MetaData(),
|
||||
schema=cfg.test_schema,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def drop_materialized_views(cfg, eng):
|
||||
inspector = inspect(eng)
|
||||
|
||||
mview_names = inspector.get_materialized_view_names()
|
||||
|
||||
with eng.begin() as conn:
|
||||
for vname in mview_names:
|
||||
conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
mview_names = inspector.get_materialized_view_names(
|
||||
schema=cfg.test_schema
|
||||
)
|
||||
with eng.begin() as conn:
|
||||
for vname in mview_names:
|
||||
conn.exec_driver_sql(
|
||||
f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def create_db(cfg, eng, ident):
|
||||
"""Dynamically create a database for testing.
|
||||
|
||||
Used when a test run will employ multiple processes, e.g., when run
|
||||
via `tox` or `pytest -n4`.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"no DB creation routine for cfg: %s" % (eng.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_db(cfg, eng, ident):
|
||||
"""Drop a database that we dynamically created for testing."""
|
||||
raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
|
||||
|
||||
|
||||
def _adapt_update_db_opts(fn):
|
||||
insp = util.inspect_getfullargspec(fn)
|
||||
if len(insp.args) == 3:
|
||||
return fn
|
||||
else:
|
||||
return lambda db_url, db_opts, _options: fn(db_url, db_opts)
|
||||
|
||||
|
||||
@register.init_decorator(_adapt_update_db_opts)
|
||||
def update_db_opts(db_url, db_opts, options):
|
||||
"""Set database options (db_opts) for a test database that we created."""
|
||||
|
||||
|
||||
@register.init
|
||||
def post_configure_engine(url, engine, follower_ident):
|
||||
"""Perform extra steps after configuring an engine for testing.
|
||||
|
||||
(For the internal dialects, currently only used by sqlite, oracle)
|
||||
"""
|
||||
|
||||
|
||||
@register.init
|
||||
def follower_url_from_main(url, ident):
|
||||
"""Create a connection URL for a dynamically-created test database.
|
||||
|
||||
:param url: the connection URL specified when the test run was invoked
|
||||
:param ident: the pytest-xdist "worker identifier" to be used as the
|
||||
database name
|
||||
"""
|
||||
url = sa_url.make_url(url)
|
||||
return url.set(database=ident)
|
||||
|
||||
|
||||
@register.init
|
||||
def configure_follower(cfg, ident):
|
||||
"""Create dialect-specific config settings for a follower database."""
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def run_reap_dbs(url, ident):
|
||||
"""Remove databases that were created during the test process, after the
|
||||
process has ended.
|
||||
|
||||
This is an optional step that is invoked for certain backends that do not
|
||||
reliably release locks on the database as long as a process is still in
|
||||
use. For the internal dialects, this is currently only necessary for
|
||||
mssql and oracle.
|
||||
"""
|
||||
|
||||
|
||||
def reap_dbs(idents_file):
|
||||
log.info("Reaping databases...")
|
||||
|
||||
urls = collections.defaultdict(set)
|
||||
idents = collections.defaultdict(set)
|
||||
dialects = {}
|
||||
|
||||
with open(idents_file) as file_:
|
||||
for line in file_:
|
||||
line = line.strip()
|
||||
db_name, db_url = line.split(" ")
|
||||
url_obj = sa_url.make_url(db_url)
|
||||
if db_name not in dialects:
|
||||
dialects[db_name] = url_obj.get_dialect()
|
||||
dialects[db_name].load_provisioning()
|
||||
url_key = (url_obj.get_backend_name(), url_obj.host)
|
||||
urls[url_key].add(db_url)
|
||||
idents[url_key].add(db_name)
|
||||
|
||||
for url_key in urls:
|
||||
url = list(urls[url_key])[0]
|
||||
ident = idents[url_key]
|
||||
run_reap_dbs(url, ident)
|
||||
|
||||
|
||||
@register.init
|
||||
def temp_table_keyword_args(cfg, eng):
|
||||
"""Specify keyword arguments for creating a temporary Table.
|
||||
|
||||
Dialect-specific implementations of this method will return the
|
||||
kwargs that are passed to the Table method when creating a temporary
|
||||
table for testing, e.g., in the define_temp_tables method of the
|
||||
ComponentReflectionTest class in suite/test_reflection.py
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"no temp table keyword args routine for cfg: %s" % (eng.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def prepare_for_drop_tables(config, connection):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def stop_test_class_outside_fixtures(config, db, testcls):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def get_temp_table_name(cfg, eng, base_name):
|
||||
"""Specify table name for creating a temporary Table.
|
||||
|
||||
Dialect-specific implementations of this method will return the
|
||||
name to use when creating a temporary table for testing,
|
||||
e.g., in the define_temp_tables method of the
|
||||
ComponentReflectionTest class in suite/test_reflection.py
|
||||
|
||||
Default to just the base name since that's what most dialects will
|
||||
use. The mssql dialect's implementation will need a "#" prepended.
|
||||
"""
|
||||
return base_name
|
||||
|
||||
|
||||
@register.init
|
||||
def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
|
||||
raise NotImplementedError(
|
||||
"backend does not implement a schema name set function: %s"
|
||||
% (cfg.db.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def upsert(
|
||||
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
|
||||
):
|
||||
"""return the backends insert..on conflict / on dupe etc. construct.
|
||||
|
||||
while we should add a backend-neutral upsert construct as well, such as
|
||||
insert().upsert(), it's important that we continue to test the
|
||||
backend-specific insert() constructs since if we do implement
|
||||
insert().upsert(), that would be using a different codepath for the things
|
||||
we need to test like insertmanyvalues, etc.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"backend does not include an upsert implementation: {cfg.db.url}"
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def normalize_sequence(cfg, sequence):
|
||||
"""Normalize sequence parameters for dialect that don't start with 1
|
||||
by default.
|
||||
|
||||
The default implementation does nothing
|
||||
"""
|
||||
return sequence
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,224 @@
|
|||
# testing/schema.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from . import exclusions
|
||||
from .. import event
|
||||
from .. import schema
|
||||
from .. import types as sqltypes
|
||||
from ..orm import mapped_column as _orm_mapped_column
|
||||
from ..util import OrderedDict
|
||||
|
||||
__all__ = ["Table", "Column"]
|
||||
|
||||
table_options = {}
|
||||
|
||||
|
||||
def Table(*args, **kw) -> schema.Table:
|
||||
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
|
||||
|
||||
kw.update(table_options)
|
||||
|
||||
if exclusions.against(config._current, "mysql"):
|
||||
if (
|
||||
"mysql_engine" not in kw
|
||||
and "mysql_type" not in kw
|
||||
and "autoload_with" not in kw
|
||||
):
|
||||
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
|
||||
kw["mysql_engine"] = "InnoDB"
|
||||
else:
|
||||
# there are in fact test fixtures that rely upon MyISAM,
|
||||
# due to MySQL / MariaDB having poor FK behavior under innodb,
|
||||
# such as a self-referential table can't be deleted from at
|
||||
# once without attending to per-row dependencies. We'd need to
|
||||
# add special steps to some fixtures if we want to not
|
||||
# explicitly state MyISAM here
|
||||
kw["mysql_engine"] = "MyISAM"
|
||||
elif exclusions.against(config._current, "mariadb"):
|
||||
if (
|
||||
"mariadb_engine" not in kw
|
||||
and "mariadb_type" not in kw
|
||||
and "autoload_with" not in kw
|
||||
):
|
||||
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
|
||||
kw["mariadb_engine"] = "InnoDB"
|
||||
else:
|
||||
kw["mariadb_engine"] = "MyISAM"
|
||||
|
||||
return schema.Table(*args, **kw)
|
||||
|
||||
|
||||
def mapped_column(*args, **kw):
|
||||
"""An orm.mapped_column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
return _schema_column(_orm_mapped_column, args, kw)
|
||||
|
||||
|
||||
def Column(*args, **kw):
|
||||
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
return _schema_column(schema.Column, args, kw)
|
||||
|
||||
|
||||
def _schema_column(factory, args, kw):
|
||||
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
|
||||
|
||||
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
|
||||
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
|
||||
|
||||
construct = factory(*args, **kw)
|
||||
|
||||
if factory is schema.Column:
|
||||
col = construct
|
||||
else:
|
||||
col = construct.column
|
||||
|
||||
if test_opts.get("test_needs_autoincrement", False) and kw.get(
|
||||
"primary_key", False
|
||||
):
|
||||
if col.default is None and col.server_default is None:
|
||||
col.autoincrement = True
|
||||
|
||||
# allow any test suite to pick up on this
|
||||
col.info["test_needs_autoincrement"] = True
|
||||
|
||||
# hardcoded rule for oracle; this should
|
||||
# be moved out
|
||||
if exclusions.against(config._current, "oracle"):
|
||||
|
||||
def add_seq(c, tbl):
|
||||
c._init_items(
|
||||
schema.Sequence(
|
||||
_truncate_name(
|
||||
config.db.dialect, tbl.name + "_" + c.name + "_seq"
|
||||
),
|
||||
optional=True,
|
||||
)
|
||||
)
|
||||
|
||||
event.listen(col, "after_parent_attach", add_seq, propagate=True)
|
||||
return construct
|
||||
|
||||
|
||||
class eq_type_affinity:
|
||||
"""Helper to compare types inside of datastructures based on affinity.
|
||||
|
||||
E.g.::
|
||||
|
||||
eq_(
|
||||
inspect(connection).get_columns("foo"),
|
||||
[
|
||||
{
|
||||
"name": "id",
|
||||
"type": testing.eq_type_affinity(sqltypes.INTEGER),
|
||||
"nullable": False,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
},
|
||||
{
|
||||
"name": "data",
|
||||
"type": testing.eq_type_affinity(sqltypes.NullType),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = sqltypes.to_instance(target)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target._type_affinity is other._type_affinity
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.target._type_affinity is not other._type_affinity
|
||||
|
||||
|
||||
class eq_compile_type:
|
||||
"""similar to eq_type_affinity but uses compile"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = target
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target == other.compile()
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.target != other.compile()
|
||||
|
||||
|
||||
class eq_clause_element:
|
||||
"""Helper to compare SQL structures based on compare()"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = target
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target.compare(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.target.compare(other)
|
||||
|
||||
|
||||
def _truncate_name(dialect, name):
|
||||
if len(name) > dialect.max_identifier_length:
|
||||
return (
|
||||
name[0 : max(dialect.max_identifier_length - 6, 0)]
|
||||
+ "_"
|
||||
+ hex(hash(name) % 64)[2:]
|
||||
)
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
def pep435_enum(name):
|
||||
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
|
||||
__members__ = OrderedDict()
|
||||
|
||||
def __init__(self, name, value, alias=None):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.__members__[name] = self
|
||||
value_to_member[value] = self
|
||||
setattr(self.__class__, name, self)
|
||||
if alias:
|
||||
self.__members__[alias] = self
|
||||
setattr(self.__class__, alias, self)
|
||||
|
||||
value_to_member = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, value):
|
||||
return value_to_member[value]
|
||||
|
||||
someenum = type(
|
||||
name,
|
||||
(object,),
|
||||
{"__members__": __members__, "__init__": __init__, "get": get},
|
||||
)
|
||||
|
||||
# getframe() trick for pickling I don't understand courtesy
|
||||
# Python namedtuple()
|
||||
try:
|
||||
module = sys._getframe(1).f_globals.get("__name__", "__main__")
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
if module is not None:
|
||||
someenum.__module__ = module
|
||||
|
||||
return someenum
|
|
@ -0,0 +1,19 @@
|
|||
# testing/suite/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from .test_cte import * # noqa
|
||||
from .test_ddl import * # noqa
|
||||
from .test_deprecations import * # noqa
|
||||
from .test_dialect import * # noqa
|
||||
from .test_insert import * # noqa
|
||||
from .test_reflection import * # noqa
|
||||
from .test_results import * # noqa
|
||||
from .test_rowcount import * # noqa
|
||||
from .test_select import * # noqa
|
||||
from .test_sequence import * # noqa
|
||||
from .test_types import * # noqa
|
||||
from .test_unicode_ddl import * # noqa
|
||||
from .test_update_delete import * # noqa
|
|
@ -0,0 +1,211 @@
|
|||
# testing/suite/test_cte.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import ForeignKey
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class CTETest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
__requires__ = ("ctes",)
|
||||
|
||||
run_inserts = "each"
|
||||
run_deletes = "each"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
Column("parent_id", ForeignKey("some_table.id")),
|
||||
)
|
||||
|
||||
Table(
|
||||
"some_other_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
Column("parent_id", Integer),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1", "parent_id": None},
|
||||
{"id": 2, "data": "d2", "parent_id": 1},
|
||||
{"id": 3, "data": "d3", "parent_id": 1},
|
||||
{"id": 4, "data": "d4", "parent_id": 3},
|
||||
{"id": 5, "data": "d5", "parent_id": 3},
|
||||
],
|
||||
)
|
||||
|
||||
def test_select_nonrecursive_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
result = connection.execute(
|
||||
select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
|
||||
)
|
||||
eq_(result.fetchall(), [("d4",)])
|
||||
|
||||
def test_select_recursive_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte", recursive=True)
|
||||
)
|
||||
|
||||
cte_alias = cte.alias("c1")
|
||||
st1 = some_table.alias()
|
||||
# note that SQL Server requires this to be UNION ALL,
|
||||
# can't be UNION
|
||||
cte = cte.union_all(
|
||||
select(st1).where(st1.c.id == cte_alias.c.parent_id)
|
||||
)
|
||||
result = connection.execute(
|
||||
select(cte.c.data)
|
||||
.where(cte.c.data != "d2")
|
||||
.order_by(cte.c.data.desc())
|
||||
)
|
||||
eq_(
|
||||
result.fetchall(),
|
||||
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
|
||||
)
|
||||
|
||||
def test_insert_from_select_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(cte)
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
@testing.requires.update_from
|
||||
def test_update_from_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.update()
|
||||
.values(parent_id=5)
|
||||
.where(some_other_table.c.data == cte.c.data)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[
|
||||
(1, "d1", None),
|
||||
(2, "d2", 5),
|
||||
(3, "d3", 5),
|
||||
(4, "d4", 5),
|
||||
(5, "d5", 3),
|
||||
],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
@testing.requires.delete_from
|
||||
def test_delete_from_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.delete().where(
|
||||
some_other_table.c.data == cte.c.data
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "d1", None), (5, "d5", 3)],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
def test_delete_scalar_subq_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.delete().where(
|
||||
some_other_table.c.data
|
||||
== select(cte.c.data)
|
||||
.where(cte.c.id == some_other_table.c.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "d1", None), (5, "d5", 3)],
|
||||
)
|
|
@ -0,0 +1,389 @@
|
|||
# testing/suite/test_ddl.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import random
|
||||
|
||||
from . import testing
|
||||
from .. import config
|
||||
from .. import fixtures
|
||||
from .. import util
|
||||
from ..assertions import eq_
|
||||
from ..assertions import is_false
|
||||
from ..assertions import is_true
|
||||
from ..config import requirements
|
||||
from ..schema import Table
|
||||
from ... import CheckConstraint
|
||||
from ... import Column
|
||||
from ... import ForeignKeyConstraint
|
||||
from ... import Index
|
||||
from ... import inspect
|
||||
from ... import Integer
|
||||
from ... import schema
|
||||
from ... import String
|
||||
from ... import UniqueConstraint
|
||||
|
||||
|
||||
class TableDDLTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def _simple_fixture(self, schema=None):
|
||||
return Table(
|
||||
"test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
def _underscore_fixture(self):
|
||||
return Table(
|
||||
"_test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("_data", String(50)),
|
||||
)
|
||||
|
||||
def _table_index_fixture(self, schema=None):
|
||||
table = self._simple_fixture(schema=schema)
|
||||
idx = Index("test_index", table.c.data)
|
||||
return table, idx
|
||||
|
||||
def _simple_roundtrip(self, table):
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(table.insert().values((1, "some data")))
|
||||
result = conn.execute(table.select())
|
||||
eq_(result.first(), (1, "some data"))
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_create_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.create_table
|
||||
@requirements.schemas
|
||||
@util.provide_metadata
|
||||
def test_create_table_schema(self):
|
||||
table = self._simple_fixture(schema=config.test_schema)
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.drop_table
|
||||
@util.provide_metadata
|
||||
def test_drop_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
table.drop(config.db, checkfirst=False)
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_underscore_names(self):
|
||||
table = self._underscore_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.comment_reflection
|
||||
@util.provide_metadata
|
||||
def test_add_table_comment(self, connection):
|
||||
table = self._simple_fixture()
|
||||
table.create(connection, checkfirst=False)
|
||||
table.comment = "a comment"
|
||||
connection.execute(schema.SetTableComment(table))
|
||||
eq_(
|
||||
inspect(connection).get_table_comment("test_table"),
|
||||
{"text": "a comment"},
|
||||
)
|
||||
|
||||
@requirements.comment_reflection
|
||||
@util.provide_metadata
|
||||
def test_drop_table_comment(self, connection):
|
||||
table = self._simple_fixture()
|
||||
table.create(connection, checkfirst=False)
|
||||
table.comment = "a comment"
|
||||
connection.execute(schema.SetTableComment(table))
|
||||
connection.execute(schema.DropTableComment(table))
|
||||
eq_(
|
||||
inspect(connection).get_table_comment("test_table"), {"text": None}
|
||||
)
|
||||
|
||||
@requirements.table_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_create_table_if_not_exists(self, connection):
|
||||
table = self._simple_fixture()
|
||||
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
|
||||
@requirements.index_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_create_index_if_not_exists(self, connection):
|
||||
table, idx = self._table_index_fixture()
|
||||
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
is_false(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
|
||||
|
||||
is_true(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
|
||||
|
||||
@requirements.table_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_drop_table_if_exists(self, connection):
|
||||
table = self._simple_fixture()
|
||||
|
||||
table.create(connection)
|
||||
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
|
||||
connection.execute(schema.DropTable(table, if_exists=True))
|
||||
|
||||
is_false(inspect(connection).has_table("test_table"))
|
||||
|
||||
connection.execute(schema.DropTable(table, if_exists=True))
|
||||
|
||||
@requirements.index_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_drop_index_if_exists(self, connection):
|
||||
table, idx = self._table_index_fixture()
|
||||
|
||||
table.create(connection)
|
||||
|
||||
is_true(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.DropIndex(idx, if_exists=True))
|
||||
|
||||
is_false(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.DropIndex(idx, if_exists=True))
|
||||
|
||||
|
||||
class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
|
||||
pass
|
||||
|
||||
|
||||
class LongNameBlowoutTest(fixtures.TestBase):
|
||||
"""test the creation of a variety of DDL structures and ensure
|
||||
label length limits pass on backends
|
||||
|
||||
"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def fk(self, metadata, connection):
|
||||
convention = {
|
||||
"fk": "foreign_key_%(table_name)s_"
|
||||
"%(column_0_N_name)s_"
|
||||
"%(referred_table_name)s_"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(20))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
cons = ForeignKeyConstraint(
|
||||
["aid"], ["a_things_with_stuff.id_long_column_name"]
|
||||
)
|
||||
Table(
|
||||
"b_related_things_of_value",
|
||||
metadata,
|
||||
Column(
|
||||
"aid",
|
||||
),
|
||||
cons,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
|
||||
if testing.requires.foreign_key_constraint_name_reflection.enabled:
|
||||
insp = inspect(connection)
|
||||
fks = insp.get_foreign_keys("b_related_things_of_value")
|
||||
reflected_name = fks[0]["name"]
|
||||
|
||||
return actual_name, reflected_name
|
||||
else:
|
||||
return actual_name, None
|
||||
|
||||
def pk(self, metadata, connection):
|
||||
convention = {
|
||||
"pk": "primary_key_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
a = Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer, primary_key=True),
|
||||
)
|
||||
cons = a.primary_key
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
pk = insp.get_pk_constraint("a_things_with_stuff")
|
||||
reflected_name = pk["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def ix(self, metadata, connection):
|
||||
convention = {
|
||||
"ix": "index_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
a = Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer),
|
||||
)
|
||||
cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
ix = insp.get_indexes("a_things_with_stuff")
|
||||
reflected_name = ix[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def uq(self, metadata, connection):
|
||||
convention = {
|
||||
"uq": "unique_constraint_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer),
|
||||
cons,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
uq = insp.get_unique_constraints("a_things_with_stuff")
|
||||
reflected_name = uq[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def ck(self, metadata, connection):
|
||||
convention = {
|
||||
"ck": "check_constraint_%(table_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
cons = CheckConstraint("some_long_column_name > 5")
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("some_long_column_name", Integer),
|
||||
cons,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
ck = insp.get_check_constraints("a_things_with_stuff")
|
||||
reflected_name = ck[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
@testing.combinations(
|
||||
("fk",),
|
||||
("pk",),
|
||||
("ix",),
|
||||
("ck", testing.requires.check_constraint_reflection.as_skips()),
|
||||
("uq", testing.requires.unique_constraint_reflection.as_skips()),
|
||||
argnames="type_",
|
||||
)
|
||||
def test_long_convention_name(self, type_, metadata, connection):
|
||||
actual_name, reflected_name = getattr(self, type_)(
|
||||
metadata, connection
|
||||
)
|
||||
|
||||
assert len(actual_name) > 255
|
||||
|
||||
if reflected_name is not None:
|
||||
overlap = actual_name[0 : len(reflected_name)]
|
||||
if len(overlap) < len(actual_name):
|
||||
eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
|
||||
else:
|
||||
eq_(overlap, reflected_name)
|
||||
|
||||
|
||||
__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
|
|
@ -0,0 +1,153 @@
|
|||
# testing/suite/test_deprecations.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import testing
|
||||
from ... import union
|
||||
|
||||
|
||||
class DeprecatedCompoundSelectTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2},
|
||||
{"id": 2, "x": 2, "y": 3},
|
||||
{"id": 3, "x": 3, "y": 4},
|
||||
{"id": 4, "x": 4, "y": 5},
|
||||
],
|
||||
)
|
||||
|
||||
def _assert_result(self, conn, select, result, params=()):
|
||||
eq_(conn.execute(select, params).fetchall(), result)
|
||||
|
||||
def test_plain_union(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2)
|
||||
s2 = select(table).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
# note we've had to remove one use case entirely, which is this
|
||||
# one. the Select gets its FROMS from the WHERE clause and the
|
||||
# columns clause, but not the ORDER BY, which means the old ".c" system
|
||||
# allowed you to "order_by(s.c.foo)" to get an unnamed column in the
|
||||
# ORDER BY without adding the SELECT into the FROM and breaking the
|
||||
# query. Users will have to adjust for this use case if they were doing
|
||||
# it before.
|
||||
def _dont_test_select_from_plain_union(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2)
|
||||
s2 = select(table).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2).alias().select()
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.order_by_col_from_union
|
||||
@testing.requires.parens_in_union_contained_select_w_limit_offset
|
||||
def test_limit_offset_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
|
||||
s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.parens_in_union_contained_select_wo_limit_offset
|
||||
def test_order_by_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
|
||||
s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_distinct_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).distinct()
|
||||
s2 = select(table).where(table.c.id == 3).distinct()
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_limit_offset_aliased_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = (
|
||||
select(table)
|
||||
.where(table.c.id == 2)
|
||||
.limit(1)
|
||||
.order_by(table.c.id)
|
||||
.alias()
|
||||
.select()
|
||||
)
|
||||
s2 = (
|
||||
select(table)
|
||||
.where(table.c.id == 3)
|
||||
.limit(1)
|
||||
.order_by(table.c.id)
|
||||
.alias()
|
||||
.select()
|
||||
)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
|
@ -0,0 +1,740 @@
|
|||
# testing/suite/test_dialect.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import importlib
|
||||
|
||||
from . import testing
|
||||
from .. import assert_raises
|
||||
from .. import config
|
||||
from .. import engines
|
||||
from .. import eq_
|
||||
from .. import fixtures
|
||||
from .. import is_not_none
|
||||
from .. import is_true
|
||||
from .. import ne_
|
||||
from .. import provide_metadata
|
||||
from ..assertions import expect_raises
|
||||
from ..assertions import expect_raises_message
|
||||
from ..config import requirements
|
||||
from ..provision import set_default_schema_on_connection
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import bindparam
|
||||
from ... import dialects
|
||||
from ... import event
|
||||
from ... import exc
|
||||
from ... import Integer
|
||||
from ... import literal_column
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ...sql.compiler import Compiled
|
||||
from ...util import inspect_getfullargspec
|
||||
|
||||
|
||||
class PingTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def test_do_ping(self):
|
||||
with testing.db.connect() as conn:
|
||||
is_true(
|
||||
testing.db.dialect.do_ping(conn.connection.dbapi_connection)
|
||||
)
|
||||
|
||||
|
||||
class ArgSignatureTest(fixtures.TestBase):
|
||||
"""test that all visit_XYZ() in :class:`_sql.Compiler` subclasses have
|
||||
``**kw``, for #8988.
|
||||
|
||||
This test uses runtime code inspection. Does not need to be a
|
||||
``__backend__`` test as it only needs to run once provided all target
|
||||
dialects have been imported.
|
||||
|
||||
For third party dialects, the suite would be run with that third
|
||||
party as a "--dburi", which means its compiler classes will have been
|
||||
imported by the time this test runs.
|
||||
|
||||
"""
|
||||
|
||||
def _all_subclasses(): # type: ignore # noqa
|
||||
for d in dialects.__all__:
|
||||
if not d.startswith("_"):
|
||||
importlib.import_module("sqlalchemy.dialects.%s" % d)
|
||||
|
||||
stack = [Compiled]
|
||||
|
||||
while stack:
|
||||
cls = stack.pop(0)
|
||||
stack.extend(cls.__subclasses__())
|
||||
yield cls
|
||||
|
||||
@testing.fixture(params=list(_all_subclasses()))
|
||||
def all_subclasses(self, request):
|
||||
yield request.param
|
||||
|
||||
def test_all_visit_methods_accept_kw(self, all_subclasses):
|
||||
cls = all_subclasses
|
||||
|
||||
for k in cls.__dict__:
|
||||
if k.startswith("visit_"):
|
||||
meth = getattr(cls, k)
|
||||
|
||||
insp = inspect_getfullargspec(meth)
|
||||
is_not_none(
|
||||
insp.varkw,
|
||||
f"Compiler visit method {cls.__name__}.{k}() does "
|
||||
"not accommodate for **kw in its argument signature",
|
||||
)
|
||||
|
||||
|
||||
class ExceptionTest(fixtures.TablesTest):
|
||||
"""Test basic exception wrapping.
|
||||
|
||||
DBAPIs vary a lot in exception behavior so to actually anticipate
|
||||
specific exceptions from real round trips, we need to be conservative.
|
||||
|
||||
"""
|
||||
|
||||
run_deletes = "each"
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@requirements.duplicate_key_raises_integrity_error
|
||||
def test_integrity_error(self):
|
||||
with config.db.connect() as conn:
|
||||
trans = conn.begin()
|
||||
conn.execute(
|
||||
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
|
||||
)
|
||||
|
||||
assert_raises(
|
||||
exc.IntegrityError,
|
||||
conn.execute,
|
||||
self.tables.manual_pk.insert(),
|
||||
{"id": 1, "data": "d1"},
|
||||
)
|
||||
|
||||
trans.rollback()
|
||||
|
||||
def test_exception_with_non_ascii(self):
|
||||
with config.db.connect() as conn:
|
||||
try:
|
||||
# try to create an error message that likely has non-ascii
|
||||
# characters in the DBAPI's message string. unfortunately
|
||||
# there's no way to make this happen with some drivers like
|
||||
# mysqlclient, pymysql. this at least does produce a non-
|
||||
# ascii error message for cx_oracle, psycopg2
|
||||
conn.execute(select(literal_column("méil")))
|
||||
assert False
|
||||
except exc.DBAPIError as err:
|
||||
err_str = str(err)
|
||||
|
||||
assert str(err.orig) in str(err)
|
||||
|
||||
assert isinstance(err_str, str)
|
||||
|
||||
|
||||
class IsolationLevelTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = ("isolation_level",)
|
||||
|
||||
def _get_non_default_isolation_level(self):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
default = levels["default"]
|
||||
supported = levels["supported"]
|
||||
|
||||
s = set(supported).difference(["AUTOCOMMIT", default])
|
||||
if s:
|
||||
return s.pop()
|
||||
else:
|
||||
config.skip_test("no non-default isolation level available")
|
||||
|
||||
def test_default_isolation_level(self):
|
||||
eq_(
|
||||
config.db.dialect.default_isolation_level,
|
||||
requirements.get_isolation_levels(config)["default"],
|
||||
)
|
||||
|
||||
def test_non_default_isolation_level(self):
|
||||
non_default = self._get_non_default_isolation_level()
|
||||
|
||||
with config.db.connect() as conn:
|
||||
existing = conn.get_isolation_level()
|
||||
|
||||
ne_(existing, non_default)
|
||||
|
||||
conn.execution_options(isolation_level=non_default)
|
||||
|
||||
eq_(conn.get_isolation_level(), non_default)
|
||||
|
||||
conn.dialect.reset_isolation_level(
|
||||
conn.connection.dbapi_connection
|
||||
)
|
||||
|
||||
eq_(conn.get_isolation_level(), existing)
|
||||
|
||||
def test_all_levels(self):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
all_levels = levels["supported"]
|
||||
|
||||
for level in set(all_levels).difference(["AUTOCOMMIT"]):
|
||||
with config.db.connect() as conn:
|
||||
conn.execution_options(isolation_level=level)
|
||||
|
||||
eq_(conn.get_isolation_level(), level)
|
||||
|
||||
trans = conn.begin()
|
||||
trans.rollback()
|
||||
|
||||
eq_(conn.get_isolation_level(), level)
|
||||
|
||||
with config.db.connect() as conn:
|
||||
eq_(
|
||||
conn.get_isolation_level(),
|
||||
levels["default"],
|
||||
)
|
||||
|
||||
@testing.requires.get_isolation_level_values
|
||||
def test_invalid_level_execution_option(self, connection_no_trans):
|
||||
"""test for the new get_isolation_level_values() method"""
|
||||
|
||||
connection = connection_no_trans
|
||||
with expect_raises_message(
|
||||
exc.ArgumentError,
|
||||
"Invalid value '%s' for isolation_level. "
|
||||
"Valid isolation levels for '%s' are %s"
|
||||
% (
|
||||
"FOO",
|
||||
connection.dialect.name,
|
||||
", ".join(
|
||||
requirements.get_isolation_levels(config)["supported"]
|
||||
),
|
||||
),
|
||||
):
|
||||
connection.execution_options(isolation_level="FOO")
|
||||
|
||||
@testing.requires.get_isolation_level_values
|
||||
@testing.requires.dialect_level_isolation_level_param
|
||||
def test_invalid_level_engine_param(self, testing_engine):
|
||||
"""test for the new get_isolation_level_values() method
|
||||
and support for the dialect-level 'isolation_level' parameter.
|
||||
|
||||
"""
|
||||
|
||||
eng = testing_engine(options=dict(isolation_level="FOO"))
|
||||
with expect_raises_message(
|
||||
exc.ArgumentError,
|
||||
"Invalid value '%s' for isolation_level. "
|
||||
"Valid isolation levels for '%s' are %s"
|
||||
% (
|
||||
"FOO",
|
||||
eng.dialect.name,
|
||||
", ".join(
|
||||
requirements.get_isolation_levels(config)["supported"]
|
||||
),
|
||||
),
|
||||
):
|
||||
eng.connect()
|
||||
|
||||
@testing.requires.independent_readonly_connections
|
||||
def test_dialect_user_setting_is_restored(self, testing_engine):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
default = levels["default"]
|
||||
supported = (
|
||||
sorted(
|
||||
set(levels["supported"]).difference([default, "AUTOCOMMIT"])
|
||||
)
|
||||
)[0]
|
||||
|
||||
e = testing_engine(options={"isolation_level": supported})
|
||||
|
||||
with e.connect() as conn:
|
||||
eq_(conn.get_isolation_level(), supported)
|
||||
|
||||
with e.connect() as conn:
|
||||
conn.execution_options(isolation_level=default)
|
||||
eq_(conn.get_isolation_level(), default)
|
||||
|
||||
with e.connect() as conn:
|
||||
eq_(conn.get_isolation_level(), supported)
|
||||
|
||||
|
||||
class AutocommitIsolationTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
|
||||
__requires__ = ("autocommit",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
test_needs_acid=True,
|
||||
)
|
||||
|
||||
def _test_conn_autocommits(self, conn, autocommit):
|
||||
trans = conn.begin()
|
||||
conn.execute(
|
||||
self.tables.some_table.insert(), {"id": 1, "data": "some data"}
|
||||
)
|
||||
trans.rollback()
|
||||
|
||||
eq_(
|
||||
conn.scalar(select(self.tables.some_table.c.id)),
|
||||
1 if autocommit else None,
|
||||
)
|
||||
conn.rollback()
|
||||
|
||||
with conn.begin():
|
||||
conn.execute(self.tables.some_table.delete())
|
||||
|
||||
def test_autocommit_on(self, connection_no_trans):
|
||||
conn = connection_no_trans
|
||||
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
self._test_conn_autocommits(c2, True)
|
||||
|
||||
c2.dialect.reset_isolation_level(c2.connection.dbapi_connection)
|
||||
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
def test_autocommit_off(self, connection_no_trans):
|
||||
conn = connection_no_trans
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
def test_turn_autocommit_off_via_default_iso_level(
|
||||
self, connection_no_trans
|
||||
):
|
||||
conn = connection_no_trans
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
conn.execution_options(
|
||||
isolation_level=requirements.get_isolation_levels(config)[
|
||||
"default"
|
||||
]
|
||||
)
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
@testing.requires.independent_readonly_connections
|
||||
@testing.variation("use_dialect_setting", [True, False])
|
||||
def test_dialect_autocommit_is_restored(
|
||||
self, testing_engine, use_dialect_setting
|
||||
):
|
||||
"""test #10147"""
|
||||
|
||||
if use_dialect_setting:
|
||||
e = testing_engine(options={"isolation_level": "AUTOCOMMIT"})
|
||||
else:
|
||||
e = testing_engine().execution_options(
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
default = levels["default"]
|
||||
|
||||
with e.connect() as conn:
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
with e.connect() as conn:
|
||||
conn.execution_options(isolation_level=default)
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
with e.connect() as conn:
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
|
||||
class EscapingTest(fixtures.TestBase):
|
||||
@provide_metadata
|
||||
def test_percent_sign_round_trip(self):
|
||||
"""test that the DBAPI accommodates for escaped / nonescaped
|
||||
percent signs in a way that matches the compiler
|
||||
|
||||
"""
|
||||
m = self.metadata
|
||||
t = Table("t", m, Column("data", String(50)))
|
||||
t.create(config.db)
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(t.insert(), dict(data="some % value"))
|
||||
conn.execute(t.insert(), dict(data="some %% other value"))
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select(t.c.data).where(
|
||||
t.c.data == literal_column("'some % value'")
|
||||
)
|
||||
),
|
||||
"some % value",
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select(t.c.data).where(
|
||||
t.c.data == literal_column("'some %% other value'")
|
||||
)
|
||||
),
|
||||
"some %% other value",
|
||||
)
|
||||
|
||||
|
||||
class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = ("default_schema_name_switch",)
|
||||
|
||||
def test_control_case(self):
|
||||
default_schema_name = config.db.dialect.default_schema_name
|
||||
|
||||
eng = engines.testing_engine()
|
||||
with eng.connect():
|
||||
pass
|
||||
|
||||
eq_(eng.dialect.default_schema_name, default_schema_name)
|
||||
|
||||
def test_wont_work_wo_insert(self):
|
||||
default_schema_name = config.db.dialect.default_schema_name
|
||||
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect")
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, default_schema_name)
|
||||
|
||||
def test_schema_change_on_connect(self):
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect", insert=True)
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, config.test_schema)
|
||||
|
||||
def test_schema_change_works_w_transactions(self):
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect", insert=True)
|
||||
def on_connect(dbapi_connection, *arg):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
trans = conn.begin()
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
trans.rollback()
|
||||
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, config.test_schema)
|
||||
|
||||
|
||||
class FutureWeCanSetDefaultSchemaWEventsTest(
|
||||
fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class DifficultParametersTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
tough_parameters = testing.combinations(
|
||||
("boring",),
|
||||
("per cent",),
|
||||
("per % cent",),
|
||||
("%percent",),
|
||||
("par(ens)",),
|
||||
("percent%(ens)yah",),
|
||||
("col:ons",),
|
||||
("_starts_with_underscore",),
|
||||
("dot.s",),
|
||||
("more :: %colons%",),
|
||||
("_name",),
|
||||
("___name",),
|
||||
("[BracketsAndCase]",),
|
||||
("42numbers",),
|
||||
("percent%signs",),
|
||||
("has spaces",),
|
||||
("/slashes/",),
|
||||
("more/slashes",),
|
||||
("q?marks",),
|
||||
("1param",),
|
||||
("1col:on",),
|
||||
argnames="paramname",
|
||||
)
|
||||
|
||||
@tough_parameters
|
||||
@config.requirements.unusual_column_name_characters
|
||||
def test_round_trip_same_named_column(
|
||||
self, paramname, connection, metadata
|
||||
):
|
||||
name = paramname
|
||||
|
||||
t = Table(
|
||||
"t",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column(name, String(50), nullable=False),
|
||||
)
|
||||
|
||||
# table is created
|
||||
t.create(connection)
|
||||
|
||||
# automatic param generated by insert
|
||||
connection.execute(t.insert().values({"id": 1, name: "some name"}))
|
||||
|
||||
# automatic param generated by criteria, plus selecting the column
|
||||
stmt = select(t.c[name]).where(t.c[name] == "some name")
|
||||
|
||||
eq_(connection.scalar(stmt), "some name")
|
||||
|
||||
# use the name in a param explicitly
|
||||
stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
|
||||
|
||||
row = connection.execute(stmt, {name: "some name"}).first()
|
||||
|
||||
# name works as the key from cursor.description
|
||||
eq_(row._mapping[name], "some name")
|
||||
|
||||
# use expanding IN
|
||||
stmt = select(t.c[name]).where(
|
||||
t.c[name].in_(["some name", "some other_name"])
|
||||
)
|
||||
|
||||
row = connection.execute(stmt).first()
|
||||
|
||||
@testing.fixture
|
||||
def multirow_fixture(self, metadata, connection):
|
||||
mytable = Table(
|
||||
"mytable",
|
||||
metadata,
|
||||
Column("myid", Integer),
|
||||
Column("name", String(50)),
|
||||
Column("desc", String(50)),
|
||||
)
|
||||
|
||||
mytable.create(connection)
|
||||
|
||||
connection.execute(
|
||||
mytable.insert(),
|
||||
[
|
||||
{"myid": 1, "name": "a", "desc": "a_desc"},
|
||||
{"myid": 2, "name": "b", "desc": "b_desc"},
|
||||
{"myid": 3, "name": "c", "desc": "c_desc"},
|
||||
{"myid": 4, "name": "d", "desc": "d_desc"},
|
||||
],
|
||||
)
|
||||
yield mytable
|
||||
|
||||
@tough_parameters
|
||||
def test_standalone_bindparam_escape(
|
||||
self, paramname, connection, multirow_fixture
|
||||
):
|
||||
tbl1 = multirow_fixture
|
||||
stmt = select(tbl1.c.myid).where(
|
||||
tbl1.c.name == bindparam(paramname, value="x")
|
||||
)
|
||||
res = connection.scalar(stmt, {paramname: "c"})
|
||||
eq_(res, 3)
|
||||
|
||||
@tough_parameters
|
||||
def test_standalone_bindparam_escape_expanding(
|
||||
self, paramname, connection, multirow_fixture
|
||||
):
|
||||
tbl1 = multirow_fixture
|
||||
stmt = (
|
||||
select(tbl1.c.myid)
|
||||
.where(tbl1.c.name.in_(bindparam(paramname, value=["a", "b"])))
|
||||
.order_by(tbl1.c.myid)
|
||||
)
|
||||
|
||||
res = connection.scalars(stmt, {paramname: ["d", "a"]}).all()
|
||||
eq_(res, [1, 4])
|
||||
|
||||
|
||||
class ReturningGuardsTest(fixtures.TablesTest):
|
||||
"""test that the various 'returning' flags are set appropriately"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"t",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@testing.fixture
|
||||
def run_stmt(self, connection):
|
||||
t = self.tables.t
|
||||
|
||||
def go(stmt, executemany, id_param_name, expect_success):
|
||||
stmt = stmt.returning(t.c.id)
|
||||
|
||||
if executemany:
|
||||
if not expect_success:
|
||||
# for RETURNING executemany(), we raise our own
|
||||
# error as this is independent of general RETURNING
|
||||
# support
|
||||
with expect_raises_message(
|
||||
exc.StatementError,
|
||||
rf"Dialect {connection.dialect.name}\+"
|
||||
f"{connection.dialect.driver} with "
|
||||
f"current server capabilities does not support "
|
||||
f".*RETURNING when executemany is used",
|
||||
):
|
||||
result = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{id_param_name: 1, "data": "d1"},
|
||||
{id_param_name: 2, "data": "d2"},
|
||||
{id_param_name: 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
else:
|
||||
result = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{id_param_name: 1, "data": "d1"},
|
||||
{id_param_name: 2, "data": "d2"},
|
||||
{id_param_name: 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
eq_(result.all(), [(1,), (2,), (3,)])
|
||||
else:
|
||||
if not expect_success:
|
||||
# for RETURNING execute(), we pass all the way to the DB
|
||||
# and let it fail
|
||||
with expect_raises(exc.DBAPIError):
|
||||
connection.execute(
|
||||
stmt, {id_param_name: 1, "data": "d1"}
|
||||
)
|
||||
else:
|
||||
result = connection.execute(
|
||||
stmt, {id_param_name: 1, "data": "d1"}
|
||||
)
|
||||
eq_(result.all(), [(1,)])
|
||||
|
||||
return go
|
||||
|
||||
def test_insert_single(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
stmt = t.insert()
|
||||
|
||||
run_stmt(stmt, False, "id", connection.dialect.insert_returning)
|
||||
|
||||
def test_insert_many(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
stmt = t.insert()
|
||||
|
||||
run_stmt(
|
||||
stmt, True, "id", connection.dialect.insert_executemany_returning
|
||||
)
|
||||
|
||||
def test_update_single(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.update().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(stmt, False, "b_id", connection.dialect.update_returning)
|
||||
|
||||
def test_update_many(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.update().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(
|
||||
stmt, True, "b_id", connection.dialect.update_executemany_returning
|
||||
)
|
||||
|
||||
def test_delete_single(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.delete().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(stmt, False, "b_id", connection.dialect.delete_returning)
|
||||
|
||||
def test_delete_many(self, connection, run_stmt):
|
||||
t = self.tables.t
|
||||
|
||||
connection.execute(
|
||||
t.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
stmt = t.delete().where(t.c.id == bindparam("b_id"))
|
||||
|
||||
run_stmt(
|
||||
stmt, True, "b_id", connection.dialect.delete_executemany_returning
|
||||
)
|
|
@ -0,0 +1,630 @@
|
|||
# testing/suite/test_insert.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from decimal import Decimal
|
||||
import uuid
|
||||
|
||||
from . import testing
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Double
|
||||
from ... import Float
|
||||
from ... import Identity
|
||||
from ... import Integer
|
||||
from ... import literal
|
||||
from ... import literal_column
|
||||
from ... import Numeric
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ...types import LargeBinary
|
||||
from ...types import UUID
|
||||
from ...types import Uuid
|
||||
|
||||
|
||||
class LastrowidTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = "implements_get_lastrowid", "autoincrement_insert"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(
|
||||
conn.dialect.default_sequence_base,
|
||||
"some data",
|
||||
),
|
||||
)
|
||||
|
||||
def test_autoincrement_on_insert(self, connection):
|
||||
connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, connection)
|
||||
|
||||
def test_last_inserted_id(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(r.inserted_primary_key, (pk,))
|
||||
|
||||
@requirements.dbapi_lastrowid
|
||||
def test_native_lastrowid_autoinc(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
lastrowid = r.lastrowid
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(lastrowid, pk)
|
||||
|
||||
|
||||
class InsertBehaviorTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"no_implicit_returning",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
Table(
|
||||
"includes_defaults",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
Column("x", Integer, default=5),
|
||||
Column(
|
||||
"y",
|
||||
Integer,
|
||||
default=literal_column("2", type_=Integer) + literal(2),
|
||||
),
|
||||
)
|
||||
|
||||
@testing.variation("style", ["plain", "return_defaults"])
|
||||
@testing.variation("executemany", [True, False])
|
||||
def test_no_results_for_non_returning_insert(
|
||||
self, connection, style, executemany
|
||||
):
|
||||
"""test another INSERT issue found during #10453"""
|
||||
|
||||
table = self.tables.no_implicit_returning
|
||||
|
||||
stmt = table.insert()
|
||||
if style.return_defaults:
|
||||
stmt = stmt.return_defaults()
|
||||
|
||||
if executemany:
|
||||
data = [
|
||||
{"data": "d1"},
|
||||
{"data": "d2"},
|
||||
{"data": "d3"},
|
||||
{"data": "d4"},
|
||||
{"data": "d5"},
|
||||
]
|
||||
else:
|
||||
data = {"data": "d1"}
|
||||
|
||||
r = connection.execute(stmt, data)
|
||||
assert not r.returns_rows
|
||||
|
||||
@requirements.autoincrement_insert
|
||||
def test_autoclose_on_insert(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
|
||||
# new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
|
||||
# an insert where the PK was taken from a row that the dialect
|
||||
# selected, as is the case for mssql/pyodbc, will still report
|
||||
# returns_rows as true because there's a cursor description. in that
|
||||
# case, the row had to have been consumed at least.
|
||||
assert not r.returns_rows or r.fetchone() is None
|
||||
|
||||
@requirements.insert_returning
|
||||
def test_autoclose_on_insert_implicit_returning(self, connection):
|
||||
r = connection.execute(
|
||||
# return_defaults() ensures RETURNING will be used,
|
||||
# new in 2.0 as sqlite/mariadb offer both RETURNING and
|
||||
# cursor.lastrowid
|
||||
self.tables.autoinc_pk.insert().return_defaults(),
|
||||
dict(data="some data"),
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
|
||||
# note we are experimenting with having this be True
|
||||
# as of I8091919d45421e3f53029b8660427f844fee0228 .
|
||||
# implicit returning has fetched the row, but it still is a
|
||||
# "returns rows"
|
||||
assert r.returns_rows
|
||||
|
||||
# and we should be able to fetchone() on it, we just get no row
|
||||
eq_(r.fetchone(), None)
|
||||
|
||||
# and the keys, etc.
|
||||
eq_(r.keys(), ["id"])
|
||||
|
||||
# but the dialect took in the row already. not really sure
|
||||
# what the best behavior is.
|
||||
|
||||
@requirements.empty_inserts
|
||||
def test_empty_insert(self, connection):
|
||||
r = connection.execute(self.tables.autoinc_pk.insert())
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.select().where(
|
||||
self.tables.autoinc_pk.c.id != None
|
||||
)
|
||||
)
|
||||
eq_(len(r.all()), 1)
|
||||
|
||||
@requirements.empty_inserts_executemany
|
||||
def test_empty_insert_multiple(self, connection):
|
||||
r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.select().where(
|
||||
self.tables.autoinc_pk.c.id != None
|
||||
)
|
||||
)
|
||||
|
||||
eq_(len(r.all()), 3)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc(self, connection):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
connection.execute(
|
||||
src_table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
result = connection.execute(
|
||||
dest_table.insert().from_select(
|
||||
("data",),
|
||||
select(src_table.c.data).where(
|
||||
src_table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(result.inserted_primary_key, (None,))
|
||||
|
||||
result = connection.execute(
|
||||
select(dest_table.c.data).order_by(dest_table.c.data)
|
||||
)
|
||||
eq_(result.fetchall(), [("data2",), ("data3",)])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc_no_rows(self, connection):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
|
||||
result = connection.execute(
|
||||
dest_table.insert().from_select(
|
||||
("data",),
|
||||
select(src_table.c.data).where(
|
||||
src_table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
eq_(result.inserted_primary_key, (None,))
|
||||
|
||||
result = connection.execute(
|
||||
select(dest_table.c.data).order_by(dest_table.c.data)
|
||||
)
|
||||
|
||||
eq_(result.fetchall(), [])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select(self, connection):
|
||||
table = self.tables.manual_pk
|
||||
connection.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
("id", "data"),
|
||||
select(table.c.id + 5, table.c.data).where(
|
||||
table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(table.c.data).order_by(table.c.data)
|
||||
).fetchall(),
|
||||
[("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
|
||||
)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_with_defaults(self, connection):
|
||||
table = self.tables.includes_defaults
|
||||
connection.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
("id", "data"),
|
||||
select(table.c.id + 5, table.c.data).where(
|
||||
table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(table).order_by(table.c.data, table.c.id)
|
||||
).fetchall(),
|
||||
[
|
||||
(1, "data1", 5, 4),
|
||||
(2, "data2", 5, 4),
|
||||
(7, "data2", 5, 4),
|
||||
(3, "data3", 5, 4),
|
||||
(8, "data3", 5, 4),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ReturningTest(fixtures.TablesTest):
|
||||
run_create_tables = "each"
|
||||
__requires__ = "insert_returning", "autoincrement_insert"
|
||||
__backend__ = True
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(
|
||||
conn.dialect.default_sequence_base,
|
||||
"some data",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@requirements.fetch_rows_post_commit
|
||||
def test_explicit_returning_pk_autocommit(self, connection):
|
||||
table = self.tables.autoinc_pk
|
||||
r = connection.execute(
|
||||
table.insert().returning(table.c.id), dict(data="some data")
|
||||
)
|
||||
pk = r.first()[0]
|
||||
fetched_pk = connection.scalar(select(table.c.id))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_explicit_returning_pk_no_autocommit(self, connection):
|
||||
table = self.tables.autoinc_pk
|
||||
r = connection.execute(
|
||||
table.insert().returning(table.c.id), dict(data="some data")
|
||||
)
|
||||
|
||||
pk = r.first()[0]
|
||||
fetched_pk = connection.scalar(select(table.c.id))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_autoincrement_on_insert_implicit_returning(self, connection):
|
||||
connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, connection)
|
||||
|
||||
def test_last_inserted_id_implicit_returning(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(r.inserted_primary_key, (pk,))
|
||||
|
||||
@requirements.insert_executemany_returning
|
||||
def test_insertmanyvalues_returning(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert().returning(
|
||||
self.tables.autoinc_pk.c.id
|
||||
),
|
||||
[
|
||||
{"data": "d1"},
|
||||
{"data": "d2"},
|
||||
{"data": "d3"},
|
||||
{"data": "d4"},
|
||||
{"data": "d5"},
|
||||
],
|
||||
)
|
||||
rall = r.all()
|
||||
|
||||
pks = connection.execute(select(self.tables.autoinc_pk.c.id))
|
||||
|
||||
eq_(rall, pks.all())
|
||||
|
||||
@testing.combinations(
|
||||
(Double(), 8.5514716, True),
|
||||
(
|
||||
Double(53),
|
||||
8.5514716,
|
||||
True,
|
||||
testing.requires.float_or_double_precision_behaves_generically,
|
||||
),
|
||||
(Float(), 8.5514, True),
|
||||
(
|
||||
Float(8),
|
||||
8.5514,
|
||||
True,
|
||||
testing.requires.float_or_double_precision_behaves_generically,
|
||||
),
|
||||
(
|
||||
Numeric(precision=15, scale=12, asdecimal=False),
|
||||
8.5514716,
|
||||
True,
|
||||
testing.requires.literal_float_coercion,
|
||||
),
|
||||
(
|
||||
Numeric(precision=15, scale=12, asdecimal=True),
|
||||
Decimal("8.5514716"),
|
||||
False,
|
||||
),
|
||||
argnames="type_,value,do_rounding",
|
||||
)
|
||||
@testing.variation("sort_by_parameter_order", [True, False])
|
||||
@testing.variation("multiple_rows", [True, False])
|
||||
def test_insert_w_floats(
|
||||
self,
|
||||
connection,
|
||||
metadata,
|
||||
sort_by_parameter_order,
|
||||
type_,
|
||||
value,
|
||||
do_rounding,
|
||||
multiple_rows,
|
||||
):
|
||||
"""test #9701.
|
||||
|
||||
this tests insertmanyvalues as well as decimal / floating point
|
||||
RETURNING types
|
||||
|
||||
"""
|
||||
|
||||
t = Table(
|
||||
# Oracle backends seems to be getting confused if
|
||||
# this table is named the same as the one
|
||||
# in test_imv_returning_datatypes. use a different name
|
||||
"f_t",
|
||||
metadata,
|
||||
Column("id", Integer, Identity(), primary_key=True),
|
||||
Column("value", type_),
|
||||
)
|
||||
|
||||
t.create(connection)
|
||||
|
||||
result = connection.execute(
|
||||
t.insert().returning(
|
||||
t.c.id,
|
||||
t.c.value,
|
||||
sort_by_parameter_order=bool(sort_by_parameter_order),
|
||||
),
|
||||
(
|
||||
[{"value": value} for i in range(10)]
|
||||
if multiple_rows
|
||||
else {"value": value}
|
||||
),
|
||||
)
|
||||
|
||||
if multiple_rows:
|
||||
i_range = range(1, 11)
|
||||
else:
|
||||
i_range = range(1, 2)
|
||||
|
||||
# we want to test only that we are getting floating points back
|
||||
# with some degree of the original value maintained, that it is not
|
||||
# being truncated to an integer. there's too much variation in how
|
||||
# drivers return floats, which should not be relied upon to be
|
||||
# exact, for us to just compare as is (works for PG drivers but not
|
||||
# others) so we use rounding here. There's precedent for this
|
||||
# in suite/test_types.py::NumericTest as well
|
||||
|
||||
if do_rounding:
|
||||
eq_(
|
||||
{(id_, round(val_, 5)) for id_, val_ in result},
|
||||
{(id_, round(value, 5)) for id_ in i_range},
|
||||
)
|
||||
|
||||
eq_(
|
||||
{
|
||||
round(val_, 5)
|
||||
for val_ in connection.scalars(select(t.c.value))
|
||||
},
|
||||
{round(value, 5)},
|
||||
)
|
||||
else:
|
||||
eq_(
|
||||
set(result),
|
||||
{(id_, value) for id_ in i_range},
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(connection.scalars(select(t.c.value))),
|
||||
{value},
|
||||
)
|
||||
|
||||
@testing.combinations(
|
||||
(
|
||||
"non_native_uuid",
|
||||
Uuid(native_uuid=False),
|
||||
uuid.uuid4(),
|
||||
),
|
||||
(
|
||||
"non_native_uuid_str",
|
||||
Uuid(as_uuid=False, native_uuid=False),
|
||||
str(uuid.uuid4()),
|
||||
),
|
||||
(
|
||||
"generic_native_uuid",
|
||||
Uuid(native_uuid=True),
|
||||
uuid.uuid4(),
|
||||
testing.requires.uuid_data_type,
|
||||
),
|
||||
(
|
||||
"generic_native_uuid_str",
|
||||
Uuid(as_uuid=False, native_uuid=True),
|
||||
str(uuid.uuid4()),
|
||||
testing.requires.uuid_data_type,
|
||||
),
|
||||
("UUID", UUID(), uuid.uuid4(), testing.requires.uuid_data_type),
|
||||
(
|
||||
"LargeBinary1",
|
||||
LargeBinary(),
|
||||
b"this is binary",
|
||||
),
|
||||
("LargeBinary2", LargeBinary(), b"7\xe7\x9f"),
|
||||
argnames="type_,value",
|
||||
id_="iaa",
|
||||
)
|
||||
@testing.variation("sort_by_parameter_order", [True, False])
|
||||
@testing.variation("multiple_rows", [True, False])
|
||||
@testing.requires.insert_returning
|
||||
def test_imv_returning_datatypes(
|
||||
self,
|
||||
connection,
|
||||
metadata,
|
||||
sort_by_parameter_order,
|
||||
type_,
|
||||
value,
|
||||
multiple_rows,
|
||||
):
|
||||
"""test #9739, #9808 (similar to #9701).
|
||||
|
||||
this tests insertmanyvalues in conjunction with various datatypes.
|
||||
|
||||
These tests are particularly for the asyncpg driver which needs
|
||||
most types to be explicitly cast for the new IMV format
|
||||
|
||||
"""
|
||||
t = Table(
|
||||
"d_t",
|
||||
metadata,
|
||||
Column("id", Integer, Identity(), primary_key=True),
|
||||
Column("value", type_),
|
||||
)
|
||||
|
||||
t.create(connection)
|
||||
|
||||
result = connection.execute(
|
||||
t.insert().returning(
|
||||
t.c.id,
|
||||
t.c.value,
|
||||
sort_by_parameter_order=bool(sort_by_parameter_order),
|
||||
),
|
||||
(
|
||||
[{"value": value} for i in range(10)]
|
||||
if multiple_rows
|
||||
else {"value": value}
|
||||
),
|
||||
)
|
||||
|
||||
if multiple_rows:
|
||||
i_range = range(1, 11)
|
||||
else:
|
||||
i_range = range(1, 2)
|
||||
|
||||
eq_(
|
||||
set(result),
|
||||
{(id_, value) for id_ in i_range},
|
||||
)
|
||||
|
||||
eq_(
|
||||
set(connection.scalars(select(t.c.value))),
|
||||
{value},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,468 @@
|
|||
# testing/suite/test_results.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import datetime
|
||||
|
||||
from .. import engines
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import DateTime
|
||||
from ... import func
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import sql
|
||||
from ... import String
|
||||
from ... import testing
|
||||
from ... import text
|
||||
|
||||
|
||||
class RowFetchTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"plain_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"has_dates",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("today", DateTime),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
cls.tables.has_dates.insert(),
|
||||
[{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
|
||||
)
|
||||
|
||||
def test_via_attr(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row.id, 1)
|
||||
eq_(row.data, "d1")
|
||||
|
||||
def test_via_string(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row._mapping["id"], 1)
|
||||
eq_(row._mapping["data"], "d1")
|
||||
|
||||
def test_via_int(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row[0], 1)
|
||||
eq_(row[1], "d1")
|
||||
|
||||
def test_via_col_object(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row._mapping[self.tables.plain_pk.c.id], 1)
|
||||
eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
|
||||
|
||||
@requirements.duplicate_names_in_cursor_description
|
||||
def test_row_with_dupe_names(self, connection):
|
||||
result = connection.execute(
|
||||
select(
|
||||
self.tables.plain_pk.c.data,
|
||||
self.tables.plain_pk.c.data.label("data"),
|
||||
).order_by(self.tables.plain_pk.c.id)
|
||||
)
|
||||
row = result.first()
|
||||
eq_(result.keys(), ["data", "data"])
|
||||
eq_(row, ("d1", "d1"))
|
||||
|
||||
def test_row_w_scalar_select(self, connection):
|
||||
"""test that a scalar select as a column is returned as such
|
||||
and that type conversion works OK.
|
||||
|
||||
(this is half a SQLAlchemy Core test and half to catch database
|
||||
backends that may have unusual behavior with scalar selects.)
|
||||
|
||||
"""
|
||||
datetable = self.tables.has_dates
|
||||
s = select(datetable.alias("x").c.today).scalar_subquery()
|
||||
s2 = select(datetable.c.id, s.label("somelabel"))
|
||||
row = connection.execute(s2).first()
|
||||
|
||||
eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
|
||||
|
||||
|
||||
class PercentSchemaNamesTest(fixtures.TablesTest):
|
||||
"""tests using percent signs, spaces in table and column names.
|
||||
|
||||
This didn't work for PostgreSQL / MySQL drivers for a long time
|
||||
but is now supported.
|
||||
|
||||
"""
|
||||
|
||||
__requires__ = ("percent_schema_names",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
cls.tables.percent_table = Table(
|
||||
"percent%table",
|
||||
metadata,
|
||||
Column("percent%", Integer),
|
||||
Column("spaces % more spaces", Integer),
|
||||
)
|
||||
cls.tables.lightweight_percent_table = sql.table(
|
||||
"percent%table",
|
||||
sql.column("percent%"),
|
||||
sql.column("spaces % more spaces"),
|
||||
)
|
||||
|
||||
def test_single_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
for params in [
|
||||
{"percent%": 5, "spaces % more spaces": 12},
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
]:
|
||||
connection.execute(percent_table.insert(), params)
|
||||
self._assert_table(connection)
|
||||
|
||||
def test_executemany_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
connection.execute(
|
||||
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
|
||||
)
|
||||
connection.execute(
|
||||
percent_table.insert(),
|
||||
[
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
],
|
||||
)
|
||||
self._assert_table(connection)
|
||||
|
||||
@requirements.insert_executemany_returning
|
||||
def test_executemany_returning_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
connection.execute(
|
||||
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
|
||||
)
|
||||
result = connection.execute(
|
||||
percent_table.insert().returning(
|
||||
percent_table.c["percent%"],
|
||||
percent_table.c["spaces % more spaces"],
|
||||
),
|
||||
[
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
],
|
||||
)
|
||||
eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
|
||||
self._assert_table(connection)
|
||||
|
||||
def _assert_table(self, conn):
|
||||
percent_table = self.tables.percent_table
|
||||
lightweight_percent_table = self.tables.lightweight_percent_table
|
||||
|
||||
for table in (
|
||||
percent_table,
|
||||
percent_table.alias(),
|
||||
lightweight_percent_table,
|
||||
lightweight_percent_table.alias(),
|
||||
):
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(table.select().order_by(table.c["percent%"]))
|
||||
),
|
||||
[(5, 12), (7, 11), (9, 10), (11, 9)],
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(
|
||||
table.select()
|
||||
.where(table.c["spaces % more spaces"].in_([9, 10]))
|
||||
.order_by(table.c["percent%"])
|
||||
)
|
||||
),
|
||||
[(9, 10), (11, 9)],
|
||||
)
|
||||
|
||||
row = conn.execute(
|
||||
table.select().order_by(table.c["percent%"])
|
||||
).first()
|
||||
eq_(row._mapping["percent%"], 5)
|
||||
eq_(row._mapping["spaces % more spaces"], 12)
|
||||
|
||||
eq_(row._mapping[table.c["percent%"]], 5)
|
||||
eq_(row._mapping[table.c["spaces % more spaces"]], 12)
|
||||
|
||||
conn.execute(
|
||||
percent_table.update().values(
|
||||
{percent_table.c["spaces % more spaces"]: 15}
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(
|
||||
percent_table.select().order_by(
|
||||
percent_table.c["percent%"]
|
||||
)
|
||||
)
|
||||
),
|
||||
[(5, 15), (7, 15), (9, 15), (11, 15)],
|
||||
)
|
||||
|
||||
|
||||
class ServerSideCursorsTest(
|
||||
fixtures.TestBase, testing.AssertsExecutionResults
|
||||
):
|
||||
__requires__ = ("server_side_cursors",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def _is_server_side(self, cursor):
|
||||
# TODO: this is a huge issue as it prevents these tests from being
|
||||
# usable by third party dialects.
|
||||
if self.engine.dialect.driver == "psycopg2":
|
||||
return bool(cursor.name)
|
||||
elif self.engine.dialect.driver == "pymysql":
|
||||
sscursor = __import__("pymysql.cursors").cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
|
||||
return cursor.server_side
|
||||
elif self.engine.dialect.driver == "mysqldb":
|
||||
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.dialect.driver == "mariadbconnector":
|
||||
return not cursor.buffered
|
||||
elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
|
||||
return cursor.server_side
|
||||
elif self.engine.dialect.driver == "pg8000":
|
||||
return getattr(cursor, "server_side", False)
|
||||
elif self.engine.dialect.driver == "psycopg":
|
||||
return bool(getattr(cursor, "name", False))
|
||||
else:
|
||||
return False
|
||||
|
||||
def _fixture(self, server_side_cursors):
|
||||
if server_side_cursors:
|
||||
with testing.expect_deprecated(
|
||||
"The create_engine.server_side_cursors parameter is "
|
||||
"deprecated and will be removed in a future release. "
|
||||
"Please use the Connection.execution_options.stream_results "
|
||||
"parameter."
|
||||
):
|
||||
self.engine = engines.testing_engine(
|
||||
options={"server_side_cursors": server_side_cursors}
|
||||
)
|
||||
else:
|
||||
self.engine = engines.testing_engine(
|
||||
options={"server_side_cursors": server_side_cursors}
|
||||
)
|
||||
return self.engine
|
||||
|
||||
@testing.combinations(
|
||||
("global_string", True, "select 1", True),
|
||||
("global_text", True, text("select 1"), True),
|
||||
("global_expr", True, select(1), True),
|
||||
("global_off_explicit", False, text("select 1"), False),
|
||||
(
|
||||
"stmt_option",
|
||||
False,
|
||||
select(1).execution_options(stream_results=True),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"stmt_option_disabled",
|
||||
True,
|
||||
select(1).execution_options(stream_results=False),
|
||||
False,
|
||||
),
|
||||
("for_update_expr", True, select(1).with_for_update(), True),
|
||||
# TODO: need a real requirement for this, or dont use this test
|
||||
(
|
||||
"for_update_string",
|
||||
True,
|
||||
"SELECT 1 FOR UPDATE",
|
||||
True,
|
||||
testing.skip_if(["sqlite", "mssql"]),
|
||||
),
|
||||
("text_no_ss", False, text("select 42"), False),
|
||||
(
|
||||
"text_ss_option",
|
||||
False,
|
||||
text("select 42").execution_options(stream_results=True),
|
||||
True,
|
||||
),
|
||||
id_="iaaa",
|
||||
argnames="engine_ss_arg, statement, cursor_ss_status",
|
||||
)
|
||||
def test_ss_cursor_status(
|
||||
self, engine_ss_arg, statement, cursor_ss_status
|
||||
):
|
||||
engine = self._fixture(engine_ss_arg)
|
||||
with engine.begin() as conn:
|
||||
if isinstance(statement, str):
|
||||
result = conn.exec_driver_sql(statement)
|
||||
else:
|
||||
result = conn.execute(statement)
|
||||
eq_(self._is_server_side(result.cursor), cursor_ss_status)
|
||||
result.close()
|
||||
|
||||
def test_conn_option(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# should be enabled for this one
|
||||
result = conn.execution_options(
|
||||
stream_results=True
|
||||
).exec_driver_sql("select 1")
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
# the connection has autobegun, which means at the end of the
|
||||
# block, we will roll back, which on MySQL at least will fail
|
||||
# with "Commands out of sync" if the result set
|
||||
# is not closed, so we close it first.
|
||||
#
|
||||
# fun fact! why did we not have this result.close() in this test
|
||||
# before 2.0? don't we roll back in the connection pool
|
||||
# unconditionally? yes! and in fact if you run this test in 1.4
|
||||
# with stdout shown, there is in fact "Exception during reset or
|
||||
# similar" with "Commands out sync" emitted a warning! 2.0's
|
||||
# architecture finds and fixes what was previously an expensive
|
||||
# silent error condition.
|
||||
result.close()
|
||||
|
||||
def test_stmt_enabled_conn_option_disabled(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
s = select(1).execution_options(stream_results=True)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# not this one
|
||||
result = conn.execution_options(stream_results=False).execute(s)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_aliases_and_ss(self):
|
||||
engine = self._fixture(False)
|
||||
s1 = (
|
||||
select(sql.literal_column("1").label("x"))
|
||||
.execution_options(stream_results=True)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# options don't propagate out when subquery is used as a FROM clause
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(s1.select())
|
||||
assert not self._is_server_side(result.cursor)
|
||||
result.close()
|
||||
|
||||
s2 = select(1).select_from(s1)
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(s2)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
result.close()
|
||||
|
||||
def test_roundtrip_fetchall(self, metadata):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
md,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
test_table.create(connection, checkfirst=True)
|
||||
connection.execute(test_table.insert(), dict(data="data1"))
|
||||
connection.execute(test_table.insert(), dict(data="data2"))
|
||||
eq_(
|
||||
connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "data1"), (2, "data2")],
|
||||
)
|
||||
connection.execute(
|
||||
test_table.update()
|
||||
.where(test_table.c.id == 2)
|
||||
.values(data=test_table.c.data + " updated")
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "data1"), (2, "data2 updated")],
|
||||
)
|
||||
connection.execute(test_table.delete())
|
||||
eq_(
|
||||
connection.scalar(
|
||||
select(func.count("*")).select_from(test_table)
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
def test_roundtrip_fetchmany(self, metadata):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
md,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
test_table.create(connection, checkfirst=True)
|
||||
connection.execute(
|
||||
test_table.insert(),
|
||||
[dict(data="data%d" % i) for i in range(1, 20)],
|
||||
)
|
||||
|
||||
result = connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
)
|
||||
|
||||
eq_(
|
||||
result.fetchmany(5),
|
||||
[(i, "data%d" % i) for i in range(1, 6)],
|
||||
)
|
||||
eq_(
|
||||
result.fetchmany(10),
|
||||
[(i, "data%d" % i) for i in range(6, 16)],
|
||||
)
|
||||
eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
|
|
@ -0,0 +1,258 @@
|
|||
# testing/suite/test_rowcount.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
|
||||
|
||||
class RowCountTest(fixtures.TablesTest):
|
||||
"""test rowcount functionality"""
|
||||
|
||||
__requires__ = ("sane_rowcount",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"employees",
|
||||
metadata,
|
||||
Column(
|
||||
"employee_id",
|
||||
Integer,
|
||||
autoincrement=False,
|
||||
primary_key=True,
|
||||
),
|
||||
Column("name", String(50)),
|
||||
Column("department", String(1)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
cls.data = data = [
|
||||
("Angela", "A"),
|
||||
("Andrew", "A"),
|
||||
("Anand", "A"),
|
||||
("Bob", "B"),
|
||||
("Bobette", "B"),
|
||||
("Buffy", "B"),
|
||||
("Charlie", "C"),
|
||||
("Cynthia", "C"),
|
||||
("Chris", "C"),
|
||||
]
|
||||
|
||||
employees_table = cls.tables.employees
|
||||
connection.execute(
|
||||
employees_table.insert(),
|
||||
[
|
||||
{"employee_id": i, "name": n, "department": d}
|
||||
for i, (n, d) in enumerate(data)
|
||||
],
|
||||
)
|
||||
|
||||
def test_basic(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
s = select(
|
||||
employees_table.c.name, employees_table.c.department
|
||||
).order_by(employees_table.c.employee_id)
|
||||
rows = connection.execute(s).fetchall()
|
||||
|
||||
eq_(rows, self.data)
|
||||
|
||||
@testing.variation("statement", ["update", "delete", "insert", "select"])
|
||||
@testing.variation("close_first", [True, False])
|
||||
def test_non_rowcount_scenarios_no_raise(
|
||||
self, connection, statement, close_first
|
||||
):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows changed
|
||||
department = employees_table.c.department
|
||||
|
||||
if statement.update:
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
elif statement.delete:
|
||||
r = connection.execute(
|
||||
employees_table.delete().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
elif statement.insert:
|
||||
r = connection.execute(
|
||||
employees_table.insert(),
|
||||
[
|
||||
{"employee_id": 25, "name": "none 1", "department": "X"},
|
||||
{"employee_id": 26, "name": "none 2", "department": "Z"},
|
||||
{"employee_id": 27, "name": "none 3", "department": "Z"},
|
||||
],
|
||||
)
|
||||
elif statement.select:
|
||||
s = select(
|
||||
employees_table.c.name, employees_table.c.department
|
||||
).where(employees_table.c.department == "C")
|
||||
r = connection.execute(s)
|
||||
r.all()
|
||||
else:
|
||||
statement.fail()
|
||||
|
||||
if close_first:
|
||||
r.close()
|
||||
|
||||
assert r.rowcount in (-1, 3)
|
||||
|
||||
def test_update_rowcount1(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows changed
|
||||
department = employees_table.c.department
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
assert r.rowcount == 3
|
||||
|
||||
def test_update_rowcount2(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 0 rows changed
|
||||
department = employees_table.c.department
|
||||
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "C"},
|
||||
)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.variation("implicit_returning", [True, False])
|
||||
@testing.variation(
|
||||
"dml",
|
||||
[
|
||||
("update", testing.requires.update_returning),
|
||||
("delete", testing.requires.delete_returning),
|
||||
],
|
||||
)
|
||||
def test_update_delete_rowcount_return_defaults(
|
||||
self, connection, implicit_returning, dml
|
||||
):
|
||||
"""note this test should succeed for all RETURNING backends
|
||||
as of 2.0. In
|
||||
Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
|
||||
len(rows) when we have implicit returning
|
||||
|
||||
"""
|
||||
|
||||
if implicit_returning:
|
||||
employees_table = self.tables.employees
|
||||
else:
|
||||
employees_table = Table(
|
||||
"employees",
|
||||
MetaData(),
|
||||
Column(
|
||||
"employee_id",
|
||||
Integer,
|
||||
autoincrement=False,
|
||||
primary_key=True,
|
||||
),
|
||||
Column("name", String(50)),
|
||||
Column("department", String(1)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
department = employees_table.c.department
|
||||
|
||||
if dml.update:
|
||||
stmt = (
|
||||
employees_table.update()
|
||||
.where(department == "C")
|
||||
.values(name=employees_table.c.department + "Z")
|
||||
.return_defaults()
|
||||
)
|
||||
elif dml.delete:
|
||||
stmt = (
|
||||
employees_table.delete()
|
||||
.where(department == "C")
|
||||
.return_defaults()
|
||||
)
|
||||
else:
|
||||
dml.fail()
|
||||
|
||||
r = connection.execute(stmt)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
def test_raw_sql_rowcount(self, connection):
|
||||
# test issue #3622, make sure eager rowcount is called for text
|
||||
result = connection.exec_driver_sql(
|
||||
"update employees set department='Z' where department='C'"
|
||||
)
|
||||
eq_(result.rowcount, 3)
|
||||
|
||||
def test_text_rowcount(self, connection):
|
||||
# test issue #3622, make sure eager rowcount is called for text
|
||||
result = connection.execute(
|
||||
text("update employees set department='Z' where department='C'")
|
||||
)
|
||||
eq_(result.rowcount, 3)
|
||||
|
||||
def test_delete_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows deleted
|
||||
department = employees_table.c.department
|
||||
r = connection.execute(
|
||||
employees_table.delete().where(department == "C")
|
||||
)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_multi_update_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
stmt = (
|
||||
employees_table.update()
|
||||
.where(employees_table.c.name == bindparam("emp_name"))
|
||||
.values(department="C")
|
||||
)
|
||||
|
||||
r = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{"emp_name": "Bob"},
|
||||
{"emp_name": "Cynthia"},
|
||||
{"emp_name": "nonexistent"},
|
||||
],
|
||||
)
|
||||
|
||||
eq_(r.rowcount, 2)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_multi_delete_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
stmt = employees_table.delete().where(
|
||||
employees_table.c.name == bindparam("emp_name")
|
||||
)
|
||||
|
||||
r = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{"emp_name": "Bob"},
|
||||
{"emp_name": "Cynthia"},
|
||||
{"emp_name": "nonexistent"},
|
||||
],
|
||||
)
|
||||
|
||||
eq_(r.rowcount, 2)
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,317 @@
|
|||
# testing/suite/test_sequence.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import config
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..assertions import is_true
|
||||
from ..config import requirements
|
||||
from ..provision import normalize_sequence
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import inspect
|
||||
from ... import Integer
|
||||
from ... import MetaData
|
||||
from ... import Sequence
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class SequenceTest(fixtures.TablesTest):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
run_create_tables = "each"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"seq_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(config, Sequence("tab_id_seq")),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"seq_opt_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence("tab_id_seq", data_type=Integer, optional=True),
|
||||
),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"seq_no_returning",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(config, Sequence("noret_id_seq")),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
Table(
|
||||
"seq_no_returning_sch",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"noret_sch_id_seq", schema=config.test_schema
|
||||
),
|
||||
),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
def test_insert_roundtrip(self, connection):
|
||||
connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
|
||||
self._assert_round_trip(self.tables.seq_pk, connection)
|
||||
|
||||
def test_insert_lastrowid(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.seq_pk.insert(), dict(data="some data")
|
||||
)
|
||||
eq_(
|
||||
r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
|
||||
)
|
||||
|
||||
def test_nextval_direct(self, connection):
|
||||
r = connection.scalar(self.tables.seq_pk.c.id.default)
|
||||
eq_(r, testing.db.dialect.default_sequence_base)
|
||||
|
||||
@requirements.sequences_optional
|
||||
def test_optional_seq(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.seq_opt_pk.insert(), dict(data="some data")
|
||||
)
|
||||
eq_(r.inserted_primary_key, (1,))
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
|
||||
|
||||
def test_insert_roundtrip_no_implicit_returning(self, connection):
|
||||
connection.execute(
|
||||
self.tables.seq_no_returning.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.seq_no_returning, connection)
|
||||
|
||||
@testing.combinations((True,), (False,), argnames="implicit_returning")
|
||||
@testing.requires.schemas
|
||||
def test_insert_roundtrip_translate(self, connection, implicit_returning):
|
||||
seq_no_returning = Table(
|
||||
"seq_no_returning_sch",
|
||||
MetaData(),
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
normalize_sequence(
|
||||
config, Sequence("noret_sch_id_seq", schema="alt_schema")
|
||||
),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=implicit_returning,
|
||||
schema="alt_schema",
|
||||
)
|
||||
|
||||
connection = connection.execution_options(
|
||||
schema_translate_map={"alt_schema": config.test_schema}
|
||||
)
|
||||
connection.execute(seq_no_returning.insert(), dict(data="some data"))
|
||||
self._assert_round_trip(seq_no_returning, connection)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_nextval_direct_schema_translate(self, connection):
|
||||
seq = normalize_sequence(
|
||||
config, Sequence("noret_sch_id_seq", schema="alt_schema")
|
||||
)
|
||||
connection = connection.execution_options(
|
||||
schema_translate_map={"alt_schema": config.test_schema}
|
||||
)
|
||||
|
||||
r = connection.scalar(seq)
|
||||
eq_(r, testing.db.dialect.default_sequence_base)
|
||||
|
||||
|
||||
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
def test_literal_binds_inline_compile(self, connection):
|
||||
table = Table(
|
||||
"x",
|
||||
MetaData(),
|
||||
Column(
|
||||
"y", Integer, normalize_sequence(config, Sequence("y_seq"))
|
||||
),
|
||||
Column("q", Integer),
|
||||
)
|
||||
|
||||
stmt = table.insert().values(q=5)
|
||||
|
||||
seq_nextval = connection.dialect.statement_compiler(
|
||||
statement=None, dialect=connection.dialect
|
||||
).visit_sequence(normalize_sequence(config, Sequence("y_seq")))
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
|
||||
literal_binds=True,
|
||||
dialect=connection.dialect,
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTest(fixtures.TablesTest):
|
||||
run_deletes = None
|
||||
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
normalize_sequence(config, Sequence("user_id_seq", metadata=metadata))
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"other_seq",
|
||||
metadata=metadata,
|
||||
nomaxvalue=True,
|
||||
nominvalue=True,
|
||||
),
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"user_id_seq", schema=config.test_schema, metadata=metadata
|
||||
),
|
||||
)
|
||||
normalize_sequence(
|
||||
config,
|
||||
Sequence(
|
||||
"schema_seq", schema=config.test_schema, metadata=metadata
|
||||
),
|
||||
)
|
||||
Table(
|
||||
"user_id_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
)
|
||||
|
||||
def test_has_sequence(self, connection):
|
||||
eq_(inspect(connection).has_sequence("user_id_seq"), True)
|
||||
|
||||
def test_has_sequence_cache(self, connection, metadata):
|
||||
insp = inspect(connection)
|
||||
eq_(insp.has_sequence("user_id_seq"), True)
|
||||
ss = normalize_sequence(config, Sequence("new_seq", metadata=metadata))
|
||||
eq_(insp.has_sequence("new_seq"), False)
|
||||
ss.create(connection)
|
||||
try:
|
||||
eq_(insp.has_sequence("new_seq"), False)
|
||||
insp.clear_cache()
|
||||
eq_(insp.has_sequence("new_seq"), True)
|
||||
finally:
|
||||
ss.drop(connection)
|
||||
|
||||
def test_has_sequence_other_object(self, connection):
|
||||
eq_(inspect(connection).has_sequence("user_id_table"), False)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schema(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"user_id_seq", schema=config.test_schema
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
def test_has_sequence_neg(self, connection):
|
||||
eq_(inspect(connection).has_sequence("some_sequence"), False)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schemas_neg(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"some_sequence", schema=config.test_schema
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_default_not_in_remote(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"other_sequence", schema=config.test_schema
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_remote_not_in_default(self, connection):
|
||||
eq_(inspect(connection).has_sequence("schema_seq"), False)
|
||||
|
||||
def test_get_sequence_names(self, connection):
|
||||
exp = {"other_seq", "user_id_seq"}
|
||||
|
||||
res = set(inspect(connection).get_sequence_names())
|
||||
is_true(res.intersection(exp) == exp)
|
||||
is_true("schema_seq" not in res)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_sequence_names_no_sequence_schema(self, connection):
|
||||
eq_(
|
||||
inspect(connection).get_sequence_names(
|
||||
schema=config.test_schema_2
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_sequence_names_sequences_schema(self, connection):
|
||||
eq_(
|
||||
sorted(
|
||||
inspect(connection).get_sequence_names(
|
||||
schema=config.test_schema
|
||||
)
|
||||
),
|
||||
["schema_seq", "user_id_seq"],
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTestEmpty(fixtures.TestBase):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
def test_get_sequence_names_no_sequence(self, connection):
|
||||
eq_(
|
||||
inspect(connection).get_sequence_names(),
|
||||
[],
|
||||
)
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,189 @@
|
|||
# testing/suite/test_unicode_ddl.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing.schema import Column
|
||||
from sqlalchemy.testing.schema import Table
|
||||
|
||||
|
||||
class UnicodeSchemaTest(fixtures.TablesTest):
|
||||
__requires__ = ("unicode_ddl",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
global t1, t2, t3
|
||||
|
||||
t1 = Table(
|
||||
"unitable1",
|
||||
metadata,
|
||||
Column("méil", Integer, primary_key=True),
|
||||
Column("\u6e2c\u8a66", Integer),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
t2 = Table(
|
||||
"Unitéble2",
|
||||
metadata,
|
||||
Column("méil", Integer, primary_key=True, key="a"),
|
||||
Column(
|
||||
"\u6e2c\u8a66",
|
||||
Integer,
|
||||
ForeignKey("unitable1.méil"),
|
||||
key="b",
|
||||
),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
# Few DBs support Unicode foreign keys
|
||||
if testing.against("sqlite"):
|
||||
t3 = Table(
|
||||
"\u6e2c\u8a66",
|
||||
metadata,
|
||||
Column(
|
||||
"\u6e2c\u8a66_id",
|
||||
Integer,
|
||||
primary_key=True,
|
||||
autoincrement=False,
|
||||
),
|
||||
Column(
|
||||
"unitable1_\u6e2c\u8a66",
|
||||
Integer,
|
||||
ForeignKey("unitable1.\u6e2c\u8a66"),
|
||||
),
|
||||
Column("Unitéble2_b", Integer, ForeignKey("Unitéble2.b")),
|
||||
Column(
|
||||
"\u6e2c\u8a66_self",
|
||||
Integer,
|
||||
ForeignKey("\u6e2c\u8a66.\u6e2c\u8a66_id"),
|
||||
),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
else:
|
||||
t3 = Table(
|
||||
"\u6e2c\u8a66",
|
||||
metadata,
|
||||
Column(
|
||||
"\u6e2c\u8a66_id",
|
||||
Integer,
|
||||
primary_key=True,
|
||||
autoincrement=False,
|
||||
),
|
||||
Column("unitable1_\u6e2c\u8a66", Integer),
|
||||
Column("Unitéble2_b", Integer),
|
||||
Column("\u6e2c\u8a66_self", Integer),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
def test_insert(self, connection):
|
||||
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
|
||||
connection.execute(t2.insert(), {"a": 1, "b": 1})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 1,
|
||||
"unitable1_\u6e2c\u8a66": 5,
|
||||
"Unitéble2_b": 1,
|
||||
"\u6e2c\u8a66_self": 1,
|
||||
},
|
||||
)
|
||||
|
||||
eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
|
||||
eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
|
||||
eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
|
||||
|
||||
def test_col_targeting(self, connection):
|
||||
connection.execute(t1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
|
||||
connection.execute(t2.insert(), {"a": 1, "b": 1})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 1,
|
||||
"unitable1_\u6e2c\u8a66": 5,
|
||||
"Unitéble2_b": 1,
|
||||
"\u6e2c\u8a66_self": 1,
|
||||
},
|
||||
)
|
||||
|
||||
row = connection.execute(t1.select()).first()
|
||||
eq_(row._mapping[t1.c["méil"]], 1)
|
||||
eq_(row._mapping[t1.c["\u6e2c\u8a66"]], 5)
|
||||
|
||||
row = connection.execute(t2.select()).first()
|
||||
eq_(row._mapping[t2.c["a"]], 1)
|
||||
eq_(row._mapping[t2.c["b"]], 1)
|
||||
|
||||
row = connection.execute(t3.select()).first()
|
||||
eq_(row._mapping[t3.c["\u6e2c\u8a66_id"]], 1)
|
||||
eq_(row._mapping[t3.c["unitable1_\u6e2c\u8a66"]], 5)
|
||||
eq_(row._mapping[t3.c["Unitéble2_b"]], 1)
|
||||
eq_(row._mapping[t3.c["\u6e2c\u8a66_self"]], 1)
|
||||
|
||||
def test_reflect(self, connection):
|
||||
connection.execute(t1.insert(), {"méil": 2, "\u6e2c\u8a66": 7})
|
||||
connection.execute(t2.insert(), {"a": 2, "b": 2})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 2,
|
||||
"unitable1_\u6e2c\u8a66": 7,
|
||||
"Unitéble2_b": 2,
|
||||
"\u6e2c\u8a66_self": 2,
|
||||
},
|
||||
)
|
||||
|
||||
meta = MetaData()
|
||||
tt1 = Table(t1.name, meta, autoload_with=connection)
|
||||
tt2 = Table(t2.name, meta, autoload_with=connection)
|
||||
tt3 = Table(t3.name, meta, autoload_with=connection)
|
||||
|
||||
connection.execute(tt1.insert(), {"méil": 1, "\u6e2c\u8a66": 5})
|
||||
connection.execute(tt2.insert(), {"méil": 1, "\u6e2c\u8a66": 1})
|
||||
connection.execute(
|
||||
tt3.insert(),
|
||||
{
|
||||
"\u6e2c\u8a66_id": 1,
|
||||
"unitable1_\u6e2c\u8a66": 5,
|
||||
"Unitéble2_b": 1,
|
||||
"\u6e2c\u8a66_self": 1,
|
||||
},
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(tt1.select().order_by(desc("méil"))).fetchall(),
|
||||
[(2, 7), (1, 5)],
|
||||
)
|
||||
eq_(
|
||||
connection.execute(tt2.select().order_by(desc("méil"))).fetchall(),
|
||||
[(2, 2), (1, 1)],
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
tt3.select().order_by(desc("\u6e2c\u8a66_id"))
|
||||
).fetchall(),
|
||||
[(2, 7, 2, 2), (1, 5, 1, 1)],
|
||||
)
|
||||
|
||||
def test_repr(self):
|
||||
meta = MetaData()
|
||||
t = Table("\u6e2c\u8a66", meta, Column("\u6e2c\u8a66_id", Integer))
|
||||
eq_(
|
||||
repr(t),
|
||||
(
|
||||
"Table('測試', MetaData(), "
|
||||
"Column('測試_id', Integer(), "
|
||||
"table=<測試>), "
|
||||
"schema=None)"
|
||||
),
|
||||
)
|
|
@ -0,0 +1,139 @@
|
|||
# testing/suite/test_update_delete.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class SimpleUpdateDeleteTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
__requires__ = ("sane_rowcount",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"plain_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_update(self, connection):
|
||||
t = self.tables.plain_pk
|
||||
r = connection.execute(
|
||||
t.update().where(t.c.id == 2), dict(data="d2_new")
|
||||
)
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
assert r.rowcount == 1
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[(1, "d1"), (2, "d2_new"), (3, "d3")],
|
||||
)
|
||||
|
||||
def test_delete(self, connection):
|
||||
t = self.tables.plain_pk
|
||||
r = connection.execute(t.delete().where(t.c.id == 2))
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
assert r.rowcount == 1
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[(1, "d1"), (3, "d3")],
|
||||
)
|
||||
|
||||
@testing.variation("criteria", ["rows", "norows", "emptyin"])
|
||||
@testing.requires.update_returning
|
||||
def test_update_returning(self, connection, criteria):
|
||||
t = self.tables.plain_pk
|
||||
|
||||
stmt = t.update().returning(t.c.id, t.c.data)
|
||||
|
||||
if criteria.norows:
|
||||
stmt = stmt.where(t.c.id == 10)
|
||||
elif criteria.rows:
|
||||
stmt = stmt.where(t.c.id == 2)
|
||||
elif criteria.emptyin:
|
||||
stmt = stmt.where(t.c.id.in_([]))
|
||||
else:
|
||||
criteria.fail()
|
||||
|
||||
r = connection.execute(stmt, dict(data="d2_new"))
|
||||
assert not r.is_insert
|
||||
assert r.returns_rows
|
||||
eq_(r.keys(), ["id", "data"])
|
||||
|
||||
if criteria.rows:
|
||||
eq_(r.all(), [(2, "d2_new")])
|
||||
else:
|
||||
eq_(r.all(), [])
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
(
|
||||
[(1, "d1"), (2, "d2_new"), (3, "d3")]
|
||||
if criteria.rows
|
||||
else [(1, "d1"), (2, "d2"), (3, "d3")]
|
||||
),
|
||||
)
|
||||
|
||||
@testing.variation("criteria", ["rows", "norows", "emptyin"])
|
||||
@testing.requires.delete_returning
|
||||
def test_delete_returning(self, connection, criteria):
|
||||
t = self.tables.plain_pk
|
||||
|
||||
stmt = t.delete().returning(t.c.id, t.c.data)
|
||||
|
||||
if criteria.norows:
|
||||
stmt = stmt.where(t.c.id == 10)
|
||||
elif criteria.rows:
|
||||
stmt = stmt.where(t.c.id == 2)
|
||||
elif criteria.emptyin:
|
||||
stmt = stmt.where(t.c.id.in_([]))
|
||||
else:
|
||||
criteria.fail()
|
||||
|
||||
r = connection.execute(stmt)
|
||||
assert not r.is_insert
|
||||
assert r.returns_rows
|
||||
eq_(r.keys(), ["id", "data"])
|
||||
|
||||
if criteria.rows:
|
||||
eq_(r.all(), [(2, "d2")])
|
||||
else:
|
||||
eq_(r.all(), [])
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
(
|
||||
[(1, "d1"), (3, "d3")]
|
||||
if criteria.rows
|
||||
else [(1, "d1"), (2, "d2"), (3, "d3")]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("SimpleUpdateDeleteTest",)
|
519
elitebot/lib/python3.11/site-packages/sqlalchemy/testing/util.py
Normal file
519
elitebot/lib/python3.11/site-packages/sqlalchemy/testing/util.py
Normal file
|
@ -0,0 +1,519 @@
|
|||
# testing/util.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import decimal
|
||||
import gc
|
||||
from itertools import chain
|
||||
import random
|
||||
import sys
|
||||
from sys import getsizeof
|
||||
import types
|
||||
|
||||
from . import config
|
||||
from . import mock
|
||||
from .. import inspect
|
||||
from ..engine import Connection
|
||||
from ..schema import Column
|
||||
from ..schema import DropConstraint
|
||||
from ..schema import DropTable
|
||||
from ..schema import ForeignKeyConstraint
|
||||
from ..schema import MetaData
|
||||
from ..schema import Table
|
||||
from ..sql import schema
|
||||
from ..sql.sqltypes import Integer
|
||||
from ..util import decorator
|
||||
from ..util import defaultdict
|
||||
from ..util import has_refcount_gc
|
||||
from ..util import inspect_getfullargspec
|
||||
|
||||
|
||||
if not has_refcount_gc:
|
||||
|
||||
def non_refcount_gc_collect(*args):
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
|
||||
gc_collect = lazy_gc = non_refcount_gc_collect
|
||||
else:
|
||||
# assume CPython - straight gc.collect, lazy_gc() is a pass
|
||||
gc_collect = gc.collect
|
||||
|
||||
def lazy_gc():
|
||||
pass
|
||||
|
||||
|
||||
def picklers():
|
||||
picklers = set()
|
||||
import pickle
|
||||
|
||||
picklers.add(pickle)
|
||||
|
||||
# yes, this thing needs this much testing
|
||||
for pickle_ in picklers:
|
||||
for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1):
|
||||
yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
|
||||
|
||||
|
||||
def random_choices(population, k=1):
|
||||
return random.choices(population, k=k)
|
||||
|
||||
|
||||
def round_decimal(value, prec):
|
||||
if isinstance(value, float):
|
||||
return round(value, prec)
|
||||
|
||||
# can also use shift() here but that is 2.6 only
|
||||
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
|
||||
decimal.ROUND_FLOOR
|
||||
) / pow(10, prec)
|
||||
|
||||
|
||||
class RandomSet(set):
|
||||
def __iter__(self):
|
||||
l = list(set.__iter__(self))
|
||||
random.shuffle(l)
|
||||
return iter(l)
|
||||
|
||||
def pop(self):
|
||||
index = random.randint(0, len(self) - 1)
|
||||
item = list(set.__iter__(self))[index]
|
||||
self.remove(item)
|
||||
return item
|
||||
|
||||
def union(self, other):
|
||||
return RandomSet(set.union(self, other))
|
||||
|
||||
def difference(self, other):
|
||||
return RandomSet(set.difference(self, other))
|
||||
|
||||
def intersection(self, other):
|
||||
return RandomSet(set.intersection(self, other))
|
||||
|
||||
def copy(self):
|
||||
return RandomSet(self)
|
||||
|
||||
|
||||
def conforms_partial_ordering(tuples, sorted_elements):
|
||||
"""True if the given sorting conforms to the given partial ordering."""
|
||||
|
||||
deps = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
deps[parent].add(child)
|
||||
for i, node in enumerate(sorted_elements):
|
||||
for n in sorted_elements[i:]:
|
||||
if node in deps[n]:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def all_partial_orderings(tuples, elements):
|
||||
edges = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
edges[child].add(parent)
|
||||
|
||||
def _all_orderings(elements):
|
||||
if len(elements) == 1:
|
||||
yield list(elements)
|
||||
else:
|
||||
for elem in elements:
|
||||
subset = set(elements).difference([elem])
|
||||
if not subset.intersection(edges[elem]):
|
||||
for sub_ordering in _all_orderings(subset):
|
||||
yield [elem] + sub_ordering
|
||||
|
||||
return iter(_all_orderings(elements))
|
||||
|
||||
|
||||
def function_named(fn, name):
|
||||
"""Return a function with a given __name__.
|
||||
|
||||
Will assign to __name__ and return the original function if possible on
|
||||
the Python implementation, otherwise a new function will be constructed.
|
||||
|
||||
This function should be phased out as much as possible
|
||||
in favor of @decorator. Tests that "generate" many named tests
|
||||
should be modernized.
|
||||
|
||||
"""
|
||||
try:
|
||||
fn.__name__ = name
|
||||
except TypeError:
|
||||
fn = types.FunctionType(
|
||||
fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def run_as_contextmanager(ctx, fn, *arg, **kw):
|
||||
"""Run the given function under the given contextmanager,
|
||||
simulating the behavior of 'with' to support older
|
||||
Python versions.
|
||||
|
||||
This is not necessary anymore as we have placed 2.6
|
||||
as minimum Python version, however some tests are still using
|
||||
this structure.
|
||||
|
||||
"""
|
||||
|
||||
obj = ctx.__enter__()
|
||||
try:
|
||||
result = fn(obj, *arg, **kw)
|
||||
ctx.__exit__(None, None, None)
|
||||
return result
|
||||
except:
|
||||
exc_info = sys.exc_info()
|
||||
raise_ = ctx.__exit__(*exc_info)
|
||||
if not raise_:
|
||||
raise
|
||||
else:
|
||||
return raise_
|
||||
|
||||
|
||||
def rowset(results):
|
||||
"""Converts the results of sql execution into a plain set of column tuples.
|
||||
|
||||
Useful for asserting the results of an unordered query.
|
||||
"""
|
||||
|
||||
return {tuple(row) for row in results}
|
||||
|
||||
|
||||
def fail(msg):
|
||||
assert False, msg
|
||||
|
||||
|
||||
@decorator
|
||||
def provide_metadata(fn, *args, **kw):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards.
|
||||
|
||||
Legacy; use the "metadata" pytest fixture.
|
||||
|
||||
"""
|
||||
|
||||
from . import fixtures
|
||||
|
||||
metadata = schema.MetaData()
|
||||
self = args[0]
|
||||
prev_meta = getattr(self, "metadata", None)
|
||||
self.metadata = metadata
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
# close out some things that get in the way of dropping tables.
|
||||
# when using the "metadata" fixture, there is a set ordering
|
||||
# of things that makes sure things are cleaned up in order, however
|
||||
# the simple "decorator" nature of this legacy function means
|
||||
# we have to hardcode some of that cleanup ahead of time.
|
||||
|
||||
# close ORM sessions
|
||||
fixtures.close_all_sessions()
|
||||
|
||||
# integrate with the "connection" fixture as there are many
|
||||
# tests where it is used along with provide_metadata
|
||||
cfc = fixtures.base._connection_fixture_connection
|
||||
if cfc:
|
||||
# TODO: this warning can be used to find all the places
|
||||
# this is used with connection fixture
|
||||
# warn("mixing legacy provide metadata with connection fixture")
|
||||
drop_all_tables_from_metadata(metadata, cfc)
|
||||
# as the provide_metadata fixture is often used with "testing.db",
|
||||
# when we do the drop we have to commit the transaction so that
|
||||
# the DB is actually updated as the CREATE would have been
|
||||
# committed
|
||||
cfc.get_transaction().commit()
|
||||
else:
|
||||
drop_all_tables_from_metadata(metadata, config.db)
|
||||
self.metadata = prev_meta
|
||||
|
||||
|
||||
def flag_combinations(*combinations):
|
||||
"""A facade around @testing.combinations() oriented towards boolean
|
||||
keyword-based arguments.
|
||||
|
||||
Basically generates a nice looking identifier based on the keywords
|
||||
and also sets up the argument names.
|
||||
|
||||
E.g.::
|
||||
|
||||
@testing.flag_combinations(
|
||||
dict(lazy=False, passive=False),
|
||||
dict(lazy=True, passive=False),
|
||||
dict(lazy=False, passive=True),
|
||||
dict(lazy=False, passive=True, raiseload=True),
|
||||
)
|
||||
|
||||
|
||||
would result in::
|
||||
|
||||
@testing.combinations(
|
||||
('', False, False, False),
|
||||
('lazy', True, False, False),
|
||||
('lazy_passive', True, True, False),
|
||||
('lazy_passive', True, True, True),
|
||||
id_='iaaa',
|
||||
argnames='lazy,passive,raiseload'
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
keys = set()
|
||||
|
||||
for d in combinations:
|
||||
keys.update(d)
|
||||
|
||||
keys = sorted(keys)
|
||||
|
||||
return config.combinations(
|
||||
*[
|
||||
("_".join(k for k in keys if d.get(k, False)),)
|
||||
+ tuple(d.get(k, False) for k in keys)
|
||||
for d in combinations
|
||||
],
|
||||
id_="i" + ("a" * len(keys)),
|
||||
argnames=",".join(keys),
|
||||
)
|
||||
|
||||
|
||||
def lambda_combinations(lambda_arg_sets, **kw):
|
||||
args = inspect_getfullargspec(lambda_arg_sets)
|
||||
|
||||
arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
|
||||
|
||||
def create_fixture(pos):
|
||||
def fixture(**kw):
|
||||
return lambda_arg_sets(**kw)[pos]
|
||||
|
||||
fixture.__name__ = "fixture_%3.3d" % pos
|
||||
return fixture
|
||||
|
||||
return config.combinations(
|
||||
*[(create_fixture(i),) for i in range(len(arg_sets))], **kw
|
||||
)
|
||||
|
||||
|
||||
def resolve_lambda(__fn, **kw):
|
||||
"""Given a no-arg lambda and a namespace, return a new lambda that
|
||||
has all the values filled in.
|
||||
|
||||
This is used so that we can have module-level fixtures that
|
||||
refer to instance-level variables using lambdas.
|
||||
|
||||
"""
|
||||
|
||||
pos_args = inspect_getfullargspec(__fn)[0]
|
||||
pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
|
||||
glb = dict(__fn.__globals__)
|
||||
glb.update(kw)
|
||||
new_fn = types.FunctionType(__fn.__code__, glb)
|
||||
return new_fn(**pass_pos_args)
|
||||
|
||||
|
||||
def metadata_fixture(ddl="function"):
|
||||
"""Provide MetaData for a pytest fixture."""
|
||||
|
||||
def decorate(fn):
|
||||
def run_ddl(self):
|
||||
metadata = self.metadata = schema.MetaData()
|
||||
try:
|
||||
result = fn(self, metadata)
|
||||
metadata.create_all(config.db)
|
||||
# TODO:
|
||||
# somehow get a per-function dml erase fixture here
|
||||
yield result
|
||||
finally:
|
||||
metadata.drop_all(config.db)
|
||||
|
||||
return config.fixture(scope=ddl)(run_ddl)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def force_drop_names(*names):
|
||||
"""Force the given table names to be dropped after test complete,
|
||||
isolating for foreign key cycles
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def go(fn, *args, **kw):
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
drop_all_tables(config.db, inspect(config.db), include_names=names)
|
||||
|
||||
return go
|
||||
|
||||
|
||||
class adict(dict):
|
||||
"""Dict keys available as attributes. Shadows."""
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return dict.__getattribute__(self, key)
|
||||
|
||||
def __call__(self, *keys):
|
||||
return tuple([self[key] for key in keys])
|
||||
|
||||
get_all = __call__
|
||||
|
||||
|
||||
def drop_all_tables_from_metadata(metadata, engine_or_connection):
|
||||
from . import engines
|
||||
|
||||
def go(connection):
|
||||
engines.testing_reaper.prepare_for_drop_tables(connection)
|
||||
|
||||
if not connection.dialect.supports_alter:
|
||||
from . import assertions
|
||||
|
||||
with assertions.expect_warnings(
|
||||
"Can't sort tables", assert_=False
|
||||
):
|
||||
metadata.drop_all(connection)
|
||||
else:
|
||||
metadata.drop_all(connection)
|
||||
|
||||
if not isinstance(engine_or_connection, Connection):
|
||||
with engine_or_connection.begin() as connection:
|
||||
go(connection)
|
||||
else:
|
||||
go(engine_or_connection)
|
||||
|
||||
|
||||
def drop_all_tables(
|
||||
engine,
|
||||
inspector,
|
||||
schema=None,
|
||||
consider_schemas=(None,),
|
||||
include_names=None,
|
||||
):
|
||||
if include_names is not None:
|
||||
include_names = set(include_names)
|
||||
|
||||
if schema is not None:
|
||||
assert consider_schemas == (
|
||||
None,
|
||||
), "consider_schemas and schema are mutually exclusive"
|
||||
consider_schemas = (schema,)
|
||||
|
||||
with engine.begin() as conn:
|
||||
for table_key, fkcs in reversed(
|
||||
inspector.sort_tables_on_foreign_key_dependency(
|
||||
consider_schemas=consider_schemas
|
||||
)
|
||||
):
|
||||
if table_key:
|
||||
if (
|
||||
include_names is not None
|
||||
and table_key[1] not in include_names
|
||||
):
|
||||
continue
|
||||
conn.execute(
|
||||
DropTable(
|
||||
Table(table_key[1], MetaData(), schema=table_key[0])
|
||||
)
|
||||
)
|
||||
elif fkcs:
|
||||
if not engine.dialect.supports_alter:
|
||||
continue
|
||||
for t_key, fkc in fkcs:
|
||||
if (
|
||||
include_names is not None
|
||||
and t_key[1] not in include_names
|
||||
):
|
||||
continue
|
||||
tb = Table(
|
||||
t_key[1],
|
||||
MetaData(),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
schema=t_key[0],
|
||||
)
|
||||
conn.execute(
|
||||
DropConstraint(
|
||||
ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def teardown_events(event_cls):
|
||||
@decorator
|
||||
def decorate(fn, *arg, **kw):
|
||||
try:
|
||||
return fn(*arg, **kw)
|
||||
finally:
|
||||
event_cls._clear()
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def total_size(o):
|
||||
"""Returns the approximate memory footprint an object and all of its
|
||||
contents.
|
||||
|
||||
source: https://code.activestate.com/recipes/577504/
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def dict_handler(d):
|
||||
return chain.from_iterable(d.items())
|
||||
|
||||
all_handlers = {
|
||||
tuple: iter,
|
||||
list: iter,
|
||||
deque: iter,
|
||||
dict: dict_handler,
|
||||
set: iter,
|
||||
frozenset: iter,
|
||||
}
|
||||
seen = set() # track which object id's have already been seen
|
||||
default_size = getsizeof(0) # estimate sizeof object without __sizeof__
|
||||
|
||||
def sizeof(o):
|
||||
if id(o) in seen: # do not double count the same object
|
||||
return 0
|
||||
seen.add(id(o))
|
||||
s = getsizeof(o, default_size)
|
||||
|
||||
for typ, handler in all_handlers.items():
|
||||
if isinstance(o, typ):
|
||||
s += sum(map(sizeof, handler(o)))
|
||||
break
|
||||
return s
|
||||
|
||||
return sizeof(o)
|
||||
|
||||
|
||||
def count_cache_key_tuples(tup):
|
||||
"""given a cache key tuple, counts how many instances of actual
|
||||
tuples are found.
|
||||
|
||||
used to alert large jumps in cache key complexity.
|
||||
|
||||
"""
|
||||
stack = [tup]
|
||||
|
||||
sentinel = object()
|
||||
num_elements = 0
|
||||
|
||||
while stack:
|
||||
elem = stack.pop(0)
|
||||
if elem is sentinel:
|
||||
num_elements += 1
|
||||
elif isinstance(elem, tuple):
|
||||
if elem:
|
||||
stack = list(elem) + [sentinel] + stack
|
||||
return num_elements
|
|
@ -0,0 +1,52 @@
|
|||
# testing/warnings.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
from . import assertions
|
||||
from .. import exc
|
||||
from .. import exc as sa_exc
|
||||
from ..exc import SATestSuiteWarning
|
||||
from ..util.langhelpers import _warnings_warn
|
||||
|
||||
|
||||
def warn_test_suite(message):
|
||||
_warnings_warn(message, category=SATestSuiteWarning)
|
||||
|
||||
|
||||
def setup_filters():
|
||||
"""hook for setting up warnings filters.
|
||||
|
||||
SQLAlchemy-specific classes must only be here and not in pytest config,
|
||||
as we need to delay importing SQLAlchemy until conftest.py has been
|
||||
processed.
|
||||
|
||||
NOTE: filters on subclasses of DeprecationWarning or
|
||||
PendingDeprecationWarning have no effect if added here, since pytest
|
||||
will add at each test the following filters
|
||||
``always::PendingDeprecationWarning`` and ``always::DeprecationWarning``
|
||||
that will take precedence over any added here.
|
||||
|
||||
"""
|
||||
warnings.filterwarnings("error", category=exc.SAWarning)
|
||||
warnings.filterwarnings("always", category=exc.SATestSuiteWarning)
|
||||
|
||||
|
||||
def assert_warnings(fn, warning_msgs, regex=False):
|
||||
"""Assert that each of the given warnings are emitted by fn.
|
||||
|
||||
Deprecated. Please use assertions.expect_warnings().
|
||||
|
||||
"""
|
||||
|
||||
with assertions._expect_warnings(
|
||||
sa_exc.SAWarning, warning_msgs, regex=regex
|
||||
):
|
||||
return fn()
|
Loading…
Add table
Add a link
Reference in a new issue