Cleaned up the directories
This commit is contained in:
parent
f708506d68
commit
a683fcffea
1340 changed files with 554582 additions and 6840 deletions
|
@ -0,0 +1,11 @@
|
|||
# ext/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .. import util as _sa_util
|
||||
|
||||
|
||||
_sa_util.preloaded.import_prefix("sqlalchemy.ext")
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,25 @@
|
|||
# ext/asyncio/__init__.py
|
||||
# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from .engine import async_engine_from_config as async_engine_from_config
|
||||
from .engine import AsyncConnection as AsyncConnection
|
||||
from .engine import AsyncEngine as AsyncEngine
|
||||
from .engine import AsyncTransaction as AsyncTransaction
|
||||
from .engine import create_async_engine as create_async_engine
|
||||
from .engine import create_async_pool_from_url as create_async_pool_from_url
|
||||
from .result import AsyncMappingResult as AsyncMappingResult
|
||||
from .result import AsyncResult as AsyncResult
|
||||
from .result import AsyncScalarResult as AsyncScalarResult
|
||||
from .result import AsyncTupleResult as AsyncTupleResult
|
||||
from .scoping import async_scoped_session as async_scoped_session
|
||||
from .session import async_object_session as async_object_session
|
||||
from .session import async_session as async_session
|
||||
from .session import async_sessionmaker as async_sessionmaker
|
||||
from .session import AsyncAttrs as AsyncAttrs
|
||||
from .session import AsyncSession as AsyncSession
|
||||
from .session import AsyncSessionTransaction as AsyncSessionTransaction
|
||||
from .session import close_all_sessions as close_all_sessions
|
|
@ -0,0 +1,279 @@
|
|||
# ext/asyncio/base.py
|
||||
# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import functools
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncIterator
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import ClassVar
|
||||
from typing import Dict
|
||||
from typing import Generator
|
||||
from typing import Generic
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
import weakref
|
||||
|
||||
from . import exc as async_exc
|
||||
from ... import util
|
||||
from ...util.typing import Literal
|
||||
from ...util.typing import Self
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
|
||||
|
||||
|
||||
_PT = TypeVar("_PT", bound=Any)
|
||||
|
||||
|
||||
class ReversibleProxy(Generic[_PT]):
|
||||
_proxy_objects: ClassVar[
|
||||
Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
|
||||
] = {}
|
||||
__slots__ = ("__weakref__",)
|
||||
|
||||
@overload
|
||||
def _assign_proxied(self, target: _PT) -> _PT: ...
|
||||
|
||||
@overload
|
||||
def _assign_proxied(self, target: None) -> None: ...
|
||||
|
||||
def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
|
||||
if target is not None:
|
||||
target_ref: weakref.ref[_PT] = weakref.ref(
|
||||
target, ReversibleProxy._target_gced
|
||||
)
|
||||
proxy_ref = weakref.ref(
|
||||
self,
|
||||
functools.partial(ReversibleProxy._target_gced, target_ref),
|
||||
)
|
||||
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
|
||||
|
||||
return target
|
||||
|
||||
@classmethod
|
||||
def _target_gced(
|
||||
cls,
|
||||
ref: weakref.ref[_PT],
|
||||
proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100
|
||||
) -> None:
|
||||
cls._proxy_objects.pop(ref, None)
|
||||
|
||||
@classmethod
|
||||
def _regenerate_proxy_for_target(cls, target: _PT) -> Self:
|
||||
raise NotImplementedError()
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def _retrieve_proxy_for_target(
|
||||
cls,
|
||||
target: _PT,
|
||||
regenerate: Literal[True] = ...,
|
||||
) -> Self: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def _retrieve_proxy_for_target(
|
||||
cls, target: _PT, regenerate: bool = True
|
||||
) -> Optional[Self]: ...
|
||||
|
||||
@classmethod
|
||||
def _retrieve_proxy_for_target(
|
||||
cls, target: _PT, regenerate: bool = True
|
||||
) -> Optional[Self]:
|
||||
try:
|
||||
proxy_ref = cls._proxy_objects[weakref.ref(target)]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
proxy = proxy_ref()
|
||||
if proxy is not None:
|
||||
return proxy # type: ignore
|
||||
|
||||
if regenerate:
|
||||
return cls._regenerate_proxy_for_target(target)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class StartableContext(Awaitable[_T_co], abc.ABC):
|
||||
__slots__ = ()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def start(self, is_ctxmanager: bool = False) -> _T_co:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __await__(self) -> Generator[Any, Any, _T_co]:
|
||||
return self.start().__await__()
|
||||
|
||||
async def __aenter__(self) -> _T_co:
|
||||
return await self.start(is_ctxmanager=True)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def __aexit__(
|
||||
self, type_: Any, value: Any, traceback: Any
|
||||
) -> Optional[bool]:
|
||||
pass
|
||||
|
||||
def _raise_for_not_started(self) -> NoReturn:
|
||||
raise async_exc.AsyncContextNotStarted(
|
||||
"%s context has not been started and object has not been awaited."
|
||||
% (self.__class__.__name__)
|
||||
)
|
||||
|
||||
|
||||
class GeneratorStartableContext(StartableContext[_T_co]):
|
||||
__slots__ = ("gen",)
|
||||
|
||||
gen: AsyncGenerator[_T_co, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., AsyncIterator[_T_co]],
|
||||
args: Tuple[Any, ...],
|
||||
kwds: Dict[str, Any],
|
||||
):
|
||||
self.gen = func(*args, **kwds) # type: ignore
|
||||
|
||||
async def start(self, is_ctxmanager: bool = False) -> _T_co:
|
||||
try:
|
||||
start_value = await util.anext_(self.gen)
|
||||
except StopAsyncIteration:
|
||||
raise RuntimeError("generator didn't yield") from None
|
||||
|
||||
# if not a context manager, then interrupt the generator, don't
|
||||
# let it complete. this step is technically not needed, as the
|
||||
# generator will close in any case at gc time. not clear if having
|
||||
# this here is a good idea or not (though it helps for clarity IMO)
|
||||
if not is_ctxmanager:
|
||||
await self.gen.aclose()
|
||||
|
||||
return start_value
|
||||
|
||||
async def __aexit__(
|
||||
self, typ: Any, value: Any, traceback: Any
|
||||
) -> Optional[bool]:
|
||||
# vendored from contextlib.py
|
||||
if typ is None:
|
||||
try:
|
||||
await util.anext_(self.gen)
|
||||
except StopAsyncIteration:
|
||||
return False
|
||||
else:
|
||||
raise RuntimeError("generator didn't stop")
|
||||
else:
|
||||
if value is None:
|
||||
# Need to force instantiation so we can reliably
|
||||
# tell if we get the same exception back
|
||||
value = typ()
|
||||
try:
|
||||
await self.gen.athrow(value)
|
||||
except StopAsyncIteration as exc:
|
||||
# Suppress StopIteration *unless* it's the same exception that
|
||||
# was passed to throw(). This prevents a StopIteration
|
||||
# raised inside the "with" statement from being suppressed.
|
||||
return exc is not value
|
||||
except RuntimeError as exc:
|
||||
# Don't re-raise the passed in exception. (issue27122)
|
||||
if exc is value:
|
||||
return False
|
||||
# Avoid suppressing if a Stop(Async)Iteration exception
|
||||
# was passed to athrow() and later wrapped into a RuntimeError
|
||||
# (see PEP 479 for sync generators; async generators also
|
||||
# have this behavior). But do this only if the exception
|
||||
# wrapped
|
||||
# by the RuntimeError is actully Stop(Async)Iteration (see
|
||||
# issue29692).
|
||||
if (
|
||||
isinstance(value, (StopIteration, StopAsyncIteration))
|
||||
and exc.__cause__ is value
|
||||
):
|
||||
return False
|
||||
raise
|
||||
except BaseException as exc:
|
||||
# only re-raise if it's *not* the exception that was
|
||||
# passed to throw(), because __exit__() must not raise
|
||||
# an exception unless __exit__() itself failed. But throw()
|
||||
# has to raise the exception to signal propagation, so this
|
||||
# fixes the impedance mismatch between the throw() protocol
|
||||
# and the __exit__() protocol.
|
||||
if exc is not value:
|
||||
raise
|
||||
return False
|
||||
raise RuntimeError("generator didn't stop after athrow()")
|
||||
|
||||
|
||||
def asyncstartablecontext(
|
||||
func: Callable[..., AsyncIterator[_T_co]]
|
||||
) -> Callable[..., GeneratorStartableContext[_T_co]]:
|
||||
"""@asyncstartablecontext decorator.
|
||||
|
||||
the decorated function can be called either as ``async with fn()``, **or**
|
||||
``await fn()``. This is decidedly different from what
|
||||
``@contextlib.asynccontextmanager`` supports, and the usage pattern
|
||||
is different as well.
|
||||
|
||||
Typical usage::
|
||||
|
||||
@asyncstartablecontext
|
||||
async def some_async_generator(<arguments>):
|
||||
<setup>
|
||||
try:
|
||||
yield <value>
|
||||
except GeneratorExit:
|
||||
# return value was awaited, no context manager is present
|
||||
# and caller will .close() the resource explicitly
|
||||
pass
|
||||
else:
|
||||
<context manager cleanup>
|
||||
|
||||
|
||||
Above, ``GeneratorExit`` is caught if the function were used as an
|
||||
``await``. In this case, it's essential that the cleanup does **not**
|
||||
occur, so there should not be a ``finally`` block.
|
||||
|
||||
If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
|
||||
and we were invoked as a context manager, and cleanup should proceed.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
|
||||
return GeneratorStartableContext(func, args, kwds)
|
||||
|
||||
return helper
|
||||
|
||||
|
||||
class ProxyComparable(ReversibleProxy[_PT]):
|
||||
__slots__ = ()
|
||||
|
||||
@util.ro_non_memoized_property
|
||||
def _proxied(self) -> _PT:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self._proxied == other._proxied
|
||||
)
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return (
|
||||
not isinstance(other, self.__class__)
|
||||
or self._proxied != other._proxied
|
||||
)
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,21 @@
|
|||
# ext/asyncio/exc.py
|
||||
# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from ... import exc
|
||||
|
||||
|
||||
class AsyncMethodRequired(exc.InvalidRequestError):
|
||||
"""an API can't be used because its result would not be
|
||||
compatible with async"""
|
||||
|
||||
|
||||
class AsyncContextNotStarted(exc.InvalidRequestError):
|
||||
"""a startable context manager has not been started."""
|
||||
|
||||
|
||||
class AsyncContextAlreadyStarted(exc.InvalidRequestError):
|
||||
"""a startable context manager is already started."""
|
|
@ -0,0 +1,961 @@
|
|||
# ext/asyncio/result.py
|
||||
# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Any
|
||||
from typing import AsyncIterator
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from . import exc as async_exc
|
||||
from ... import util
|
||||
from ...engine import Result
|
||||
from ...engine.result import _NO_ROW
|
||||
from ...engine.result import _R
|
||||
from ...engine.result import _WithKeys
|
||||
from ...engine.result import FilterResult
|
||||
from ...engine.result import FrozenResult
|
||||
from ...engine.result import ResultMetaData
|
||||
from ...engine.row import Row
|
||||
from ...engine.row import RowMapping
|
||||
from ...sql.base import _generative
|
||||
from ...util.concurrency import greenlet_spawn
|
||||
from ...util.typing import Literal
|
||||
from ...util.typing import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine import CursorResult
|
||||
from ...engine.result import _KeyIndexType
|
||||
from ...engine.result import _UniqueFilterType
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
_TP = TypeVar("_TP", bound=Tuple[Any, ...])
|
||||
|
||||
|
||||
class AsyncCommon(FilterResult[_R]):
|
||||
__slots__ = ()
|
||||
|
||||
_real_result: Result[Any]
|
||||
_metadata: ResultMetaData
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
"""Close this result."""
|
||||
|
||||
await greenlet_spawn(self._real_result.close)
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
"""proxies the .closed attribute of the underlying result object,
|
||||
if any, else raises ``AttributeError``.
|
||||
|
||||
.. versionadded:: 2.0.0b3
|
||||
|
||||
"""
|
||||
return self._real_result.closed
|
||||
|
||||
|
||||
class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
|
||||
"""An asyncio wrapper around a :class:`_result.Result` object.
|
||||
|
||||
The :class:`_asyncio.AsyncResult` only applies to statement executions that
|
||||
use a server-side cursor. It is returned only from the
|
||||
:meth:`_asyncio.AsyncConnection.stream` and
|
||||
:meth:`_asyncio.AsyncSession.stream` methods.
|
||||
|
||||
.. note:: As is the case with :class:`_engine.Result`, this object is
|
||||
used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
|
||||
which can yield instances of ORM mapped objects either individually or
|
||||
within tuple-like rows. Note that these result objects do not
|
||||
deduplicate instances or rows automatically as is the case with the
|
||||
legacy :class:`_orm.Query` object. For in-Python de-duplication of
|
||||
instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
|
||||
method.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
_real_result: Result[_TP]
|
||||
|
||||
def __init__(self, real_result: Result[_TP]):
|
||||
self._real_result = real_result
|
||||
|
||||
self._metadata = real_result._metadata
|
||||
self._unique_filter_state = real_result._unique_filter_state
|
||||
self._post_creational_filter = None
|
||||
|
||||
# BaseCursorResult pre-generates the "_row_getter". Use that
|
||||
# if available rather than building a second one
|
||||
if "_row_getter" in real_result.__dict__:
|
||||
self._set_memoized_attribute(
|
||||
"_row_getter", real_result.__dict__["_row_getter"]
|
||||
)
|
||||
|
||||
@property
|
||||
def t(self) -> AsyncTupleResult[_TP]:
|
||||
"""Apply a "typed tuple" typing filter to returned rows.
|
||||
|
||||
The :attr:`_asyncio.AsyncResult.t` attribute is a synonym for
|
||||
calling the :meth:`_asyncio.AsyncResult.tuples` method.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
return self # type: ignore
|
||||
|
||||
def tuples(self) -> AsyncTupleResult[_TP]:
|
||||
"""Apply a "typed tuple" typing filter to returned rows.
|
||||
|
||||
This method returns the same :class:`_asyncio.AsyncResult` object
|
||||
at runtime,
|
||||
however annotates as returning a :class:`_asyncio.AsyncTupleResult`
|
||||
object that will indicate to :pep:`484` typing tools that plain typed
|
||||
``Tuple`` instances are returned rather than rows. This allows
|
||||
tuple unpacking and ``__getitem__`` access of :class:`_engine.Row`
|
||||
objects to by typed, for those cases where the statement invoked
|
||||
itself included typing information.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
:return: the :class:`_result.AsyncTupleResult` type at typing time.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`_asyncio.AsyncResult.t` - shorter synonym
|
||||
|
||||
:attr:`_engine.Row.t` - :class:`_engine.Row` version
|
||||
|
||||
"""
|
||||
|
||||
return self # type: ignore
|
||||
|
||||
@_generative
|
||||
def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self:
|
||||
"""Apply unique filtering to the objects returned by this
|
||||
:class:`_asyncio.AsyncResult`.
|
||||
|
||||
Refer to :meth:`_engine.Result.unique` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
"""
|
||||
self._unique_filter_state = (set(), strategy)
|
||||
return self
|
||||
|
||||
def columns(self, *col_expressions: _KeyIndexType) -> Self:
|
||||
r"""Establish the columns that should be returned in each row.
|
||||
|
||||
Refer to :meth:`_engine.Result.columns` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
"""
|
||||
return self._column_slices(col_expressions)
|
||||
|
||||
async def partitions(
|
||||
self, size: Optional[int] = None
|
||||
) -> AsyncIterator[Sequence[Row[_TP]]]:
|
||||
"""Iterate through sub-lists of rows of the size given.
|
||||
|
||||
An async iterator is returned::
|
||||
|
||||
async def scroll_results(connection):
|
||||
result = await connection.stream(select(users_table))
|
||||
|
||||
async for partition in result.partitions(100):
|
||||
print("list of rows: %s" % partition)
|
||||
|
||||
Refer to :meth:`_engine.Result.partitions` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
"""
|
||||
|
||||
getter = self._manyrow_getter
|
||||
|
||||
while True:
|
||||
partition = await greenlet_spawn(getter, self, size)
|
||||
if partition:
|
||||
yield partition
|
||||
else:
|
||||
break
|
||||
|
||||
async def fetchall(self) -> Sequence[Row[_TP]]:
|
||||
"""A synonym for the :meth:`_asyncio.AsyncResult.all` method.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
async def fetchone(self) -> Optional[Row[_TP]]:
|
||||
"""Fetch one row.
|
||||
|
||||
When all rows are exhausted, returns None.
|
||||
|
||||
This method is provided for backwards compatibility with
|
||||
SQLAlchemy 1.x.x.
|
||||
|
||||
To fetch the first row of a result only, use the
|
||||
:meth:`_asyncio.AsyncResult.first` method. To iterate through all
|
||||
rows, iterate the :class:`_asyncio.AsyncResult` object directly.
|
||||
|
||||
:return: a :class:`_engine.Row` object if no filters are applied,
|
||||
or ``None`` if no rows remain.
|
||||
|
||||
"""
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
return None
|
||||
else:
|
||||
return row
|
||||
|
||||
async def fetchmany(
|
||||
self, size: Optional[int] = None
|
||||
) -> Sequence[Row[_TP]]:
|
||||
"""Fetch many rows.
|
||||
|
||||
When all rows are exhausted, returns an empty list.
|
||||
|
||||
This method is provided for backwards compatibility with
|
||||
SQLAlchemy 1.x.x.
|
||||
|
||||
To fetch rows in groups, use the
|
||||
:meth:`._asyncio.AsyncResult.partitions` method.
|
||||
|
||||
:return: a list of :class:`_engine.Row` objects.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.partitions`
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._manyrow_getter, self, size)
|
||||
|
||||
async def all(self) -> Sequence[Row[_TP]]:
|
||||
"""Return all rows in a list.
|
||||
|
||||
Closes the result set after invocation. Subsequent invocations
|
||||
will return an empty list.
|
||||
|
||||
:return: a list of :class:`_engine.Row` objects.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
def __aiter__(self) -> AsyncResult[_TP]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Row[_TP]:
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return row
|
||||
|
||||
async def first(self) -> Optional[Row[_TP]]:
|
||||
"""Fetch the first row or ``None`` if no row is present.
|
||||
|
||||
Closes the result set and discards remaining rows.
|
||||
|
||||
.. note:: This method returns one **row**, e.g. tuple, by default.
|
||||
To return exactly one single scalar value, that is, the first
|
||||
column of the first row, use the
|
||||
:meth:`_asyncio.AsyncResult.scalar` method,
|
||||
or combine :meth:`_asyncio.AsyncResult.scalars` and
|
||||
:meth:`_asyncio.AsyncResult.first`.
|
||||
|
||||
Additionally, in contrast to the behavior of the legacy ORM
|
||||
:meth:`_orm.Query.first` method, **no limit is applied** to the
|
||||
SQL query which was invoked to produce this
|
||||
:class:`_asyncio.AsyncResult`;
|
||||
for a DBAPI driver that buffers results in memory before yielding
|
||||
rows, all rows will be sent to the Python process and all but
|
||||
the first row will be discarded.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`migration_20_unify_select`
|
||||
|
||||
:return: a :class:`_engine.Row` object, or None
|
||||
if no rows remain.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalar`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, False)
|
||||
|
||||
async def one_or_none(self) -> Optional[Row[_TP]]:
|
||||
"""Return at most one result or raise an exception.
|
||||
|
||||
Returns ``None`` if the result has no rows.
|
||||
Raises :class:`.MultipleResultsFound`
|
||||
if multiple rows are returned.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
:return: The first :class:`_engine.Row` or ``None`` if no row
|
||||
is available.
|
||||
|
||||
:raises: :class:`.MultipleResultsFound`
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.first`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, False)
|
||||
|
||||
@overload
|
||||
async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ...
|
||||
|
||||
@overload
|
||||
async def scalar_one(self) -> Any: ...
|
||||
|
||||
async def scalar_one(self) -> Any:
|
||||
"""Return exactly one scalar result or raise an exception.
|
||||
|
||||
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
|
||||
then :meth:`_asyncio.AsyncResult.one`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalars`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, True)
|
||||
|
||||
@overload
|
||||
async def scalar_one_or_none(
|
||||
self: AsyncResult[Tuple[_T]],
|
||||
) -> Optional[_T]: ...
|
||||
|
||||
@overload
|
||||
async def scalar_one_or_none(self) -> Optional[Any]: ...
|
||||
|
||||
async def scalar_one_or_none(self) -> Optional[Any]:
|
||||
"""Return exactly one scalar result or ``None``.
|
||||
|
||||
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
|
||||
then :meth:`_asyncio.AsyncResult.one_or_none`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one_or_none`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalars`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, True)
|
||||
|
||||
async def one(self) -> Row[_TP]:
|
||||
"""Return exactly one row or raise an exception.
|
||||
|
||||
Raises :class:`.NoResultFound` if the result returns no
|
||||
rows, or :class:`.MultipleResultsFound` if multiple rows
|
||||
would be returned.
|
||||
|
||||
.. note:: This method returns one **row**, e.g. tuple, by default.
|
||||
To return exactly one single scalar value, that is, the first
|
||||
column of the first row, use the
|
||||
:meth:`_asyncio.AsyncResult.scalar_one` method, or combine
|
||||
:meth:`_asyncio.AsyncResult.scalars` and
|
||||
:meth:`_asyncio.AsyncResult.one`.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
:return: The first :class:`_engine.Row`.
|
||||
|
||||
:raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_asyncio.AsyncResult.first`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.one_or_none`
|
||||
|
||||
:meth:`_asyncio.AsyncResult.scalar_one`
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, False)
|
||||
|
||||
@overload
|
||||
async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ...
|
||||
|
||||
@overload
|
||||
async def scalar(self) -> Any: ...
|
||||
|
||||
async def scalar(self) -> Any:
|
||||
"""Fetch the first column of the first row, and close the result set.
|
||||
|
||||
Returns ``None`` if there are no rows to fetch.
|
||||
|
||||
No validation is performed to test if additional rows remain.
|
||||
|
||||
After calling this method, the object is fully closed,
|
||||
e.g. the :meth:`_engine.CursorResult.close`
|
||||
method will have been called.
|
||||
|
||||
:return: a Python scalar value, or ``None`` if no rows remain.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, True)
|
||||
|
||||
async def freeze(self) -> FrozenResult[_TP]:
|
||||
"""Return a callable object that will produce copies of this
|
||||
:class:`_asyncio.AsyncResult` when invoked.
|
||||
|
||||
The callable object returned is an instance of
|
||||
:class:`_engine.FrozenResult`.
|
||||
|
||||
This is used for result set caching. The method must be called
|
||||
on the result when it has been unconsumed, and calling the method
|
||||
will consume the result fully. When the :class:`_engine.FrozenResult`
|
||||
is retrieved from a cache, it can be called any number of times where
|
||||
it will produce a new :class:`_engine.Result` object each time
|
||||
against its stored set of rows.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`do_orm_execute_re_executing` - example usage within the
|
||||
ORM to implement a result-set cache.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(FrozenResult, self)
|
||||
|
||||
@overload
|
||||
def scalars(
|
||||
self: AsyncResult[Tuple[_T]], index: Literal[0]
|
||||
) -> AsyncScalarResult[_T]: ...
|
||||
|
||||
@overload
|
||||
def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ...
|
||||
|
||||
@overload
|
||||
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ...
|
||||
|
||||
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
|
||||
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
|
||||
will return single elements rather than :class:`_row.Row` objects.
|
||||
|
||||
Refer to :meth:`_result.Result.scalars` in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
:param index: integer or row key indicating the column to be fetched
|
||||
from each row, defaults to ``0`` indicating the first column.
|
||||
|
||||
:return: a new :class:`_asyncio.AsyncScalarResult` filtering object
|
||||
referring to this :class:`_asyncio.AsyncResult` object.
|
||||
|
||||
"""
|
||||
return AsyncScalarResult(self._real_result, index)
|
||||
|
||||
def mappings(self) -> AsyncMappingResult:
|
||||
"""Apply a mappings filter to returned rows, returning an instance of
|
||||
:class:`_asyncio.AsyncMappingResult`.
|
||||
|
||||
When this filter is applied, fetching rows will return
|
||||
:class:`_engine.RowMapping` objects instead of :class:`_engine.Row`
|
||||
objects.
|
||||
|
||||
:return: a new :class:`_asyncio.AsyncMappingResult` filtering object
|
||||
referring to the underlying :class:`_result.Result` object.
|
||||
|
||||
"""
|
||||
|
||||
return AsyncMappingResult(self._real_result)
|
||||
|
||||
|
||||
class AsyncScalarResult(AsyncCommon[_R]):
|
||||
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
|
||||
rather than :class:`_row.Row` values.
|
||||
|
||||
The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
|
||||
:meth:`_asyncio.AsyncResult.scalars` method.
|
||||
|
||||
Refer to the :class:`_result.ScalarResult` object in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
_generate_rows = False
|
||||
|
||||
def __init__(self, real_result: Result[Any], index: _KeyIndexType):
|
||||
self._real_result = real_result
|
||||
|
||||
if real_result._source_supports_scalars:
|
||||
self._metadata = real_result._metadata
|
||||
self._post_creational_filter = None
|
||||
else:
|
||||
self._metadata = real_result._metadata._reduce([index])
|
||||
self._post_creational_filter = operator.itemgetter(0)
|
||||
|
||||
self._unique_filter_state = real_result._unique_filter_state
|
||||
|
||||
def unique(
|
||||
self,
|
||||
strategy: Optional[_UniqueFilterType] = None,
|
||||
) -> Self:
|
||||
"""Apply unique filtering to the objects returned by this
|
||||
:class:`_asyncio.AsyncScalarResult`.
|
||||
|
||||
See :meth:`_asyncio.AsyncResult.unique` for usage details.
|
||||
|
||||
"""
|
||||
self._unique_filter_state = (set(), strategy)
|
||||
return self
|
||||
|
||||
async def partitions(
|
||||
self, size: Optional[int] = None
|
||||
) -> AsyncIterator[Sequence[_R]]:
|
||||
"""Iterate through sub-lists of elements of the size given.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
|
||||
scalar values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
|
||||
getter = self._manyrow_getter
|
||||
|
||||
while True:
|
||||
partition = await greenlet_spawn(getter, self, size)
|
||||
if partition:
|
||||
yield partition
|
||||
else:
|
||||
break
|
||||
|
||||
async def fetchall(self) -> Sequence[_R]:
|
||||
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
|
||||
"""Fetch many objects.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
|
||||
scalar values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._manyrow_getter, self, size)
|
||||
|
||||
async def all(self) -> Sequence[_R]:
|
||||
"""Return all scalar values in a list.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
|
||||
scalar values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
def __aiter__(self) -> AsyncScalarResult[_R]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> _R:
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return row
|
||||
|
||||
async def first(self) -> Optional[_R]:
|
||||
"""Fetch the first object or ``None`` if no object is present.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
|
||||
scalar values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, False)
|
||||
|
||||
async def one_or_none(self) -> Optional[_R]:
|
||||
"""Return at most one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
|
||||
scalar values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, False)
|
||||
|
||||
async def one(self) -> _R:
|
||||
"""Return exactly one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
|
||||
scalar values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, False)
|
||||
|
||||
|
||||
class AsyncMappingResult(_WithKeys, AsyncCommon[RowMapping]):
|
||||
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
|
||||
values rather than :class:`_engine.Row` values.
|
||||
|
||||
The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
|
||||
:meth:`_asyncio.AsyncResult.mappings` method.
|
||||
|
||||
Refer to the :class:`_result.MappingResult` object in the synchronous
|
||||
SQLAlchemy API for a complete behavioral description.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
_generate_rows = True
|
||||
|
||||
_post_creational_filter = operator.attrgetter("_mapping")
|
||||
|
||||
def __init__(self, result: Result[Any]):
|
||||
self._real_result = result
|
||||
self._unique_filter_state = result._unique_filter_state
|
||||
self._metadata = result._metadata
|
||||
if result._source_supports_scalars:
|
||||
self._metadata = self._metadata._reduce([0])
|
||||
|
||||
def unique(
|
||||
self,
|
||||
strategy: Optional[_UniqueFilterType] = None,
|
||||
) -> Self:
|
||||
"""Apply unique filtering to the objects returned by this
|
||||
:class:`_asyncio.AsyncMappingResult`.
|
||||
|
||||
See :meth:`_asyncio.AsyncResult.unique` for usage details.
|
||||
|
||||
"""
|
||||
self._unique_filter_state = (set(), strategy)
|
||||
return self
|
||||
|
||||
def columns(self, *col_expressions: _KeyIndexType) -> Self:
|
||||
r"""Establish the columns that should be returned in each row."""
|
||||
return self._column_slices(col_expressions)
|
||||
|
||||
async def partitions(
|
||||
self, size: Optional[int] = None
|
||||
) -> AsyncIterator[Sequence[RowMapping]]:
|
||||
"""Iterate through sub-lists of elements of the size given.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
|
||||
getter = self._manyrow_getter
|
||||
|
||||
while True:
|
||||
partition = await greenlet_spawn(getter, self, size)
|
||||
if partition:
|
||||
yield partition
|
||||
else:
|
||||
break
|
||||
|
||||
async def fetchall(self) -> Sequence[RowMapping]:
|
||||
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
async def fetchone(self) -> Optional[RowMapping]:
|
||||
"""Fetch one object.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
return None
|
||||
else:
|
||||
return row
|
||||
|
||||
async def fetchmany(
|
||||
self, size: Optional[int] = None
|
||||
) -> Sequence[RowMapping]:
|
||||
"""Fetch many rows.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._manyrow_getter, self, size)
|
||||
|
||||
async def all(self) -> Sequence[RowMapping]:
|
||||
"""Return all rows in a list.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self._allrows)
|
||||
|
||||
def __aiter__(self) -> AsyncMappingResult:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RowMapping:
|
||||
row = await greenlet_spawn(self._onerow_getter, self)
|
||||
if row is _NO_ROW:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
return row
|
||||
|
||||
async def first(self) -> Optional[RowMapping]:
|
||||
"""Fetch the first object or ``None`` if no object is present.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, False, False, False)
|
||||
|
||||
async def one_or_none(self) -> Optional[RowMapping]:
|
||||
"""Return at most one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, False, False)
|
||||
|
||||
async def one(self) -> RowMapping:
|
||||
"""Return exactly one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
|
||||
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
return await greenlet_spawn(self._only_one_row, True, True, False)
|
||||
|
||||
|
||||
class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
|
||||
"""A :class:`_asyncio.AsyncResult` that's typed as returning plain
|
||||
Python tuples instead of rows.
|
||||
|
||||
Since :class:`_engine.Row` acts like a tuple in every way already,
|
||||
this class is a typing only class, regular :class:`_asyncio.AsyncResult` is
|
||||
still used at runtime.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
async def partitions(
|
||||
self, size: Optional[int] = None
|
||||
) -> AsyncIterator[Sequence[_R]]:
|
||||
"""Iterate through sub-lists of elements of the size given.
|
||||
|
||||
Equivalent to :meth:`_result.Result.partitions` except that
|
||||
tuple values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def fetchone(self) -> Optional[_R]:
|
||||
"""Fetch one tuple.
|
||||
|
||||
Equivalent to :meth:`_result.Result.fetchone` except that
|
||||
tuple values, rather than :class:`_engine.Row`
|
||||
objects, are returned.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def fetchall(self) -> Sequence[_R]:
|
||||
"""A synonym for the :meth:`_engine.ScalarResult.all` method."""
|
||||
...
|
||||
|
||||
async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]:
|
||||
"""Fetch many objects.
|
||||
|
||||
Equivalent to :meth:`_result.Result.fetchmany` except that
|
||||
tuple values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def all(self) -> Sequence[_R]: # noqa: A001
|
||||
"""Return all scalar values in a list.
|
||||
|
||||
Equivalent to :meth:`_result.Result.all` except that
|
||||
tuple values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[_R]: ...
|
||||
|
||||
async def __anext__(self) -> _R: ...
|
||||
|
||||
async def first(self) -> Optional[_R]:
|
||||
"""Fetch the first object or ``None`` if no object is present.
|
||||
|
||||
Equivalent to :meth:`_result.Result.first` except that
|
||||
tuple values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def one_or_none(self) -> Optional[_R]:
|
||||
"""Return at most one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_result.Result.one_or_none` except that
|
||||
tuple values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def one(self) -> _R:
|
||||
"""Return exactly one object or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_result.Result.one` except that
|
||||
tuple values, rather than :class:`_engine.Row` objects,
|
||||
are returned.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ...
|
||||
|
||||
@overload
|
||||
async def scalar_one(self) -> Any: ...
|
||||
|
||||
async def scalar_one(self) -> Any:
|
||||
"""Return exactly one scalar result or raise an exception.
|
||||
|
||||
This is equivalent to calling :meth:`_engine.Result.scalars`
|
||||
and then :meth:`_engine.Result.one`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_engine.Result.one`
|
||||
|
||||
:meth:`_engine.Result.scalars`
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def scalar_one_or_none(
|
||||
self: AsyncTupleResult[Tuple[_T]],
|
||||
) -> Optional[_T]: ...
|
||||
|
||||
@overload
|
||||
async def scalar_one_or_none(self) -> Optional[Any]: ...
|
||||
|
||||
async def scalar_one_or_none(self) -> Optional[Any]:
|
||||
"""Return exactly one or no scalar result.
|
||||
|
||||
This is equivalent to calling :meth:`_engine.Result.scalars`
|
||||
and then :meth:`_engine.Result.one_or_none`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_engine.Result.one_or_none`
|
||||
|
||||
:meth:`_engine.Result.scalars`
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def scalar(
|
||||
self: AsyncTupleResult[Tuple[_T]],
|
||||
) -> Optional[_T]: ...
|
||||
|
||||
@overload
|
||||
async def scalar(self) -> Any: ...
|
||||
|
||||
async def scalar(self) -> Any:
|
||||
"""Fetch the first column of the first row, and close the result
|
||||
set.
|
||||
|
||||
Returns ``None`` if there are no rows to fetch.
|
||||
|
||||
No validation is performed to test if additional rows remain.
|
||||
|
||||
After calling this method, the object is fully closed,
|
||||
e.g. the :meth:`_engine.CursorResult.close`
|
||||
method will have been called.
|
||||
|
||||
:return: a Python scalar value , or ``None`` if no rows remain.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
_RT = TypeVar("_RT", bound="Result[Any]")
|
||||
|
||||
|
||||
async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
|
||||
cursor_result: CursorResult[Any]
|
||||
|
||||
try:
|
||||
is_cursor = result._is_cursor
|
||||
except AttributeError:
|
||||
# legacy execute(DefaultGenerator) case
|
||||
return result
|
||||
|
||||
if not is_cursor:
|
||||
cursor_result = getattr(result, "raw", None) # type: ignore
|
||||
else:
|
||||
cursor_result = result # type: ignore
|
||||
if cursor_result and cursor_result.context._is_server_side:
|
||||
await greenlet_spawn(cursor_result.close)
|
||||
raise async_exc.AsyncMethodRequired(
|
||||
"Can't use the %s.%s() method with a "
|
||||
"server-side cursor. "
|
||||
"Use the %s.stream() method for an async "
|
||||
"streaming result set."
|
||||
% (
|
||||
calling_method.__self__.__class__.__name__,
|
||||
calling_method.__name__,
|
||||
calling_method.__self__.__class__.__name__,
|
||||
)
|
||||
)
|
||||
return result
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
1652
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/automap.py
Normal file
1652
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/automap.py
Normal file
File diff suppressed because it is too large
Load diff
574
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/baked.py
Normal file
574
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/baked.py
Normal file
|
@ -0,0 +1,574 @@
|
|||
# ext/baked.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""Baked query extension.
|
||||
|
||||
Provides a creational pattern for the :class:`.query.Query` object which
|
||||
allows the fully constructed object, Core select statement, and string
|
||||
compiled result to be fully cached.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import collections.abc as collections_abc
|
||||
import logging
|
||||
|
||||
from .. import exc as sa_exc
|
||||
from .. import util
|
||||
from ..orm import exc as orm_exc
|
||||
from ..orm.query import Query
|
||||
from ..orm.session import Session
|
||||
from ..sql import func
|
||||
from ..sql import literal_column
|
||||
from ..sql import util as sql_util
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Bakery:
|
||||
"""Callable which returns a :class:`.BakedQuery`.
|
||||
|
||||
This object is returned by the class method
|
||||
:meth:`.BakedQuery.bakery`. It exists as an object
|
||||
so that the "cache" can be easily inspected.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = "cls", "cache"
|
||||
|
||||
def __init__(self, cls_, cache):
|
||||
self.cls = cls_
|
||||
self.cache = cache
|
||||
|
||||
def __call__(self, initial_fn, *args):
|
||||
return self.cls(self.cache, initial_fn, args)
|
||||
|
||||
|
||||
class BakedQuery:
|
||||
"""A builder object for :class:`.query.Query` objects."""
|
||||
|
||||
__slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
|
||||
|
||||
def __init__(self, bakery, initial_fn, args=()):
|
||||
self._cache_key = ()
|
||||
self._update_cache_key(initial_fn, args)
|
||||
self.steps = [initial_fn]
|
||||
self._spoiled = False
|
||||
self._bakery = bakery
|
||||
|
||||
@classmethod
|
||||
def bakery(cls, size=200, _size_alert=None):
|
||||
"""Construct a new bakery.
|
||||
|
||||
:return: an instance of :class:`.Bakery`
|
||||
|
||||
"""
|
||||
|
||||
return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
|
||||
|
||||
def _clone(self):
|
||||
b1 = BakedQuery.__new__(BakedQuery)
|
||||
b1._cache_key = self._cache_key
|
||||
b1.steps = list(self.steps)
|
||||
b1._bakery = self._bakery
|
||||
b1._spoiled = self._spoiled
|
||||
return b1
|
||||
|
||||
def _update_cache_key(self, fn, args=()):
|
||||
self._cache_key += (fn.__code__,) + args
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
self.add_criteria(*other)
|
||||
else:
|
||||
self.add_criteria(other)
|
||||
return self
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, tuple):
|
||||
return self.with_criteria(*other)
|
||||
else:
|
||||
return self.with_criteria(other)
|
||||
|
||||
def add_criteria(self, fn, *args):
|
||||
"""Add a criteria function to this :class:`.BakedQuery`.
|
||||
|
||||
This is equivalent to using the ``+=`` operator to
|
||||
modify a :class:`.BakedQuery` in-place.
|
||||
|
||||
"""
|
||||
self._update_cache_key(fn, args)
|
||||
self.steps.append(fn)
|
||||
return self
|
||||
|
||||
def with_criteria(self, fn, *args):
|
||||
"""Add a criteria function to a :class:`.BakedQuery` cloned from this
|
||||
one.
|
||||
|
||||
This is equivalent to using the ``+`` operator to
|
||||
produce a new :class:`.BakedQuery` with modifications.
|
||||
|
||||
"""
|
||||
return self._clone().add_criteria(fn, *args)
|
||||
|
||||
def for_session(self, session):
|
||||
"""Return a :class:`_baked.Result` object for this
|
||||
:class:`.BakedQuery`.
|
||||
|
||||
This is equivalent to calling the :class:`.BakedQuery` as a
|
||||
Python callable, e.g. ``result = my_baked_query(session)``.
|
||||
|
||||
"""
|
||||
return Result(self, session)
|
||||
|
||||
def __call__(self, session):
|
||||
return self.for_session(session)
|
||||
|
||||
def spoil(self, full=False):
|
||||
"""Cancel any query caching that will occur on this BakedQuery object.
|
||||
|
||||
The BakedQuery can continue to be used normally, however additional
|
||||
creational functions will not be cached; they will be called
|
||||
on every invocation.
|
||||
|
||||
This is to support the case where a particular step in constructing
|
||||
a baked query disqualifies the query from being cacheable, such
|
||||
as a variant that relies upon some uncacheable value.
|
||||
|
||||
:param full: if False, only functions added to this
|
||||
:class:`.BakedQuery` object subsequent to the spoil step will be
|
||||
non-cached; the state of the :class:`.BakedQuery` up until
|
||||
this point will be pulled from the cache. If True, then the
|
||||
entire :class:`_query.Query` object is built from scratch each
|
||||
time, with all creational functions being called on each
|
||||
invocation.
|
||||
|
||||
"""
|
||||
if not full and not self._spoiled:
|
||||
_spoil_point = self._clone()
|
||||
_spoil_point._cache_key += ("_query_only",)
|
||||
self.steps = [_spoil_point._retrieve_baked_query]
|
||||
self._spoiled = True
|
||||
return self
|
||||
|
||||
def _effective_key(self, session):
|
||||
"""Return the key that actually goes into the cache dictionary for
|
||||
this :class:`.BakedQuery`, taking into account the given
|
||||
:class:`.Session`.
|
||||
|
||||
This basically means we also will include the session's query_class,
|
||||
as the actual :class:`_query.Query` object is part of what's cached
|
||||
and needs to match the type of :class:`_query.Query` that a later
|
||||
session will want to use.
|
||||
|
||||
"""
|
||||
return self._cache_key + (session._query_cls,)
|
||||
|
||||
def _with_lazyload_options(self, options, effective_path, cache_path=None):
|
||||
"""Cloning version of _add_lazyload_options."""
|
||||
q = self._clone()
|
||||
q._add_lazyload_options(options, effective_path, cache_path=cache_path)
|
||||
return q
|
||||
|
||||
def _add_lazyload_options(self, options, effective_path, cache_path=None):
|
||||
"""Used by per-state lazy loaders to add options to the
|
||||
"lazy load" query from a parent query.
|
||||
|
||||
Creates a cache key based on given load path and query options;
|
||||
if a repeatable cache key cannot be generated, the query is
|
||||
"spoiled" so that it won't use caching.
|
||||
|
||||
"""
|
||||
|
||||
key = ()
|
||||
|
||||
if not cache_path:
|
||||
cache_path = effective_path
|
||||
|
||||
for opt in options:
|
||||
if opt._is_legacy_option or opt._is_compile_state:
|
||||
ck = opt._generate_cache_key()
|
||||
if ck is None:
|
||||
self.spoil(full=True)
|
||||
else:
|
||||
assert not ck[1], (
|
||||
"loader options with variable bound parameters "
|
||||
"not supported with baked queries. Please "
|
||||
"use new-style select() statements for cached "
|
||||
"ORM queries."
|
||||
)
|
||||
key += ck[0]
|
||||
|
||||
self.add_criteria(
|
||||
lambda q: q._with_current_path(effective_path).options(*options),
|
||||
cache_path.path,
|
||||
key,
|
||||
)
|
||||
|
||||
def _retrieve_baked_query(self, session):
|
||||
query = self._bakery.get(self._effective_key(session), None)
|
||||
if query is None:
|
||||
query = self._as_query(session)
|
||||
self._bakery[self._effective_key(session)] = query.with_session(
|
||||
None
|
||||
)
|
||||
return query.with_session(session)
|
||||
|
||||
def _bake(self, session):
|
||||
query = self._as_query(session)
|
||||
query.session = None
|
||||
|
||||
# in 1.4, this is where before_compile() event is
|
||||
# invoked
|
||||
statement = query._statement_20()
|
||||
|
||||
# if the query is not safe to cache, we still do everything as though
|
||||
# we did cache it, since the receiver of _bake() assumes subqueryload
|
||||
# context was set up, etc.
|
||||
#
|
||||
# note also we want to cache the statement itself because this
|
||||
# allows the statement itself to hold onto its cache key that is
|
||||
# used by the Connection, which in itself is more expensive to
|
||||
# generate than what BakedQuery was able to provide in 1.3 and prior
|
||||
|
||||
if statement._compile_options._bake_ok:
|
||||
self._bakery[self._effective_key(session)] = (
|
||||
query,
|
||||
statement,
|
||||
)
|
||||
|
||||
return query, statement
|
||||
|
||||
def to_query(self, query_or_session):
|
||||
"""Return the :class:`_query.Query` object for use as a subquery.
|
||||
|
||||
This method should be used within the lambda callable being used
|
||||
to generate a step of an enclosing :class:`.BakedQuery`. The
|
||||
parameter should normally be the :class:`_query.Query` object that
|
||||
is passed to the lambda::
|
||||
|
||||
sub_bq = self.bakery(lambda s: s.query(User.name))
|
||||
sub_bq += lambda q: q.filter(
|
||||
User.id == Address.user_id).correlate(Address)
|
||||
|
||||
main_bq = self.bakery(lambda s: s.query(Address))
|
||||
main_bq += lambda q: q.filter(
|
||||
sub_bq.to_query(q).exists())
|
||||
|
||||
In the case where the subquery is used in the first callable against
|
||||
a :class:`.Session`, the :class:`.Session` is also accepted::
|
||||
|
||||
sub_bq = self.bakery(lambda s: s.query(User.name))
|
||||
sub_bq += lambda q: q.filter(
|
||||
User.id == Address.user_id).correlate(Address)
|
||||
|
||||
main_bq = self.bakery(
|
||||
lambda s: s.query(
|
||||
Address.id, sub_bq.to_query(q).scalar_subquery())
|
||||
)
|
||||
|
||||
:param query_or_session: a :class:`_query.Query` object or a class
|
||||
:class:`.Session` object, that is assumed to be within the context
|
||||
of an enclosing :class:`.BakedQuery` callable.
|
||||
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(query_or_session, Session):
|
||||
session = query_or_session
|
||||
elif isinstance(query_or_session, Query):
|
||||
session = query_or_session.session
|
||||
if session is None:
|
||||
raise sa_exc.ArgumentError(
|
||||
"Given Query needs to be associated with a Session"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Query or Session object expected, got %r."
|
||||
% type(query_or_session)
|
||||
)
|
||||
return self._as_query(session)
|
||||
|
||||
def _as_query(self, session):
|
||||
query = self.steps[0](session)
|
||||
|
||||
for step in self.steps[1:]:
|
||||
query = step(query)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class Result:
|
||||
"""Invokes a :class:`.BakedQuery` against a :class:`.Session`.
|
||||
|
||||
The :class:`_baked.Result` object is where the actual :class:`.query.Query`
|
||||
object gets created, or retrieved from the cache,
|
||||
against a target :class:`.Session`, and is then invoked for results.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = "bq", "session", "_params", "_post_criteria"
|
||||
|
||||
def __init__(self, bq, session):
|
||||
self.bq = bq
|
||||
self.session = session
|
||||
self._params = {}
|
||||
self._post_criteria = []
|
||||
|
||||
def params(self, *args, **kw):
|
||||
"""Specify parameters to be replaced into the string SQL statement."""
|
||||
|
||||
if len(args) == 1:
|
||||
kw.update(args[0])
|
||||
elif len(args) > 0:
|
||||
raise sa_exc.ArgumentError(
|
||||
"params() takes zero or one positional argument, "
|
||||
"which is a dictionary."
|
||||
)
|
||||
self._params.update(kw)
|
||||
return self
|
||||
|
||||
def _using_post_criteria(self, fns):
|
||||
if fns:
|
||||
self._post_criteria.extend(fns)
|
||||
return self
|
||||
|
||||
def with_post_criteria(self, fn):
|
||||
"""Add a criteria function that will be applied post-cache.
|
||||
|
||||
This adds a function that will be run against the
|
||||
:class:`_query.Query` object after it is retrieved from the
|
||||
cache. This currently includes **only** the
|
||||
:meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
|
||||
methods.
|
||||
|
||||
.. warning:: :meth:`_baked.Result.with_post_criteria`
|
||||
functions are applied
|
||||
to the :class:`_query.Query`
|
||||
object **after** the query's SQL statement
|
||||
object has been retrieved from the cache. Only
|
||||
:meth:`_query.Query.params` and
|
||||
:meth:`_query.Query.execution_options`
|
||||
methods should be used.
|
||||
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
|
||||
"""
|
||||
return self._using_post_criteria([fn])
|
||||
|
||||
def _as_query(self):
|
||||
q = self.bq._as_query(self.session).params(self._params)
|
||||
for fn in self._post_criteria:
|
||||
q = fn(q)
|
||||
return q
|
||||
|
||||
def __str__(self):
|
||||
return str(self._as_query())
|
||||
|
||||
def __iter__(self):
|
||||
return self._iter().__iter__()
|
||||
|
||||
def _iter(self):
|
||||
bq = self.bq
|
||||
|
||||
if not self.session.enable_baked_queries or bq._spoiled:
|
||||
return self._as_query()._iter()
|
||||
|
||||
query, statement = bq._bakery.get(
|
||||
bq._effective_key(self.session), (None, None)
|
||||
)
|
||||
if query is None:
|
||||
query, statement = bq._bake(self.session)
|
||||
|
||||
if self._params:
|
||||
q = query.params(self._params)
|
||||
else:
|
||||
q = query
|
||||
for fn in self._post_criteria:
|
||||
q = fn(q)
|
||||
|
||||
params = q._params
|
||||
execution_options = dict(q._execution_options)
|
||||
execution_options.update(
|
||||
{
|
||||
"_sa_orm_load_options": q.load_options,
|
||||
"compiled_cache": bq._bakery,
|
||||
}
|
||||
)
|
||||
|
||||
result = self.session.execute(
|
||||
statement, params, execution_options=execution_options
|
||||
)
|
||||
if result._attributes.get("is_single_entity", False):
|
||||
result = result.scalars()
|
||||
|
||||
if result._attributes.get("filtered", False):
|
||||
result = result.unique()
|
||||
|
||||
return result
|
||||
|
||||
def count(self):
|
||||
"""return the 'count'.
|
||||
|
||||
Equivalent to :meth:`_query.Query.count`.
|
||||
|
||||
Note this uses a subquery to ensure an accurate count regardless
|
||||
of the structure of the original statement.
|
||||
|
||||
"""
|
||||
|
||||
col = func.count(literal_column("*"))
|
||||
bq = self.bq.with_criteria(lambda q: q._legacy_from_self(col))
|
||||
return bq.for_session(self.session).params(self._params).scalar()
|
||||
|
||||
def scalar(self):
|
||||
"""Return the first element of the first result or None
|
||||
if no rows present. If multiple rows are returned,
|
||||
raises MultipleResultsFound.
|
||||
|
||||
Equivalent to :meth:`_query.Query.scalar`.
|
||||
|
||||
"""
|
||||
try:
|
||||
ret = self.one()
|
||||
if not isinstance(ret, collections_abc.Sequence):
|
||||
return ret
|
||||
return ret[0]
|
||||
except orm_exc.NoResultFound:
|
||||
return None
|
||||
|
||||
def first(self):
|
||||
"""Return the first row.
|
||||
|
||||
Equivalent to :meth:`_query.Query.first`.
|
||||
|
||||
"""
|
||||
|
||||
bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
|
||||
return (
|
||||
bq.for_session(self.session)
|
||||
.params(self._params)
|
||||
._using_post_criteria(self._post_criteria)
|
||||
._iter()
|
||||
.first()
|
||||
)
|
||||
|
||||
def one(self):
|
||||
"""Return exactly one result or raise an exception.
|
||||
|
||||
Equivalent to :meth:`_query.Query.one`.
|
||||
|
||||
"""
|
||||
return self._iter().one()
|
||||
|
||||
def one_or_none(self):
|
||||
"""Return one or zero results, or raise an exception for multiple
|
||||
rows.
|
||||
|
||||
Equivalent to :meth:`_query.Query.one_or_none`.
|
||||
|
||||
"""
|
||||
return self._iter().one_or_none()
|
||||
|
||||
def all(self):
|
||||
"""Return all rows.
|
||||
|
||||
Equivalent to :meth:`_query.Query.all`.
|
||||
|
||||
"""
|
||||
return self._iter().all()
|
||||
|
||||
def get(self, ident):
|
||||
"""Retrieve an object based on identity.
|
||||
|
||||
Equivalent to :meth:`_query.Query.get`.
|
||||
|
||||
"""
|
||||
|
||||
query = self.bq.steps[0](self.session)
|
||||
return query._get_impl(ident, self._load_on_pk_identity)
|
||||
|
||||
def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
|
||||
"""Load the given primary key identity from the database."""
|
||||
|
||||
mapper = query._raw_columns[0]._annotations["parententity"]
|
||||
|
||||
_get_clause, _get_params = mapper._get_clause
|
||||
|
||||
def setup(query):
|
||||
_lcl_get_clause = _get_clause
|
||||
q = query._clone()
|
||||
q._get_condition()
|
||||
q._order_by = None
|
||||
|
||||
# None present in ident - turn those comparisons
|
||||
# into "IS NULL"
|
||||
if None in primary_key_identity:
|
||||
nones = {
|
||||
_get_params[col].key
|
||||
for col, value in zip(
|
||||
mapper.primary_key, primary_key_identity
|
||||
)
|
||||
if value is None
|
||||
}
|
||||
_lcl_get_clause = sql_util.adapt_criterion_to_null(
|
||||
_lcl_get_clause, nones
|
||||
)
|
||||
|
||||
# TODO: can mapper._get_clause be pre-adapted?
|
||||
q._where_criteria = (
|
||||
sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
|
||||
)
|
||||
|
||||
for fn in self._post_criteria:
|
||||
q = fn(q)
|
||||
return q
|
||||
|
||||
# cache the query against a key that includes
|
||||
# which positions in the primary key are NULL
|
||||
# (remember, we can map to an OUTER JOIN)
|
||||
bq = self.bq
|
||||
|
||||
# add the clause we got from mapper._get_clause to the cache
|
||||
# key so that if a race causes multiple calls to _get_clause,
|
||||
# we've cached on ours
|
||||
bq = bq._clone()
|
||||
bq._cache_key += (_get_clause,)
|
||||
|
||||
bq = bq.with_criteria(
|
||||
setup, tuple(elem is None for elem in primary_key_identity)
|
||||
)
|
||||
|
||||
params = {
|
||||
_get_params[primary_key].key: id_val
|
||||
for id_val, primary_key in zip(
|
||||
primary_key_identity, mapper.primary_key
|
||||
)
|
||||
}
|
||||
|
||||
result = list(bq.for_session(self.session).params(**params))
|
||||
l = len(result)
|
||||
if l > 1:
|
||||
raise orm_exc.MultipleResultsFound()
|
||||
elif l:
|
||||
return result[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
bakery = BakedQuery.bakery
|
555
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py
Normal file
555
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py
Normal file
|
@ -0,0 +1,555 @@
|
|||
# ext/compiler.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""Provides an API for creation of custom ClauseElements and compilers.
|
||||
|
||||
Synopsis
|
||||
========
|
||||
|
||||
Usage involves the creation of one or more
|
||||
:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
|
||||
more callables defining its compilation::
|
||||
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
class MyColumn(ColumnClause):
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(MyColumn)
|
||||
def compile_mycolumn(element, compiler, **kw):
|
||||
return "[%s]" % element.name
|
||||
|
||||
Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
|
||||
the base expression element for named column objects. The ``compiles``
|
||||
decorator registers itself with the ``MyColumn`` class so that it is invoked
|
||||
when the object is compiled to a string::
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
s = select(MyColumn('x'), MyColumn('y'))
|
||||
print(str(s))
|
||||
|
||||
Produces::
|
||||
|
||||
SELECT [x], [y]
|
||||
|
||||
Dialect-specific compilation rules
|
||||
==================================
|
||||
|
||||
Compilers can also be made dialect-specific. The appropriate compiler will be
|
||||
invoked for the dialect in use::
|
||||
|
||||
from sqlalchemy.schema import DDLElement
|
||||
|
||||
class AlterColumn(DDLElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, column, cmd):
|
||||
self.column = column
|
||||
self.cmd = cmd
|
||||
|
||||
@compiles(AlterColumn)
|
||||
def visit_alter_column(element, compiler, **kw):
|
||||
return "ALTER COLUMN %s ..." % element.column.name
|
||||
|
||||
@compiles(AlterColumn, 'postgresql')
|
||||
def visit_alter_column(element, compiler, **kw):
|
||||
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
|
||||
element.column.name)
|
||||
|
||||
The second ``visit_alter_table`` will be invoked when any ``postgresql``
|
||||
dialect is used.
|
||||
|
||||
.. _compilerext_compiling_subelements:
|
||||
|
||||
Compiling sub-elements of a custom expression construct
|
||||
=======================================================
|
||||
|
||||
The ``compiler`` argument is the
|
||||
:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
|
||||
can be inspected for any information about the in-progress compilation,
|
||||
including ``compiler.dialect``, ``compiler.statement`` etc. The
|
||||
:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
|
||||
:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
|
||||
method which can be used for compilation of embedded attributes::
|
||||
|
||||
from sqlalchemy.sql.expression import Executable, ClauseElement
|
||||
|
||||
class InsertFromSelect(Executable, ClauseElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, table, select):
|
||||
self.table = table
|
||||
self.select = select
|
||||
|
||||
@compiles(InsertFromSelect)
|
||||
def visit_insert_from_select(element, compiler, **kw):
|
||||
return "INSERT INTO %s (%s)" % (
|
||||
compiler.process(element.table, asfrom=True, **kw),
|
||||
compiler.process(element.select, **kw)
|
||||
)
|
||||
|
||||
insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
|
||||
print(insert)
|
||||
|
||||
Produces::
|
||||
|
||||
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
|
||||
FROM mytable WHERE mytable.x > :x_1)"
|
||||
|
||||
.. note::
|
||||
|
||||
The above ``InsertFromSelect`` construct is only an example, this actual
|
||||
functionality is already available using the
|
||||
:meth:`_expression.Insert.from_select` method.
|
||||
|
||||
|
||||
Cross Compiling between SQL and DDL compilers
|
||||
---------------------------------------------
|
||||
|
||||
SQL and DDL constructs are each compiled using different base compilers -
|
||||
``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
|
||||
compilation rules of SQL expressions from within a DDL expression. The
|
||||
``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
|
||||
below where we generate a CHECK constraint that embeds a SQL expression::
|
||||
|
||||
@compiles(MyConstraint)
|
||||
def compile_my_constraint(constraint, ddlcompiler, **kw):
|
||||
kw['literal_binds'] = True
|
||||
return "CONSTRAINT %s CHECK (%s)" % (
|
||||
constraint.name,
|
||||
ddlcompiler.sql_compiler.process(
|
||||
constraint.expression, **kw)
|
||||
)
|
||||
|
||||
Above, we add an additional flag to the process step as called by
|
||||
:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
|
||||
indicates that any SQL expression which refers to a :class:`.BindParameter`
|
||||
object or other "literal" object such as those which refer to strings or
|
||||
integers should be rendered **in-place**, rather than being referred to as
|
||||
a bound parameter; when emitting DDL, bound parameters are typically not
|
||||
supported.
|
||||
|
||||
|
||||
Changing the default compilation of existing constructs
|
||||
=======================================================
|
||||
|
||||
The compiler extension applies just as well to the existing constructs. When
|
||||
overriding the compilation of a built in SQL construct, the @compiles
|
||||
decorator is invoked upon the appropriate class (be sure to use the class,
|
||||
i.e. ``Insert`` or ``Select``, instead of the creation function such
|
||||
as ``insert()`` or ``select()``).
|
||||
|
||||
Within the new compilation function, to get at the "original" compilation
|
||||
routine, use the appropriate visit_XXX method - this
|
||||
because compiler.process() will call upon the overriding routine and cause
|
||||
an endless loop. Such as, to add "prefix" to all insert statements::
|
||||
|
||||
from sqlalchemy.sql.expression import Insert
|
||||
|
||||
@compiles(Insert)
|
||||
def prefix_inserts(insert, compiler, **kw):
|
||||
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
|
||||
|
||||
The above compiler will prefix all INSERT statements with "some prefix" when
|
||||
compiled.
|
||||
|
||||
.. _type_compilation_extension:
|
||||
|
||||
Changing Compilation of Types
|
||||
=============================
|
||||
|
||||
``compiler`` works for types, too, such as below where we implement the
|
||||
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
|
||||
|
||||
@compiles(String, 'mssql')
|
||||
@compiles(VARCHAR, 'mssql')
|
||||
def compile_varchar(element, compiler, **kw):
|
||||
if element.length == 'max':
|
||||
return "VARCHAR('max')"
|
||||
else:
|
||||
return compiler.visit_VARCHAR(element, **kw)
|
||||
|
||||
foo = Table('foo', metadata,
|
||||
Column('data', VARCHAR('max'))
|
||||
)
|
||||
|
||||
Subclassing Guidelines
|
||||
======================
|
||||
|
||||
A big part of using the compiler extension is subclassing SQLAlchemy
|
||||
expression constructs. To make this easier, the expression and
|
||||
schema packages feature a set of "bases" intended for common tasks.
|
||||
A synopsis is as follows:
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
|
||||
expression class. Any SQL expression can be derived from this base, and is
|
||||
probably the best choice for longer constructs such as specialized INSERT
|
||||
statements.
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
|
||||
"column-like" elements. Anything that you'd place in the "columns" clause of
|
||||
a SELECT statement (as well as order by and group by) can derive from this -
|
||||
the object will automatically have Python "comparison" behavior.
|
||||
|
||||
:class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
|
||||
``type`` member which is expression's return type. This can be established
|
||||
at the instance level in the constructor, or at the class level if its
|
||||
generally constant::
|
||||
|
||||
class timestamp(ColumnElement):
|
||||
type = TIMESTAMP()
|
||||
inherit_cache = True
|
||||
|
||||
* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
|
||||
``ColumnElement`` and a "from clause" like object, and represents a SQL
|
||||
function or stored procedure type of call. Since most databases support
|
||||
statements along the line of "SELECT FROM <some function>"
|
||||
``FunctionElement`` adds in the ability to be used in the FROM clause of a
|
||||
``select()`` construct::
|
||||
|
||||
from sqlalchemy.sql.expression import FunctionElement
|
||||
|
||||
class coalesce(FunctionElement):
|
||||
name = 'coalesce'
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(coalesce)
|
||||
def compile(element, compiler, **kw):
|
||||
return "coalesce(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
@compiles(coalesce, 'oracle')
|
||||
def compile(element, compiler, **kw):
|
||||
if len(element.clauses) > 2:
|
||||
raise TypeError("coalesce only supports two arguments on Oracle")
|
||||
return "nvl(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
* :class:`.ExecutableDDLElement` - The root of all DDL expressions,
|
||||
like CREATE TABLE, ALTER TABLE, etc. Compilation of
|
||||
:class:`.ExecutableDDLElement` subclasses is issued by a
|
||||
:class:`.DDLCompiler` instead of a :class:`.SQLCompiler`.
|
||||
:class:`.ExecutableDDLElement` can also be used as an event hook in
|
||||
conjunction with event hooks like :meth:`.DDLEvents.before_create` and
|
||||
:meth:`.DDLEvents.after_create`, allowing the construct to be invoked
|
||||
automatically during CREATE TABLE and DROP TABLE sequences.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`metadata_ddl_toplevel` - contains examples of associating
|
||||
:class:`.DDL` objects (which are themselves :class:`.ExecutableDDLElement`
|
||||
instances) with :class:`.DDLEvents` event hooks.
|
||||
|
||||
* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
|
||||
should be used with any expression class that represents a "standalone"
|
||||
SQL statement that can be passed directly to an ``execute()`` method. It
|
||||
is already implicit within ``DDLElement`` and ``FunctionElement``.
|
||||
|
||||
Most of the above constructs also respond to SQL statement caching. A
|
||||
subclassed construct will want to define the caching behavior for the object,
|
||||
which usually means setting the flag ``inherit_cache`` to the value of
|
||||
``False`` or ``True``. See the next section :ref:`compilerext_caching`
|
||||
for background.
|
||||
|
||||
|
||||
.. _compilerext_caching:
|
||||
|
||||
Enabling Caching Support for Custom Constructs
|
||||
==============================================
|
||||
|
||||
SQLAlchemy as of version 1.4 includes a
|
||||
:ref:`SQL compilation caching facility <sql_caching>` which will allow
|
||||
equivalent SQL constructs to cache their stringified form, along with other
|
||||
structural information used to fetch results from the statement.
|
||||
|
||||
For reasons discussed at :ref:`caching_caveats`, the implementation of this
|
||||
caching system takes a conservative approach towards including custom SQL
|
||||
constructs and/or subclasses within the caching system. This includes that
|
||||
any user-defined SQL constructs, including all the examples for this
|
||||
extension, will not participate in caching by default unless they positively
|
||||
assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
|
||||
attribute when set to ``True`` at the class level of a specific subclass
|
||||
will indicate that instances of this class may be safely cached, using the
|
||||
cache key generation scheme of the immediate superclass. This applies
|
||||
for example to the "synopsis" example indicated previously::
|
||||
|
||||
class MyColumn(ColumnClause):
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(MyColumn)
|
||||
def compile_mycolumn(element, compiler, **kw):
|
||||
return "[%s]" % element.name
|
||||
|
||||
Above, the ``MyColumn`` class does not include any new state that
|
||||
affects its SQL compilation; the cache key of ``MyColumn`` instances will
|
||||
make use of that of the ``ColumnClause`` superclass, meaning it will take
|
||||
into account the class of the object (``MyColumn``), the string name and
|
||||
datatype of the object::
|
||||
|
||||
>>> MyColumn("some_name", String())._generate_cache_key()
|
||||
CacheKey(
|
||||
key=('0', <class '__main__.MyColumn'>,
|
||||
'name', 'some_name',
|
||||
'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
|
||||
('length', None), ('collation', None))
|
||||
), bindparams=[])
|
||||
|
||||
For objects that are likely to be **used liberally as components within many
|
||||
larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
|
||||
datatypes, it's important that **caching be enabled as much as possible**, as
|
||||
this may otherwise negatively affect performance.
|
||||
|
||||
An example of an object that **does** contain state which affects its SQL
|
||||
compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
|
||||
this is an "INSERT FROM SELECT" construct that combines together a
|
||||
:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
|
||||
which independently affect the SQL string generation of the construct. For
|
||||
this class, the example illustrates that it simply does not participate in
|
||||
caching::
|
||||
|
||||
class InsertFromSelect(Executable, ClauseElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, table, select):
|
||||
self.table = table
|
||||
self.select = select
|
||||
|
||||
@compiles(InsertFromSelect)
|
||||
def visit_insert_from_select(element, compiler, **kw):
|
||||
return "INSERT INTO %s (%s)" % (
|
||||
compiler.process(element.table, asfrom=True, **kw),
|
||||
compiler.process(element.select, **kw)
|
||||
)
|
||||
|
||||
While it is also possible that the above ``InsertFromSelect`` could be made to
|
||||
produce a cache key that is composed of that of the :class:`_schema.Table` and
|
||||
:class:`_sql.Select` components together, the API for this is not at the moment
|
||||
fully public. However, for an "INSERT FROM SELECT" construct, which is only
|
||||
used by itself for specific operations, caching is not as critical as in the
|
||||
previous example.
|
||||
|
||||
For objects that are **used in relative isolation and are generally
|
||||
standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
|
||||
SELECT", **caching is generally less critical** as the lack of caching for such
|
||||
a construct will have only localized implications for that specific operation.
|
||||
|
||||
|
||||
Further Examples
|
||||
================
|
||||
|
||||
"UTC timestamp" function
|
||||
-------------------------
|
||||
|
||||
A function that works like "CURRENT_TIMESTAMP" except applies the
|
||||
appropriate conversions so that the time is in UTC time. Timestamps are best
|
||||
stored in relational databases as UTC, without time zones. UTC so that your
|
||||
database doesn't think time has gone backwards in the hour when daylight
|
||||
savings ends, without timezones because timezones are like character
|
||||
encodings - they're best applied only at the endpoints of an application
|
||||
(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
|
||||
|
||||
For PostgreSQL and Microsoft SQL Server::
|
||||
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.types import DateTime
|
||||
|
||||
class utcnow(expression.FunctionElement):
|
||||
type = DateTime()
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(utcnow, 'postgresql')
|
||||
def pg_utcnow(element, compiler, **kw):
|
||||
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
|
||||
|
||||
@compiles(utcnow, 'mssql')
|
||||
def ms_utcnow(element, compiler, **kw):
|
||||
return "GETUTCDATE()"
|
||||
|
||||
Example usage::
|
||||
|
||||
from sqlalchemy import (
|
||||
Table, Column, Integer, String, DateTime, MetaData
|
||||
)
|
||||
metadata = MetaData()
|
||||
event = Table("event", metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("description", String(50), nullable=False),
|
||||
Column("timestamp", DateTime, server_default=utcnow())
|
||||
)
|
||||
|
||||
"GREATEST" function
|
||||
-------------------
|
||||
|
||||
The "GREATEST" function is given any number of arguments and returns the one
|
||||
that is of the highest value - its equivalent to Python's ``max``
|
||||
function. A SQL standard version versus a CASE based version which only
|
||||
accommodates two arguments::
|
||||
|
||||
from sqlalchemy.sql import expression, case
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.types import Numeric
|
||||
|
||||
class greatest(expression.FunctionElement):
|
||||
type = Numeric()
|
||||
name = 'greatest'
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(greatest)
|
||||
def default_greatest(element, compiler, **kw):
|
||||
return compiler.visit_function(element)
|
||||
|
||||
@compiles(greatest, 'sqlite')
|
||||
@compiles(greatest, 'mssql')
|
||||
@compiles(greatest, 'oracle')
|
||||
def case_greatest(element, compiler, **kw):
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw)
|
||||
|
||||
Example usage::
|
||||
|
||||
Session.query(Account).\
|
||||
filter(
|
||||
greatest(
|
||||
Account.checking_balance,
|
||||
Account.savings_balance) > 10000
|
||||
)
|
||||
|
||||
"false" expression
|
||||
------------------
|
||||
|
||||
Render a "false" constant expression, rendering as "0" on platforms that
|
||||
don't have a "false" constant::
|
||||
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
|
||||
class sql_false(expression.ColumnElement):
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(sql_false)
|
||||
def default_false(element, compiler, **kw):
|
||||
return "false"
|
||||
|
||||
@compiles(sql_false, 'mssql')
|
||||
@compiles(sql_false, 'mysql')
|
||||
@compiles(sql_false, 'oracle')
|
||||
def int_false(element, compiler, **kw):
|
||||
return "0"
|
||||
|
||||
Example usage::
|
||||
|
||||
from sqlalchemy import select, union_all
|
||||
|
||||
exp = union_all(
|
||||
select(users.c.name, sql_false().label("enrolled")),
|
||||
select(customers.c.name, customers.c.enrolled)
|
||||
)
|
||||
|
||||
"""
|
||||
from .. import exc
|
||||
from ..sql import sqltypes
|
||||
|
||||
|
||||
def compiles(class_, *specs):
|
||||
"""Register a function as a compiler for a
|
||||
given :class:`_expression.ClauseElement` type."""
|
||||
|
||||
def decorate(fn):
|
||||
# get an existing @compiles handler
|
||||
existing = class_.__dict__.get("_compiler_dispatcher", None)
|
||||
|
||||
# get the original handler. All ClauseElement classes have one
|
||||
# of these, but some TypeEngine classes will not.
|
||||
existing_dispatch = getattr(class_, "_compiler_dispatch", None)
|
||||
|
||||
if not existing:
|
||||
existing = _dispatcher()
|
||||
|
||||
if existing_dispatch:
|
||||
|
||||
def _wrap_existing_dispatch(element, compiler, **kw):
|
||||
try:
|
||||
return existing_dispatch(element, compiler, **kw)
|
||||
except exc.UnsupportedCompilationError as uce:
|
||||
raise exc.UnsupportedCompilationError(
|
||||
compiler,
|
||||
type(element),
|
||||
message="%s construct has no default "
|
||||
"compilation handler." % type(element),
|
||||
) from uce
|
||||
|
||||
existing.specs["default"] = _wrap_existing_dispatch
|
||||
|
||||
# TODO: why is the lambda needed ?
|
||||
setattr(
|
||||
class_,
|
||||
"_compiler_dispatch",
|
||||
lambda *arg, **kw: existing(*arg, **kw),
|
||||
)
|
||||
setattr(class_, "_compiler_dispatcher", existing)
|
||||
|
||||
if specs:
|
||||
for s in specs:
|
||||
existing.specs[s] = fn
|
||||
|
||||
else:
|
||||
existing.specs["default"] = fn
|
||||
return fn
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def deregister(class_):
|
||||
"""Remove all custom compilers associated with a given
|
||||
:class:`_expression.ClauseElement` type.
|
||||
|
||||
"""
|
||||
|
||||
if hasattr(class_, "_compiler_dispatcher"):
|
||||
class_._compiler_dispatch = class_._original_compiler_dispatch
|
||||
del class_._compiler_dispatcher
|
||||
|
||||
|
||||
class _dispatcher:
|
||||
def __init__(self):
|
||||
self.specs = {}
|
||||
|
||||
def __call__(self, element, compiler, **kw):
|
||||
# TODO: yes, this could also switch off of DBAPI in use.
|
||||
fn = self.specs.get(compiler.dialect.name, None)
|
||||
if not fn:
|
||||
try:
|
||||
fn = self.specs["default"]
|
||||
except KeyError as ke:
|
||||
raise exc.UnsupportedCompilationError(
|
||||
compiler,
|
||||
type(element),
|
||||
message="%s construct has no default "
|
||||
"compilation handler." % type(element),
|
||||
) from ke
|
||||
|
||||
# if compilation includes add_to_result_map, collect add_to_result_map
|
||||
# arguments from the user-defined callable, which are probably none
|
||||
# because this is not public API. if it wasn't called, then call it
|
||||
# ourselves.
|
||||
arm = kw.get("add_to_result_map", None)
|
||||
if arm:
|
||||
arm_collection = []
|
||||
kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
|
||||
|
||||
expr = fn(element, compiler, **kw)
|
||||
|
||||
if arm:
|
||||
if not arm_collection:
|
||||
arm_collection.append(
|
||||
(None, None, (element,), sqltypes.NULLTYPE)
|
||||
)
|
||||
for tup in arm_collection:
|
||||
arm(*tup)
|
||||
return expr
|
|
@ -0,0 +1,65 @@
|
|||
# ext/declarative/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from .extensions import AbstractConcreteBase
|
||||
from .extensions import ConcreteBase
|
||||
from .extensions import DeferredReflection
|
||||
from ... import util
|
||||
from ...orm.decl_api import as_declarative as _as_declarative
|
||||
from ...orm.decl_api import declarative_base as _declarative_base
|
||||
from ...orm.decl_api import DeclarativeMeta
|
||||
from ...orm.decl_api import declared_attr
|
||||
from ...orm.decl_api import has_inherited_table as _has_inherited_table
|
||||
from ...orm.decl_api import synonym_for as _synonym_for
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``declarative_base()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.declarative_base`."
|
||||
)
|
||||
def declarative_base(*arg, **kw):
|
||||
return _declarative_base(*arg, **kw)
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``as_declarative()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.as_declarative`"
|
||||
)
|
||||
def as_declarative(*arg, **kw):
|
||||
return _as_declarative(*arg, **kw)
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``has_inherited_table()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.has_inherited_table`."
|
||||
)
|
||||
def has_inherited_table(*arg, **kw):
|
||||
return _has_inherited_table(*arg, **kw)
|
||||
|
||||
|
||||
@util.moved_20(
|
||||
"The ``synonym_for()`` function is now available as "
|
||||
":func:`sqlalchemy.orm.synonym_for`"
|
||||
)
|
||||
def synonym_for(*arg, **kw):
|
||||
return _synonym_for(*arg, **kw)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"declarative_base",
|
||||
"synonym_for",
|
||||
"has_inherited_table",
|
||||
"instrument_declarative",
|
||||
"declared_attr",
|
||||
"as_declarative",
|
||||
"ConcreteBase",
|
||||
"AbstractConcreteBase",
|
||||
"DeclarativeMeta",
|
||||
"DeferredReflection",
|
||||
]
|
|
@ -0,0 +1,548 @@
|
|||
# ext/declarative/extensions.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""Public API functions and helpers for declarative."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ... import exc as sa_exc
|
||||
from ...engine import Connection
|
||||
from ...engine import Engine
|
||||
from ...orm import exc as orm_exc
|
||||
from ...orm import relationships
|
||||
from ...orm.base import _mapper_or_none
|
||||
from ...orm.clsregistry import _resolver
|
||||
from ...orm.decl_base import _DeferredMapperConfig
|
||||
from ...orm.util import polymorphic_union
|
||||
from ...schema import Table
|
||||
from ...util import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql.schema import MetaData
|
||||
|
||||
|
||||
class ConcreteBase:
|
||||
"""A helper class for 'concrete' declarative mappings.
|
||||
|
||||
:class:`.ConcreteBase` will use the :func:`.polymorphic_union`
|
||||
function automatically, against all tables mapped as a subclass
|
||||
to this class. The function is called via the
|
||||
``__declare_last__()`` function, which is essentially
|
||||
a hook for the :meth:`.after_configured` event.
|
||||
|
||||
:class:`.ConcreteBase` produces a mapped
|
||||
table for the class itself. Compare to :class:`.AbstractConcreteBase`,
|
||||
which does not.
|
||||
|
||||
Example::
|
||||
|
||||
from sqlalchemy.ext.declarative import ConcreteBase
|
||||
|
||||
class Employee(ConcreteBase, Base):
|
||||
__tablename__ = 'employee'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'employee',
|
||||
'concrete':True}
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True}
|
||||
|
||||
|
||||
The name of the discriminator column used by :func:`.polymorphic_union`
|
||||
defaults to the name ``type``. To suit the use case of a mapping where an
|
||||
actual column in a mapped table is already named ``type``, the
|
||||
discriminator name can be configured by setting the
|
||||
``_concrete_discriminator_name`` attribute::
|
||||
|
||||
class Employee(ConcreteBase, Base):
|
||||
_concrete_discriminator_name = '_concrete_discriminator'
|
||||
|
||||
.. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
|
||||
attribute to :class:`_declarative.ConcreteBase` so that the
|
||||
virtual discriminator column name can be customized.
|
||||
|
||||
.. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
|
||||
need only be placed on the basemost class to take correct effect for
|
||||
all subclasses. An explicit error message is now raised if the
|
||||
mapped column names conflict with the discriminator name, whereas
|
||||
in the 1.3.x series there would be some warnings and then a non-useful
|
||||
query would be generated.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.AbstractConcreteBase`
|
||||
|
||||
:ref:`concrete_inheritance`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _create_polymorphic_union(cls, mappers, discriminator_name):
|
||||
return polymorphic_union(
|
||||
OrderedDict(
|
||||
(mp.polymorphic_identity, mp.local_table) for mp in mappers
|
||||
),
|
||||
discriminator_name,
|
||||
"pjoin",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __declare_first__(cls):
|
||||
m = cls.__mapper__
|
||||
if m.with_polymorphic:
|
||||
return
|
||||
|
||||
discriminator_name = (
|
||||
getattr(cls, "_concrete_discriminator_name", None) or "type"
|
||||
)
|
||||
|
||||
mappers = list(m.self_and_descendants)
|
||||
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
|
||||
m._set_with_polymorphic(("*", pjoin))
|
||||
m._set_polymorphic_on(pjoin.c[discriminator_name])
|
||||
|
||||
|
||||
class AbstractConcreteBase(ConcreteBase):
|
||||
"""A helper class for 'concrete' declarative mappings.
|
||||
|
||||
:class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
|
||||
function automatically, against all tables mapped as a subclass
|
||||
to this class. The function is called via the
|
||||
``__declare_first__()`` function, which is essentially
|
||||
a hook for the :meth:`.before_configured` event.
|
||||
|
||||
:class:`.AbstractConcreteBase` applies :class:`_orm.Mapper` for its
|
||||
immediately inheriting class, as would occur for any other
|
||||
declarative mapped class. However, the :class:`_orm.Mapper` is not
|
||||
mapped to any particular :class:`.Table` object. Instead, it's
|
||||
mapped directly to the "polymorphic" selectable produced by
|
||||
:func:`.polymorphic_union`, and performs no persistence operations on its
|
||||
own. Compare to :class:`.ConcreteBase`, which maps its
|
||||
immediately inheriting class to an actual
|
||||
:class:`.Table` that stores rows directly.
|
||||
|
||||
.. note::
|
||||
|
||||
The :class:`.AbstractConcreteBase` delays the mapper creation of the
|
||||
base class until all the subclasses have been defined,
|
||||
as it needs to create a mapping against a selectable that will include
|
||||
all subclass tables. In order to achieve this, it waits for the
|
||||
**mapper configuration event** to occur, at which point it scans
|
||||
through all the configured subclasses and sets up a mapping that will
|
||||
query against all subclasses at once.
|
||||
|
||||
While this event is normally invoked automatically, in the case of
|
||||
:class:`.AbstractConcreteBase`, it may be necessary to invoke it
|
||||
explicitly after **all** subclass mappings are defined, if the first
|
||||
operation is to be a query against this base class. To do so, once all
|
||||
the desired classes have been configured, the
|
||||
:meth:`_orm.registry.configure` method on the :class:`_orm.registry`
|
||||
in use can be invoked, which is available in relation to a particular
|
||||
declarative base class::
|
||||
|
||||
Base.registry.configure()
|
||||
|
||||
Example::
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class Employee(AbstractConcreteBase, Base):
|
||||
pass
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True
|
||||
}
|
||||
|
||||
Base.registry.configure()
|
||||
|
||||
The abstract base class is handled by declarative in a special way;
|
||||
at class configuration time, it behaves like a declarative mixin
|
||||
or an ``__abstract__`` base class. Once classes are configured
|
||||
and mappings are produced, it then gets mapped itself, but
|
||||
after all of its descendants. This is a very unique system of mapping
|
||||
not found in any other SQLAlchemy API feature.
|
||||
|
||||
Using this approach, we can specify columns and properties
|
||||
that will take place on mapped subclasses, in the way that
|
||||
we normally do as in :ref:`declarative_mixins`::
|
||||
|
||||
from sqlalchemy.ext.declarative import AbstractConcreteBase
|
||||
|
||||
class Company(Base):
|
||||
__tablename__ = 'company'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
class Employee(AbstractConcreteBase, Base):
|
||||
strict_attrs = True
|
||||
|
||||
employee_id = Column(Integer, primary_key=True)
|
||||
|
||||
@declared_attr
|
||||
def company_id(cls):
|
||||
return Column(ForeignKey('company.id'))
|
||||
|
||||
@declared_attr
|
||||
def company(cls):
|
||||
return relationship("Company")
|
||||
|
||||
class Manager(Employee):
|
||||
__tablename__ = 'manager'
|
||||
|
||||
name = Column(String(50))
|
||||
manager_data = Column(String(40))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity':'manager',
|
||||
'concrete':True
|
||||
}
|
||||
|
||||
Base.registry.configure()
|
||||
|
||||
When we make use of our mappings however, both ``Manager`` and
|
||||
``Employee`` will have an independently usable ``.company`` attribute::
|
||||
|
||||
session.execute(
|
||||
select(Employee).filter(Employee.company.has(id=5))
|
||||
)
|
||||
|
||||
:param strict_attrs: when specified on the base class, "strict" attribute
|
||||
mode is enabled which attempts to limit ORM mapped attributes on the
|
||||
base class to only those that are immediately present, while still
|
||||
preserving "polymorphic" loading behavior.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.ConcreteBase`
|
||||
|
||||
:ref:`concrete_inheritance`
|
||||
|
||||
:ref:`abstract_concrete_base`
|
||||
|
||||
"""
|
||||
|
||||
__no_table__ = True
|
||||
|
||||
@classmethod
|
||||
def __declare_first__(cls):
|
||||
cls._sa_decl_prepare_nocascade()
|
||||
|
||||
@classmethod
|
||||
def _sa_decl_prepare_nocascade(cls):
|
||||
if getattr(cls, "__mapper__", None):
|
||||
return
|
||||
|
||||
to_map = _DeferredMapperConfig.config_for_cls(cls)
|
||||
|
||||
# can't rely on 'self_and_descendants' here
|
||||
# since technically an immediate subclass
|
||||
# might not be mapped, but a subclass
|
||||
# may be.
|
||||
mappers = []
|
||||
stack = list(cls.__subclasses__())
|
||||
while stack:
|
||||
klass = stack.pop()
|
||||
stack.extend(klass.__subclasses__())
|
||||
mn = _mapper_or_none(klass)
|
||||
if mn is not None:
|
||||
mappers.append(mn)
|
||||
|
||||
discriminator_name = (
|
||||
getattr(cls, "_concrete_discriminator_name", None) or "type"
|
||||
)
|
||||
pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
|
||||
|
||||
# For columns that were declared on the class, these
|
||||
# are normally ignored with the "__no_table__" mapping,
|
||||
# unless they have a different attribute key vs. col name
|
||||
# and are in the properties argument.
|
||||
# In that case, ensure we update the properties entry
|
||||
# to the correct column from the pjoin target table.
|
||||
declared_cols = set(to_map.declared_columns)
|
||||
declared_col_keys = {c.key for c in declared_cols}
|
||||
for k, v in list(to_map.properties.items()):
|
||||
if v in declared_cols:
|
||||
to_map.properties[k] = pjoin.c[v.key]
|
||||
declared_col_keys.remove(v.key)
|
||||
|
||||
to_map.local_table = pjoin
|
||||
|
||||
strict_attrs = cls.__dict__.get("strict_attrs", False)
|
||||
|
||||
m_args = to_map.mapper_args_fn or dict
|
||||
|
||||
def mapper_args():
|
||||
args = m_args()
|
||||
args["polymorphic_on"] = pjoin.c[discriminator_name]
|
||||
args["polymorphic_abstract"] = True
|
||||
if strict_attrs:
|
||||
args["include_properties"] = (
|
||||
set(pjoin.primary_key)
|
||||
| declared_col_keys
|
||||
| {discriminator_name}
|
||||
)
|
||||
args["with_polymorphic"] = ("*", pjoin)
|
||||
return args
|
||||
|
||||
to_map.mapper_args_fn = mapper_args
|
||||
|
||||
to_map.map()
|
||||
|
||||
stack = [cls]
|
||||
while stack:
|
||||
scls = stack.pop(0)
|
||||
stack.extend(scls.__subclasses__())
|
||||
sm = _mapper_or_none(scls)
|
||||
if sm and sm.concrete and sm.inherits is None:
|
||||
for sup_ in scls.__mro__[1:]:
|
||||
sup_sm = _mapper_or_none(sup_)
|
||||
if sup_sm:
|
||||
sm._set_concrete_base(sup_sm)
|
||||
break
|
||||
|
||||
@classmethod
|
||||
def _sa_raise_deferred_config(cls):
|
||||
raise orm_exc.UnmappedClassError(
|
||||
cls,
|
||||
msg="Class %s is a subclass of AbstractConcreteBase and "
|
||||
"has a mapping pending until all subclasses are defined. "
|
||||
"Call the sqlalchemy.orm.configure_mappers() function after "
|
||||
"all subclasses have been defined to "
|
||||
"complete the mapping of this class."
|
||||
% orm_exc._safe_cls_name(cls),
|
||||
)
|
||||
|
||||
|
||||
class DeferredReflection:
|
||||
"""A helper class for construction of mappings based on
|
||||
a deferred reflection step.
|
||||
|
||||
Normally, declarative can be used with reflection by
|
||||
setting a :class:`_schema.Table` object using autoload_with=engine
|
||||
as the ``__table__`` attribute on a declarative class.
|
||||
The caveat is that the :class:`_schema.Table` must be fully
|
||||
reflected, or at the very least have a primary key column,
|
||||
at the point at which a normal declarative mapping is
|
||||
constructed, meaning the :class:`_engine.Engine` must be available
|
||||
at class declaration time.
|
||||
|
||||
The :class:`.DeferredReflection` mixin moves the construction
|
||||
of mappers to be at a later point, after a specific
|
||||
method is called which first reflects all :class:`_schema.Table`
|
||||
objects created so far. Classes can define it as such::
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import DeferredReflection
|
||||
Base = declarative_base()
|
||||
|
||||
class MyClass(DeferredReflection, Base):
|
||||
__tablename__ = 'mytable'
|
||||
|
||||
Above, ``MyClass`` is not yet mapped. After a series of
|
||||
classes have been defined in the above fashion, all tables
|
||||
can be reflected and mappings created using
|
||||
:meth:`.prepare`::
|
||||
|
||||
engine = create_engine("someengine://...")
|
||||
DeferredReflection.prepare(engine)
|
||||
|
||||
The :class:`.DeferredReflection` mixin can be applied to individual
|
||||
classes, used as the base for the declarative base itself,
|
||||
or used in a custom abstract class. Using an abstract base
|
||||
allows that only a subset of classes to be prepared for a
|
||||
particular prepare step, which is necessary for applications
|
||||
that use more than one engine. For example, if an application
|
||||
has two engines, you might use two bases, and prepare each
|
||||
separately, e.g.::
|
||||
|
||||
class ReflectedOne(DeferredReflection, Base):
|
||||
__abstract__ = True
|
||||
|
||||
class ReflectedTwo(DeferredReflection, Base):
|
||||
__abstract__ = True
|
||||
|
||||
class MyClass(ReflectedOne):
|
||||
__tablename__ = 'mytable'
|
||||
|
||||
class MyOtherClass(ReflectedOne):
|
||||
__tablename__ = 'myothertable'
|
||||
|
||||
class YetAnotherClass(ReflectedTwo):
|
||||
__tablename__ = 'yetanothertable'
|
||||
|
||||
# ... etc.
|
||||
|
||||
Above, the class hierarchies for ``ReflectedOne`` and
|
||||
``ReflectedTwo`` can be configured separately::
|
||||
|
||||
ReflectedOne.prepare(engine_one)
|
||||
ReflectedTwo.prepare(engine_two)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`orm_declarative_reflected_deferred_reflection` - in the
|
||||
:ref:`orm_declarative_table_config_toplevel` section.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def prepare(
|
||||
cls, bind: Union[Engine, Connection], **reflect_kw: Any
|
||||
) -> None:
|
||||
r"""Reflect all :class:`_schema.Table` objects for all current
|
||||
:class:`.DeferredReflection` subclasses
|
||||
|
||||
:param bind: :class:`_engine.Engine` or :class:`_engine.Connection`
|
||||
instance
|
||||
|
||||
..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also
|
||||
accepted.
|
||||
|
||||
:param \**reflect_kw: additional keyword arguments passed to
|
||||
:meth:`_schema.MetaData.reflect`, such as
|
||||
:paramref:`_schema.MetaData.reflect.views`.
|
||||
|
||||
.. versionadded:: 2.0.16
|
||||
|
||||
"""
|
||||
|
||||
to_map = _DeferredMapperConfig.classes_for_base(cls)
|
||||
|
||||
metadata_to_table = collections.defaultdict(set)
|
||||
|
||||
# first collect the primary __table__ for each class into a
|
||||
# collection of metadata/schemaname -> table names
|
||||
for thingy in to_map:
|
||||
if thingy.local_table is not None:
|
||||
metadata_to_table[
|
||||
(thingy.local_table.metadata, thingy.local_table.schema)
|
||||
].add(thingy.local_table.name)
|
||||
|
||||
# then reflect all those tables into their metadatas
|
||||
|
||||
if isinstance(bind, Connection):
|
||||
conn = bind
|
||||
ctx = contextlib.nullcontext(enter_result=conn)
|
||||
elif isinstance(bind, Engine):
|
||||
ctx = bind.connect()
|
||||
else:
|
||||
raise sa_exc.ArgumentError(
|
||||
f"Expected Engine or Connection, got {bind!r}"
|
||||
)
|
||||
|
||||
with ctx as conn:
|
||||
for (metadata, schema), table_names in metadata_to_table.items():
|
||||
metadata.reflect(
|
||||
conn,
|
||||
only=table_names,
|
||||
schema=schema,
|
||||
extend_existing=True,
|
||||
autoload_replace=False,
|
||||
**reflect_kw,
|
||||
)
|
||||
|
||||
metadata_to_table.clear()
|
||||
|
||||
# .map() each class, then go through relationships and look
|
||||
# for secondary
|
||||
for thingy in to_map:
|
||||
thingy.map()
|
||||
|
||||
mapper = thingy.cls.__mapper__
|
||||
metadata = mapper.class_.metadata
|
||||
|
||||
for rel in mapper._props.values():
|
||||
if (
|
||||
isinstance(rel, relationships.RelationshipProperty)
|
||||
and rel._init_args.secondary._is_populated()
|
||||
):
|
||||
secondary_arg = rel._init_args.secondary
|
||||
|
||||
if isinstance(secondary_arg.argument, Table):
|
||||
secondary_table = secondary_arg.argument
|
||||
metadata_to_table[
|
||||
(
|
||||
secondary_table.metadata,
|
||||
secondary_table.schema,
|
||||
)
|
||||
].add(secondary_table.name)
|
||||
elif isinstance(secondary_arg.argument, str):
|
||||
_, resolve_arg = _resolver(rel.parent.class_, rel)
|
||||
|
||||
resolver = resolve_arg(
|
||||
secondary_arg.argument, True
|
||||
)
|
||||
metadata_to_table[
|
||||
(metadata, thingy.local_table.schema)
|
||||
].add(secondary_arg.argument)
|
||||
|
||||
resolver._resolvers += (
|
||||
cls._sa_deferred_table_resolver(metadata),
|
||||
)
|
||||
|
||||
secondary_arg.argument = resolver()
|
||||
|
||||
for (metadata, schema), table_names in metadata_to_table.items():
|
||||
metadata.reflect(
|
||||
conn,
|
||||
only=table_names,
|
||||
schema=schema,
|
||||
extend_existing=True,
|
||||
autoload_replace=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _sa_deferred_table_resolver(
|
||||
cls, metadata: MetaData
|
||||
) -> Callable[[str], Table]:
|
||||
def _resolve(key: str) -> Table:
|
||||
# reflection has already occurred so this Table would have
|
||||
# its contents already
|
||||
return Table(key, metadata)
|
||||
|
||||
return _resolve
|
||||
|
||||
_sa_decl_prepare = True
|
||||
|
||||
@classmethod
|
||||
def _sa_raise_deferred_config(cls):
|
||||
raise orm_exc.UnmappedClassError(
|
||||
cls,
|
||||
msg="Class %s is a subclass of DeferredReflection. "
|
||||
"Mappings are not produced until the .prepare() "
|
||||
"method is called on the class hierarchy."
|
||||
% orm_exc._safe_cls_name(cls),
|
||||
)
|
|
@ -0,0 +1,481 @@
|
|||
# ext/horizontal_shard.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Horizontal sharding support.
|
||||
|
||||
Defines a rudimental 'horizontal sharding' system which allows a Session to
|
||||
distribute queries and persistence operations across multiple databases.
|
||||
|
||||
For a usage example, see the :ref:`examples_sharding` example included in
|
||||
the source distribution.
|
||||
|
||||
.. deepalchemy:: The horizontal sharding extension is an advanced feature,
|
||||
involving a complex statement -> database interaction as well as
|
||||
use of semi-public APIs for non-trivial cases. Simpler approaches to
|
||||
refering to multiple database "shards", most commonly using a distinct
|
||||
:class:`_orm.Session` per "shard", should always be considered first
|
||||
before using this more complex and less-production-tested system.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from .. import event
|
||||
from .. import exc
|
||||
from .. import inspect
|
||||
from .. import util
|
||||
from ..orm import PassiveFlag
|
||||
from ..orm._typing import OrmExecuteOptionsParameter
|
||||
from ..orm.interfaces import ORMOption
|
||||
from ..orm.mapper import Mapper
|
||||
from ..orm.query import Query
|
||||
from ..orm.session import _BindArguments
|
||||
from ..orm.session import _PKIdentityArgument
|
||||
from ..orm.session import Session
|
||||
from ..util.typing import Protocol
|
||||
from ..util.typing import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..engine.base import Connection
|
||||
from ..engine.base import Engine
|
||||
from ..engine.base import OptionEngine
|
||||
from ..engine.result import IteratorResult
|
||||
from ..engine.result import Result
|
||||
from ..orm import LoaderCallableStatus
|
||||
from ..orm._typing import _O
|
||||
from ..orm.bulk_persistence import BulkUDCompileState
|
||||
from ..orm.context import QueryContext
|
||||
from ..orm.session import _EntityBindKey
|
||||
from ..orm.session import _SessionBind
|
||||
from ..orm.session import ORMExecuteState
|
||||
from ..orm.state import InstanceState
|
||||
from ..sql import Executable
|
||||
from ..sql._typing import _TP
|
||||
from ..sql.elements import ClauseElement
|
||||
|
||||
__all__ = ["ShardedSession", "ShardedQuery"]
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
|
||||
ShardIdentifier = str
|
||||
|
||||
|
||||
class ShardChooser(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
mapper: Optional[Mapper[_T]],
|
||||
instance: Any,
|
||||
clause: Optional[ClauseElement],
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class IdentityChooser(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
mapper: Mapper[_T],
|
||||
primary_key: _PKIdentityArgument,
|
||||
*,
|
||||
lazy_loaded_from: Optional[InstanceState[Any]],
|
||||
execution_options: OrmExecuteOptionsParameter,
|
||||
bind_arguments: _BindArguments,
|
||||
**kw: Any,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ShardedQuery(Query[_T]):
|
||||
"""Query class used with :class:`.ShardedSession`.
|
||||
|
||||
.. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
|
||||
:class:`.Query` class. The :class:`.ShardedSession` now supports
|
||||
2.0 style execution via the :meth:`.ShardedSession.execute` method.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
assert isinstance(self.session, ShardedSession)
|
||||
|
||||
self.identity_chooser = self.session.identity_chooser
|
||||
self.execute_chooser = self.session.execute_chooser
|
||||
self._shard_id = None
|
||||
|
||||
def set_shard(self, shard_id: ShardIdentifier) -> Self:
|
||||
"""Return a new query, limited to a single shard ID.
|
||||
|
||||
All subsequent operations with the returned query will
|
||||
be against the single shard regardless of other state.
|
||||
|
||||
The shard_id can be passed for a 2.0 style execution to the
|
||||
bind_arguments dictionary of :meth:`.Session.execute`::
|
||||
|
||||
results = session.execute(
|
||||
stmt,
|
||||
bind_arguments={"shard_id": "my_shard"}
|
||||
)
|
||||
|
||||
"""
|
||||
return self.execution_options(_sa_shard_id=shard_id)
|
||||
|
||||
|
||||
class ShardedSession(Session):
|
||||
shard_chooser: ShardChooser
|
||||
identity_chooser: IdentityChooser
|
||||
execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shard_chooser: ShardChooser,
|
||||
identity_chooser: Optional[IdentityChooser] = None,
|
||||
execute_chooser: Optional[
|
||||
Callable[[ORMExecuteState], Iterable[Any]]
|
||||
] = None,
|
||||
shards: Optional[Dict[str, Any]] = None,
|
||||
query_cls: Type[Query[_T]] = ShardedQuery,
|
||||
*,
|
||||
id_chooser: Optional[
|
||||
Callable[[Query[_T], Iterable[_T]], Iterable[Any]]
|
||||
] = None,
|
||||
query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Construct a ShardedSession.
|
||||
|
||||
:param shard_chooser: A callable which, passed a Mapper, a mapped
|
||||
instance, and possibly a SQL clause, returns a shard ID. This id
|
||||
may be based off of the attributes present within the object, or on
|
||||
some round-robin scheme. If the scheme is based on a selection, it
|
||||
should set whatever state on the instance to mark it in the future as
|
||||
participating in that shard.
|
||||
|
||||
:param identity_chooser: A callable, passed a Mapper and primary key
|
||||
argument, which should return a list of shard ids where this
|
||||
primary key might reside.
|
||||
|
||||
.. versionchanged:: 2.0 The ``identity_chooser`` parameter
|
||||
supersedes the ``id_chooser`` parameter.
|
||||
|
||||
:param execute_chooser: For a given :class:`.ORMExecuteState`,
|
||||
returns the list of shard_ids
|
||||
where the query should be issued. Results from all shards returned
|
||||
will be combined together into a single listing.
|
||||
|
||||
.. versionchanged:: 1.4 The ``execute_chooser`` parameter
|
||||
supersedes the ``query_chooser`` parameter.
|
||||
|
||||
:param shards: A dictionary of string shard names
|
||||
to :class:`~sqlalchemy.engine.Engine` objects.
|
||||
|
||||
"""
|
||||
super().__init__(query_cls=query_cls, **kwargs)
|
||||
|
||||
event.listen(
|
||||
self, "do_orm_execute", execute_and_instances, retval=True
|
||||
)
|
||||
self.shard_chooser = shard_chooser
|
||||
|
||||
if id_chooser:
|
||||
_id_chooser = id_chooser
|
||||
util.warn_deprecated(
|
||||
"The ``id_chooser`` parameter is deprecated; "
|
||||
"please use ``identity_chooser``.",
|
||||
"2.0",
|
||||
)
|
||||
|
||||
def _legacy_identity_chooser(
|
||||
mapper: Mapper[_T],
|
||||
primary_key: _PKIdentityArgument,
|
||||
*,
|
||||
lazy_loaded_from: Optional[InstanceState[Any]],
|
||||
execution_options: OrmExecuteOptionsParameter,
|
||||
bind_arguments: _BindArguments,
|
||||
**kw: Any,
|
||||
) -> Any:
|
||||
q = self.query(mapper)
|
||||
if lazy_loaded_from:
|
||||
q = q._set_lazyload_from(lazy_loaded_from)
|
||||
return _id_chooser(q, primary_key)
|
||||
|
||||
self.identity_chooser = _legacy_identity_chooser
|
||||
elif identity_chooser:
|
||||
self.identity_chooser = identity_chooser
|
||||
else:
|
||||
raise exc.ArgumentError(
|
||||
"identity_chooser or id_chooser is required"
|
||||
)
|
||||
|
||||
if query_chooser:
|
||||
_query_chooser = query_chooser
|
||||
util.warn_deprecated(
|
||||
"The ``query_chooser`` parameter is deprecated; "
|
||||
"please use ``execute_chooser``.",
|
||||
"1.4",
|
||||
)
|
||||
if execute_chooser:
|
||||
raise exc.ArgumentError(
|
||||
"Can't pass query_chooser and execute_chooser "
|
||||
"at the same time."
|
||||
)
|
||||
|
||||
def _default_execute_chooser(
|
||||
orm_context: ORMExecuteState,
|
||||
) -> Iterable[Any]:
|
||||
return _query_chooser(orm_context.statement)
|
||||
|
||||
if execute_chooser is None:
|
||||
execute_chooser = _default_execute_chooser
|
||||
|
||||
if execute_chooser is None:
|
||||
raise exc.ArgumentError(
|
||||
"execute_chooser or query_chooser is required"
|
||||
)
|
||||
self.execute_chooser = execute_chooser
|
||||
self.__shards: Dict[ShardIdentifier, _SessionBind] = {}
|
||||
if shards is not None:
|
||||
for k in shards:
|
||||
self.bind_shard(k, shards[k])
|
||||
|
||||
def _identity_lookup(
|
||||
self,
|
||||
mapper: Mapper[_O],
|
||||
primary_key_identity: Union[Any, Tuple[Any, ...]],
|
||||
identity_token: Optional[Any] = None,
|
||||
passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
|
||||
lazy_loaded_from: Optional[InstanceState[Any]] = None,
|
||||
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
|
||||
bind_arguments: Optional[_BindArguments] = None,
|
||||
**kw: Any,
|
||||
) -> Union[Optional[_O], LoaderCallableStatus]:
|
||||
"""override the default :meth:`.Session._identity_lookup` method so
|
||||
that we search for a given non-token primary key identity across all
|
||||
possible identity tokens (e.g. shard ids).
|
||||
|
||||
.. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
|
||||
the :class:`_query.Query` object to the :class:`.Session`.
|
||||
|
||||
"""
|
||||
|
||||
if identity_token is not None:
|
||||
obj = super()._identity_lookup(
|
||||
mapper,
|
||||
primary_key_identity,
|
||||
identity_token=identity_token,
|
||||
**kw,
|
||||
)
|
||||
|
||||
return obj
|
||||
else:
|
||||
for shard_id in self.identity_chooser(
|
||||
mapper,
|
||||
primary_key_identity,
|
||||
lazy_loaded_from=lazy_loaded_from,
|
||||
execution_options=execution_options,
|
||||
bind_arguments=dict(bind_arguments) if bind_arguments else {},
|
||||
):
|
||||
obj2 = super()._identity_lookup(
|
||||
mapper,
|
||||
primary_key_identity,
|
||||
identity_token=shard_id,
|
||||
lazy_loaded_from=lazy_loaded_from,
|
||||
**kw,
|
||||
)
|
||||
if obj2 is not None:
|
||||
return obj2
|
||||
|
||||
return None
|
||||
|
||||
def _choose_shard_and_assign(
|
||||
self,
|
||||
mapper: Optional[_EntityBindKey[_O]],
|
||||
instance: Any,
|
||||
**kw: Any,
|
||||
) -> Any:
|
||||
if instance is not None:
|
||||
state = inspect(instance)
|
||||
if state.key:
|
||||
token = state.key[2]
|
||||
assert token is not None
|
||||
return token
|
||||
elif state.identity_token:
|
||||
return state.identity_token
|
||||
|
||||
assert isinstance(mapper, Mapper)
|
||||
shard_id = self.shard_chooser(mapper, instance, **kw)
|
||||
if instance is not None:
|
||||
state.identity_token = shard_id
|
||||
return shard_id
|
||||
|
||||
def connection_callable( # type: ignore [override]
|
||||
self,
|
||||
mapper: Optional[Mapper[_T]] = None,
|
||||
instance: Optional[Any] = None,
|
||||
shard_id: Optional[ShardIdentifier] = None,
|
||||
**kw: Any,
|
||||
) -> Connection:
|
||||
"""Provide a :class:`_engine.Connection` to use in the unit of work
|
||||
flush process.
|
||||
|
||||
"""
|
||||
|
||||
if shard_id is None:
|
||||
shard_id = self._choose_shard_and_assign(mapper, instance)
|
||||
|
||||
if self.in_transaction():
|
||||
trans = self.get_transaction()
|
||||
assert trans is not None
|
||||
return trans.connection(mapper, shard_id=shard_id)
|
||||
else:
|
||||
bind = self.get_bind(
|
||||
mapper=mapper, shard_id=shard_id, instance=instance
|
||||
)
|
||||
|
||||
if isinstance(bind, Engine):
|
||||
return bind.connect(**kw)
|
||||
else:
|
||||
assert isinstance(bind, Connection)
|
||||
return bind
|
||||
|
||||
def get_bind(
|
||||
self,
|
||||
mapper: Optional[_EntityBindKey[_O]] = None,
|
||||
*,
|
||||
shard_id: Optional[ShardIdentifier] = None,
|
||||
instance: Optional[Any] = None,
|
||||
clause: Optional[ClauseElement] = None,
|
||||
**kw: Any,
|
||||
) -> _SessionBind:
|
||||
if shard_id is None:
|
||||
shard_id = self._choose_shard_and_assign(
|
||||
mapper, instance=instance, clause=clause
|
||||
)
|
||||
assert shard_id is not None
|
||||
return self.__shards[shard_id]
|
||||
|
||||
def bind_shard(
|
||||
self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine]
|
||||
) -> None:
|
||||
self.__shards[shard_id] = bind
|
||||
|
||||
|
||||
class set_shard_id(ORMOption):
|
||||
"""a loader option for statements to apply a specific shard id to the
|
||||
primary query as well as for additional relationship and column
|
||||
loaders.
|
||||
|
||||
The :class:`_horizontal.set_shard_id` option may be applied using
|
||||
the :meth:`_sql.Executable.options` method of any executable statement::
|
||||
|
||||
stmt = (
|
||||
select(MyObject).
|
||||
where(MyObject.name == 'some name').
|
||||
options(set_shard_id("shard1"))
|
||||
)
|
||||
|
||||
Above, the statement when invoked will limit to the "shard1" shard
|
||||
identifier for the primary query as well as for all relationship and
|
||||
column loading strategies, including eager loaders such as
|
||||
:func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`,
|
||||
and the lazy relationship loader :func:`_orm.lazyload`.
|
||||
|
||||
In this way, the :class:`_horizontal.set_shard_id` option has much wider
|
||||
scope than using the "shard_id" argument within the
|
||||
:paramref:`_orm.Session.execute.bind_arguments` dictionary.
|
||||
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("shard_id", "propagate_to_loaders")
|
||||
|
||||
def __init__(
|
||||
self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True
|
||||
):
|
||||
"""Construct a :class:`_horizontal.set_shard_id` option.
|
||||
|
||||
:param shard_id: shard identifier
|
||||
:param propagate_to_loaders: if left at its default of ``True``, the
|
||||
shard option will take place for lazy loaders such as
|
||||
:func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option
|
||||
will not be propagated to loaded objects. Note that :func:`_orm.defer`
|
||||
always limits to the shard_id of the parent row in any case, so the
|
||||
parameter only has a net effect on the behavior of the
|
||||
:func:`_orm.lazyload` strategy.
|
||||
|
||||
"""
|
||||
self.shard_id = shard_id
|
||||
self.propagate_to_loaders = propagate_to_loaders
|
||||
|
||||
|
||||
def execute_and_instances(
|
||||
orm_context: ORMExecuteState,
|
||||
) -> Union[Result[_T], IteratorResult[_TP]]:
|
||||
active_options: Union[
|
||||
None,
|
||||
QueryContext.default_load_options,
|
||||
Type[QueryContext.default_load_options],
|
||||
BulkUDCompileState.default_update_options,
|
||||
Type[BulkUDCompileState.default_update_options],
|
||||
]
|
||||
|
||||
if orm_context.is_select:
|
||||
active_options = orm_context.load_options
|
||||
|
||||
elif orm_context.is_update or orm_context.is_delete:
|
||||
active_options = orm_context.update_delete_options
|
||||
else:
|
||||
active_options = None
|
||||
|
||||
session = orm_context.session
|
||||
assert isinstance(session, ShardedSession)
|
||||
|
||||
def iter_for_shard(
|
||||
shard_id: ShardIdentifier,
|
||||
) -> Union[Result[_T], IteratorResult[_TP]]:
|
||||
bind_arguments = dict(orm_context.bind_arguments)
|
||||
bind_arguments["shard_id"] = shard_id
|
||||
|
||||
orm_context.update_execution_options(identity_token=shard_id)
|
||||
return orm_context.invoke_statement(bind_arguments=bind_arguments)
|
||||
|
||||
for orm_opt in orm_context._non_compile_orm_options:
|
||||
# TODO: if we had an ORMOption that gets applied at ORM statement
|
||||
# execution time, that would allow this to be more generalized.
|
||||
# for now just iterate and look for our options
|
||||
if isinstance(orm_opt, set_shard_id):
|
||||
shard_id = orm_opt.shard_id
|
||||
break
|
||||
else:
|
||||
if active_options and active_options._identity_token is not None:
|
||||
shard_id = active_options._identity_token
|
||||
elif "_sa_shard_id" in orm_context.execution_options:
|
||||
shard_id = orm_context.execution_options["_sa_shard_id"]
|
||||
elif "shard_id" in orm_context.bind_arguments:
|
||||
shard_id = orm_context.bind_arguments["shard_id"]
|
||||
else:
|
||||
shard_id = None
|
||||
|
||||
if shard_id is not None:
|
||||
return iter_for_shard(shard_id)
|
||||
else:
|
||||
partial = []
|
||||
for shard_id in session.execute_chooser(orm_context):
|
||||
result_ = iter_for_shard(shard_id)
|
||||
partial.append(result_)
|
||||
return partial[0].merge(*partial[1:])
|
1514
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py
Normal file
1514
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,341 @@
|
|||
# ext/indexable.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""Define attributes on ORM-mapped classes that have "index" attributes for
|
||||
columns with :class:`_types.Indexable` types.
|
||||
|
||||
"index" means the attribute is associated with an element of an
|
||||
:class:`_types.Indexable` column with the predefined index to access it.
|
||||
The :class:`_types.Indexable` types include types such as
|
||||
:class:`_types.ARRAY`, :class:`_types.JSON` and
|
||||
:class:`_postgresql.HSTORE`.
|
||||
|
||||
|
||||
|
||||
The :mod:`~sqlalchemy.ext.indexable` extension provides
|
||||
:class:`_schema.Column`-like interface for any element of an
|
||||
:class:`_types.Indexable` typed column. In simple cases, it can be
|
||||
treated as a :class:`_schema.Column` - mapped attribute.
|
||||
|
||||
Synopsis
|
||||
========
|
||||
|
||||
Given ``Person`` as a model with a primary key and JSON data field.
|
||||
While this field may have any number of elements encoded within it,
|
||||
we would like to refer to the element called ``name`` individually
|
||||
as a dedicated attribute which behaves like a standalone column::
|
||||
|
||||
from sqlalchemy import Column, JSON, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.indexable import index_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
name = index_property('data', 'name')
|
||||
|
||||
|
||||
Above, the ``name`` attribute now behaves like a mapped column. We
|
||||
can compose a new ``Person`` and set the value of ``name``::
|
||||
|
||||
>>> person = Person(name='Alchemist')
|
||||
|
||||
The value is now accessible::
|
||||
|
||||
>>> person.name
|
||||
'Alchemist'
|
||||
|
||||
Behind the scenes, the JSON field was initialized to a new blank dictionary
|
||||
and the field was set::
|
||||
|
||||
>>> person.data
|
||||
{"name": "Alchemist'}
|
||||
|
||||
The field is mutable in place::
|
||||
|
||||
>>> person.name = 'Renamed'
|
||||
>>> person.name
|
||||
'Renamed'
|
||||
>>> person.data
|
||||
{'name': 'Renamed'}
|
||||
|
||||
When using :class:`.index_property`, the change that we make to the indexable
|
||||
structure is also automatically tracked as history; we no longer need
|
||||
to use :class:`~.mutable.MutableDict` in order to track this change
|
||||
for the unit of work.
|
||||
|
||||
Deletions work normally as well::
|
||||
|
||||
>>> del person.name
|
||||
>>> person.data
|
||||
{}
|
||||
|
||||
Above, deletion of ``person.name`` deletes the value from the dictionary,
|
||||
but not the dictionary itself.
|
||||
|
||||
A missing key will produce ``AttributeError``::
|
||||
|
||||
>>> person = Person()
|
||||
>>> person.name
|
||||
...
|
||||
AttributeError: 'name'
|
||||
|
||||
Unless you set a default value::
|
||||
|
||||
>>> class Person(Base):
|
||||
>>> __tablename__ = 'person'
|
||||
>>>
|
||||
>>> id = Column(Integer, primary_key=True)
|
||||
>>> data = Column(JSON)
|
||||
>>>
|
||||
>>> name = index_property('data', 'name', default=None) # See default
|
||||
|
||||
>>> person = Person()
|
||||
>>> print(person.name)
|
||||
None
|
||||
|
||||
|
||||
The attributes are also accessible at the class level.
|
||||
Below, we illustrate ``Person.name`` used to generate
|
||||
an indexed SQL criteria::
|
||||
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> session = Session()
|
||||
>>> query = session.query(Person).filter(Person.name == 'Alchemist')
|
||||
|
||||
The above query is equivalent to::
|
||||
|
||||
>>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
|
||||
|
||||
Multiple :class:`.index_property` objects can be chained to produce
|
||||
multiple levels of indexing::
|
||||
|
||||
from sqlalchemy import Column, JSON, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.indexable import index_property
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
birthday = index_property('data', 'birthday')
|
||||
year = index_property('birthday', 'year')
|
||||
month = index_property('birthday', 'month')
|
||||
day = index_property('birthday', 'day')
|
||||
|
||||
Above, a query such as::
|
||||
|
||||
q = session.query(Person).filter(Person.year == '1980')
|
||||
|
||||
On a PostgreSQL backend, the above query will render as::
|
||||
|
||||
SELECT person.id, person.data
|
||||
FROM person
|
||||
WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
|
||||
|
||||
Default Values
|
||||
==============
|
||||
|
||||
:class:`.index_property` includes special behaviors for when the indexed
|
||||
data structure does not exist, and a set operation is called:
|
||||
|
||||
* For an :class:`.index_property` that is given an integer index value,
|
||||
the default data structure will be a Python list of ``None`` values,
|
||||
at least as long as the index value; the value is then set at its
|
||||
place in the list. This means for an index value of zero, the list
|
||||
will be initialized to ``[None]`` before setting the given value,
|
||||
and for an index value of five, the list will be initialized to
|
||||
``[None, None, None, None, None]`` before setting the fifth element
|
||||
to the given value. Note that an existing list is **not** extended
|
||||
in place to receive a value.
|
||||
|
||||
* for an :class:`.index_property` that is given any other kind of index
|
||||
value (e.g. strings usually), a Python dictionary is used as the
|
||||
default data structure.
|
||||
|
||||
* The default data structure can be set to any Python callable using the
|
||||
:paramref:`.index_property.datatype` parameter, overriding the previous
|
||||
rules.
|
||||
|
||||
|
||||
Subclassing
|
||||
===========
|
||||
|
||||
:class:`.index_property` can be subclassed, in particular for the common
|
||||
use case of providing coercion of values or SQL expressions as they are
|
||||
accessed. Below is a common recipe for use with a PostgreSQL JSON type,
|
||||
where we want to also include automatic casting plus ``astext()``::
|
||||
|
||||
class pg_json_property(index_property):
|
||||
def __init__(self, attr_name, index, cast_type):
|
||||
super(pg_json_property, self).__init__(attr_name, index)
|
||||
self.cast_type = cast_type
|
||||
|
||||
def expr(self, model):
|
||||
expr = super(pg_json_property, self).expr(model)
|
||||
return expr.astext.cast(self.cast_type)
|
||||
|
||||
The above subclass can be used with the PostgreSQL-specific
|
||||
version of :class:`_postgresql.JSON`::
|
||||
|
||||
from sqlalchemy import Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Person(Base):
|
||||
__tablename__ = 'person'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON)
|
||||
|
||||
age = pg_json_property('data', 'age', Integer)
|
||||
|
||||
The ``age`` attribute at the instance level works as before; however
|
||||
when rendering SQL, PostgreSQL's ``->>`` operator will be used
|
||||
for indexed access, instead of the usual index operator of ``->``::
|
||||
|
||||
>>> query = session.query(Person).filter(Person.age < 20)
|
||||
|
||||
The above query will render::
|
||||
|
||||
SELECT person.id, person.data
|
||||
FROM person
|
||||
WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
|
||||
|
||||
""" # noqa
|
||||
from .. import inspect
|
||||
from ..ext.hybrid import hybrid_property
|
||||
from ..orm.attributes import flag_modified
|
||||
|
||||
|
||||
__all__ = ["index_property"]
|
||||
|
||||
|
||||
class index_property(hybrid_property): # noqa
|
||||
"""A property generator. The generated property describes an object
|
||||
attribute that corresponds to an :class:`_types.Indexable`
|
||||
column.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:mod:`sqlalchemy.ext.indexable`
|
||||
|
||||
"""
|
||||
|
||||
_NO_DEFAULT_ARGUMENT = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attr_name,
|
||||
index,
|
||||
default=_NO_DEFAULT_ARGUMENT,
|
||||
datatype=None,
|
||||
mutable=True,
|
||||
onebased=True,
|
||||
):
|
||||
"""Create a new :class:`.index_property`.
|
||||
|
||||
:param attr_name:
|
||||
An attribute name of an `Indexable` typed column, or other
|
||||
attribute that returns an indexable structure.
|
||||
:param index:
|
||||
The index to be used for getting and setting this value. This
|
||||
should be the Python-side index value for integers.
|
||||
:param default:
|
||||
A value which will be returned instead of `AttributeError`
|
||||
when there is not a value at given index.
|
||||
:param datatype: default datatype to use when the field is empty.
|
||||
By default, this is derived from the type of index used; a
|
||||
Python list for an integer index, or a Python dictionary for
|
||||
any other style of index. For a list, the list will be
|
||||
initialized to a list of None values that is at least
|
||||
``index`` elements long.
|
||||
:param mutable: if False, writes and deletes to the attribute will
|
||||
be disallowed.
|
||||
:param onebased: assume the SQL representation of this value is
|
||||
one-based; that is, the first index in SQL is 1, not zero.
|
||||
"""
|
||||
|
||||
if mutable:
|
||||
super().__init__(self.fget, self.fset, self.fdel, self.expr)
|
||||
else:
|
||||
super().__init__(self.fget, None, None, self.expr)
|
||||
self.attr_name = attr_name
|
||||
self.index = index
|
||||
self.default = default
|
||||
is_numeric = isinstance(index, int)
|
||||
onebased = is_numeric and onebased
|
||||
|
||||
if datatype is not None:
|
||||
self.datatype = datatype
|
||||
else:
|
||||
if is_numeric:
|
||||
self.datatype = lambda: [None for x in range(index + 1)]
|
||||
else:
|
||||
self.datatype = dict
|
||||
self.onebased = onebased
|
||||
|
||||
def _fget_default(self, err=None):
|
||||
if self.default == self._NO_DEFAULT_ARGUMENT:
|
||||
raise AttributeError(self.attr_name) from err
|
||||
else:
|
||||
return self.default
|
||||
|
||||
def fget(self, instance):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name)
|
||||
if column_value is None:
|
||||
return self._fget_default()
|
||||
try:
|
||||
value = column_value[self.index]
|
||||
except (KeyError, IndexError) as err:
|
||||
return self._fget_default(err)
|
||||
else:
|
||||
return value
|
||||
|
||||
def fset(self, instance, value):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name, None)
|
||||
if column_value is None:
|
||||
column_value = self.datatype()
|
||||
setattr(instance, attr_name, column_value)
|
||||
column_value[self.index] = value
|
||||
setattr(instance, attr_name, column_value)
|
||||
if attr_name in inspect(instance).mapper.attrs:
|
||||
flag_modified(instance, attr_name)
|
||||
|
||||
def fdel(self, instance):
|
||||
attr_name = self.attr_name
|
||||
column_value = getattr(instance, attr_name)
|
||||
if column_value is None:
|
||||
raise AttributeError(self.attr_name)
|
||||
try:
|
||||
del column_value[self.index]
|
||||
except KeyError as err:
|
||||
raise AttributeError(self.attr_name) from err
|
||||
else:
|
||||
setattr(instance, attr_name, column_value)
|
||||
flag_modified(instance, attr_name)
|
||||
|
||||
def expr(self, model):
|
||||
column = getattr(model, self.attr_name)
|
||||
index = self.index
|
||||
if self.onebased:
|
||||
index += 1
|
||||
return column[index]
|
|
@ -0,0 +1,450 @@
|
|||
# ext/instrumentation.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""Extensible class instrumentation.
|
||||
|
||||
The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
|
||||
systems of class instrumentation within the ORM. Class instrumentation
|
||||
refers to how the ORM places attributes on the class which maintain
|
||||
data and track changes to that data, as well as event hooks installed
|
||||
on the class.
|
||||
|
||||
.. note::
|
||||
The extension package is provided for the benefit of integration
|
||||
with other object management packages, which already perform
|
||||
their own instrumentation. It is not intended for general use.
|
||||
|
||||
For examples of how the instrumentation extension is used,
|
||||
see the example :ref:`examples_instrumentation`.
|
||||
|
||||
"""
|
||||
import weakref
|
||||
|
||||
from .. import util
|
||||
from ..orm import attributes
|
||||
from ..orm import base as orm_base
|
||||
from ..orm import collections
|
||||
from ..orm import exc as orm_exc
|
||||
from ..orm import instrumentation as orm_instrumentation
|
||||
from ..orm import util as orm_util
|
||||
from ..orm.instrumentation import _default_dict_getter
|
||||
from ..orm.instrumentation import _default_manager_getter
|
||||
from ..orm.instrumentation import _default_opt_manager_getter
|
||||
from ..orm.instrumentation import _default_state_getter
|
||||
from ..orm.instrumentation import ClassManager
|
||||
from ..orm.instrumentation import InstrumentationFactory
|
||||
|
||||
|
||||
INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
|
||||
"""Attribute, elects custom instrumentation when present on a mapped class.
|
||||
|
||||
Allows a class to specify a slightly or wildly different technique for
|
||||
tracking changes made to mapped attributes and collections.
|
||||
|
||||
Only one instrumentation implementation is allowed in a given object
|
||||
inheritance hierarchy.
|
||||
|
||||
The value of this attribute must be a callable and will be passed a class
|
||||
object. The callable must return one of:
|
||||
|
||||
- An instance of an :class:`.InstrumentationManager` or subclass
|
||||
- An object implementing all or some of InstrumentationManager (TODO)
|
||||
- A dictionary of callables, implementing all or some of the above (TODO)
|
||||
- An instance of a :class:`.ClassManager` or subclass
|
||||
|
||||
This attribute is consulted by SQLAlchemy instrumentation
|
||||
resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
|
||||
has been imported. If custom finders are installed in the global
|
||||
instrumentation_finders list, they may or may not choose to honor this
|
||||
attribute.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def find_native_user_instrumentation_hook(cls):
|
||||
"""Find user-specified instrumentation management for a class."""
|
||||
return getattr(cls, INSTRUMENTATION_MANAGER, None)
|
||||
|
||||
|
||||
instrumentation_finders = [find_native_user_instrumentation_hook]
|
||||
"""An extensible sequence of callables which return instrumentation
|
||||
implementations
|
||||
|
||||
When a class is registered, each callable will be passed a class object.
|
||||
If None is returned, the
|
||||
next finder in the sequence is consulted. Otherwise the return must be an
|
||||
instrumentation factory that follows the same guidelines as
|
||||
sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
|
||||
|
||||
By default, the only finder is find_native_user_instrumentation_hook, which
|
||||
searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
|
||||
ClassManager instrumentation is used.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ExtendedInstrumentationRegistry(InstrumentationFactory):
|
||||
"""Extends :class:`.InstrumentationFactory` with additional
|
||||
bookkeeping, to accommodate multiple types of
|
||||
class managers.
|
||||
|
||||
"""
|
||||
|
||||
_manager_finders = weakref.WeakKeyDictionary()
|
||||
_state_finders = weakref.WeakKeyDictionary()
|
||||
_dict_finders = weakref.WeakKeyDictionary()
|
||||
_extended = False
|
||||
|
||||
def _locate_extended_factory(self, class_):
|
||||
for finder in instrumentation_finders:
|
||||
factory = finder(class_)
|
||||
if factory is not None:
|
||||
manager = self._extended_class_manager(class_, factory)
|
||||
return manager, factory
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def _check_conflicts(self, class_, factory):
|
||||
existing_factories = self._collect_management_factories_for(
|
||||
class_
|
||||
).difference([factory])
|
||||
if existing_factories:
|
||||
raise TypeError(
|
||||
"multiple instrumentation implementations specified "
|
||||
"in %s inheritance hierarchy: %r"
|
||||
% (class_.__name__, list(existing_factories))
|
||||
)
|
||||
|
||||
def _extended_class_manager(self, class_, factory):
|
||||
manager = factory(class_)
|
||||
if not isinstance(manager, ClassManager):
|
||||
manager = _ClassInstrumentationAdapter(class_, manager)
|
||||
|
||||
if factory != ClassManager and not self._extended:
|
||||
# somebody invoked a custom ClassManager.
|
||||
# reinstall global "getter" functions with the more
|
||||
# expensive ones.
|
||||
self._extended = True
|
||||
_install_instrumented_lookups()
|
||||
|
||||
self._manager_finders[class_] = manager.manager_getter()
|
||||
self._state_finders[class_] = manager.state_getter()
|
||||
self._dict_finders[class_] = manager.dict_getter()
|
||||
return manager
|
||||
|
||||
def _collect_management_factories_for(self, cls):
|
||||
"""Return a collection of factories in play or specified for a
|
||||
hierarchy.
|
||||
|
||||
Traverses the entire inheritance graph of a cls and returns a
|
||||
collection of instrumentation factories for those classes. Factories
|
||||
are extracted from active ClassManagers, if available, otherwise
|
||||
instrumentation_finders is consulted.
|
||||
|
||||
"""
|
||||
hierarchy = util.class_hierarchy(cls)
|
||||
factories = set()
|
||||
for member in hierarchy:
|
||||
manager = self.opt_manager_of_class(member)
|
||||
if manager is not None:
|
||||
factories.add(manager.factory)
|
||||
else:
|
||||
for finder in instrumentation_finders:
|
||||
factory = finder(member)
|
||||
if factory is not None:
|
||||
break
|
||||
else:
|
||||
factory = None
|
||||
factories.add(factory)
|
||||
factories.discard(None)
|
||||
return factories
|
||||
|
||||
def unregister(self, class_):
|
||||
super().unregister(class_)
|
||||
if class_ in self._manager_finders:
|
||||
del self._manager_finders[class_]
|
||||
del self._state_finders[class_]
|
||||
del self._dict_finders[class_]
|
||||
|
||||
def opt_manager_of_class(self, cls):
|
||||
try:
|
||||
finder = self._manager_finders.get(
|
||||
cls, _default_opt_manager_getter
|
||||
)
|
||||
except TypeError:
|
||||
# due to weakref lookup on invalid object
|
||||
return None
|
||||
else:
|
||||
return finder(cls)
|
||||
|
||||
def manager_of_class(self, cls):
|
||||
try:
|
||||
finder = self._manager_finders.get(cls, _default_manager_getter)
|
||||
except TypeError:
|
||||
# due to weakref lookup on invalid object
|
||||
raise orm_exc.UnmappedClassError(
|
||||
cls, f"Can't locate an instrumentation manager for class {cls}"
|
||||
)
|
||||
else:
|
||||
manager = finder(cls)
|
||||
if manager is None:
|
||||
raise orm_exc.UnmappedClassError(
|
||||
cls,
|
||||
f"Can't locate an instrumentation manager for class {cls}",
|
||||
)
|
||||
return manager
|
||||
|
||||
def state_of(self, instance):
|
||||
if instance is None:
|
||||
raise AttributeError("None has no persistent state.")
|
||||
return self._state_finders.get(
|
||||
instance.__class__, _default_state_getter
|
||||
)(instance)
|
||||
|
||||
def dict_of(self, instance):
|
||||
if instance is None:
|
||||
raise AttributeError("None has no persistent state.")
|
||||
return self._dict_finders.get(
|
||||
instance.__class__, _default_dict_getter
|
||||
)(instance)
|
||||
|
||||
|
||||
orm_instrumentation._instrumentation_factory = _instrumentation_factory = (
|
||||
ExtendedInstrumentationRegistry()
|
||||
)
|
||||
orm_instrumentation.instrumentation_finders = instrumentation_finders
|
||||
|
||||
|
||||
class InstrumentationManager:
|
||||
"""User-defined class instrumentation extension.
|
||||
|
||||
:class:`.InstrumentationManager` can be subclassed in order
|
||||
to change
|
||||
how class instrumentation proceeds. This class exists for
|
||||
the purposes of integration with other object management
|
||||
frameworks which would like to entirely modify the
|
||||
instrumentation methodology of the ORM, and is not intended
|
||||
for regular usage. For interception of class instrumentation
|
||||
events, see :class:`.InstrumentationEvents`.
|
||||
|
||||
The API for this class should be considered as semi-stable,
|
||||
and may change slightly with new releases.
|
||||
|
||||
"""
|
||||
|
||||
# r4361 added a mandatory (cls) constructor to this interface.
|
||||
# given that, perhaps class_ should be dropped from all of these
|
||||
# signatures.
|
||||
|
||||
def __init__(self, class_):
|
||||
pass
|
||||
|
||||
def manage(self, class_, manager):
|
||||
setattr(class_, "_default_class_manager", manager)
|
||||
|
||||
def unregister(self, class_, manager):
|
||||
delattr(class_, "_default_class_manager")
|
||||
|
||||
def manager_getter(self, class_):
|
||||
def get(cls):
|
||||
return cls._default_class_manager
|
||||
|
||||
return get
|
||||
|
||||
def instrument_attribute(self, class_, key, inst):
|
||||
pass
|
||||
|
||||
def post_configure_attribute(self, class_, key, inst):
|
||||
pass
|
||||
|
||||
def install_descriptor(self, class_, key, inst):
|
||||
setattr(class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, class_, key):
|
||||
delattr(class_, key)
|
||||
|
||||
def install_member(self, class_, key, implementation):
|
||||
setattr(class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, class_, key):
|
||||
delattr(class_, key)
|
||||
|
||||
def instrument_collection_class(self, class_, key, collection_class):
|
||||
return collections.prepare_instrumentation(collection_class)
|
||||
|
||||
def get_instance_dict(self, class_, instance):
|
||||
return instance.__dict__
|
||||
|
||||
def initialize_instance_dict(self, class_, instance):
|
||||
pass
|
||||
|
||||
def install_state(self, class_, instance, state):
|
||||
setattr(instance, "_default_state", state)
|
||||
|
||||
def remove_state(self, class_, instance):
|
||||
delattr(instance, "_default_state")
|
||||
|
||||
def state_getter(self, class_):
|
||||
return lambda instance: getattr(instance, "_default_state")
|
||||
|
||||
def dict_getter(self, class_):
|
||||
return lambda inst: self.get_instance_dict(class_, inst)
|
||||
|
||||
|
||||
class _ClassInstrumentationAdapter(ClassManager):
|
||||
"""Adapts a user-defined InstrumentationManager to a ClassManager."""
|
||||
|
||||
def __init__(self, class_, override):
|
||||
self._adapted = override
|
||||
self._get_state = self._adapted.state_getter(class_)
|
||||
self._get_dict = self._adapted.dict_getter(class_)
|
||||
|
||||
ClassManager.__init__(self, class_)
|
||||
|
||||
def manage(self):
|
||||
self._adapted.manage(self.class_, self)
|
||||
|
||||
def unregister(self):
|
||||
self._adapted.unregister(self.class_, self)
|
||||
|
||||
def manager_getter(self):
|
||||
return self._adapted.manager_getter(self.class_)
|
||||
|
||||
def instrument_attribute(self, key, inst, propagated=False):
|
||||
ClassManager.instrument_attribute(self, key, inst, propagated)
|
||||
if not propagated:
|
||||
self._adapted.instrument_attribute(self.class_, key, inst)
|
||||
|
||||
def post_configure_attribute(self, key):
|
||||
super().post_configure_attribute(key)
|
||||
self._adapted.post_configure_attribute(self.class_, key, self[key])
|
||||
|
||||
def install_descriptor(self, key, inst):
|
||||
self._adapted.install_descriptor(self.class_, key, inst)
|
||||
|
||||
def uninstall_descriptor(self, key):
|
||||
self._adapted.uninstall_descriptor(self.class_, key)
|
||||
|
||||
def install_member(self, key, implementation):
|
||||
self._adapted.install_member(self.class_, key, implementation)
|
||||
|
||||
def uninstall_member(self, key):
|
||||
self._adapted.uninstall_member(self.class_, key)
|
||||
|
||||
def instrument_collection_class(self, key, collection_class):
|
||||
return self._adapted.instrument_collection_class(
|
||||
self.class_, key, collection_class
|
||||
)
|
||||
|
||||
def initialize_collection(self, key, state, factory):
|
||||
delegate = getattr(self._adapted, "initialize_collection", None)
|
||||
if delegate:
|
||||
return delegate(key, state, factory)
|
||||
else:
|
||||
return ClassManager.initialize_collection(
|
||||
self, key, state, factory
|
||||
)
|
||||
|
||||
def new_instance(self, state=None):
|
||||
instance = self.class_.__new__(self.class_)
|
||||
self.setup_instance(instance, state)
|
||||
return instance
|
||||
|
||||
def _new_state_if_none(self, instance):
|
||||
"""Install a default InstanceState if none is present.
|
||||
|
||||
A private convenience method used by the __init__ decorator.
|
||||
"""
|
||||
if self.has_state(instance):
|
||||
return False
|
||||
else:
|
||||
return self.setup_instance(instance)
|
||||
|
||||
def setup_instance(self, instance, state=None):
|
||||
self._adapted.initialize_instance_dict(self.class_, instance)
|
||||
|
||||
if state is None:
|
||||
state = self._state_constructor(instance, self)
|
||||
|
||||
# the given instance is assumed to have no state
|
||||
self._adapted.install_state(self.class_, instance, state)
|
||||
return state
|
||||
|
||||
def teardown_instance(self, instance):
|
||||
self._adapted.remove_state(self.class_, instance)
|
||||
|
||||
def has_state(self, instance):
|
||||
try:
|
||||
self._get_state(instance)
|
||||
except orm_exc.NO_STATE:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def state_getter(self):
|
||||
return self._get_state
|
||||
|
||||
def dict_getter(self):
|
||||
return self._get_dict
|
||||
|
||||
|
||||
def _install_instrumented_lookups():
|
||||
"""Replace global class/object management functions
|
||||
with ExtendedInstrumentationRegistry implementations, which
|
||||
allow multiple types of class managers to be present,
|
||||
at the cost of performance.
|
||||
|
||||
This function is called only by ExtendedInstrumentationRegistry
|
||||
and unit tests specific to this behavior.
|
||||
|
||||
The _reinstall_default_lookups() function can be called
|
||||
after this one to re-establish the default functions.
|
||||
|
||||
"""
|
||||
_install_lookups(
|
||||
dict(
|
||||
instance_state=_instrumentation_factory.state_of,
|
||||
instance_dict=_instrumentation_factory.dict_of,
|
||||
manager_of_class=_instrumentation_factory.manager_of_class,
|
||||
opt_manager_of_class=_instrumentation_factory.opt_manager_of_class,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _reinstall_default_lookups():
|
||||
"""Restore simplified lookups."""
|
||||
_install_lookups(
|
||||
dict(
|
||||
instance_state=_default_state_getter,
|
||||
instance_dict=_default_dict_getter,
|
||||
manager_of_class=_default_manager_getter,
|
||||
opt_manager_of_class=_default_opt_manager_getter,
|
||||
)
|
||||
)
|
||||
_instrumentation_factory._extended = False
|
||||
|
||||
|
||||
def _install_lookups(lookups):
|
||||
global instance_state, instance_dict
|
||||
global manager_of_class, opt_manager_of_class
|
||||
instance_state = lookups["instance_state"]
|
||||
instance_dict = lookups["instance_dict"]
|
||||
manager_of_class = lookups["manager_of_class"]
|
||||
opt_manager_of_class = lookups["opt_manager_of_class"]
|
||||
orm_base.instance_state = attributes.instance_state = (
|
||||
orm_instrumentation.instance_state
|
||||
) = instance_state
|
||||
orm_base.instance_dict = attributes.instance_dict = (
|
||||
orm_instrumentation.instance_dict
|
||||
) = instance_dict
|
||||
orm_base.manager_of_class = attributes.manager_of_class = (
|
||||
orm_instrumentation.manager_of_class
|
||||
) = manager_of_class
|
||||
orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = (
|
||||
attributes.opt_manager_of_class
|
||||
) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class
|
1073
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py
Normal file
1073
elitebot/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,6 @@
|
|||
# ext/mypy/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
@ -0,0 +1,320 @@
|
|||
# ext/mypy/apply.py
|
||||
# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import ARG_NAMED_OPT
|
||||
from mypy.nodes import Argument
|
||||
from mypy.nodes import AssignmentStmt
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import MDEF
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import RefExpr
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TempNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import Var
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.plugins.common import add_method_to_class
|
||||
from mypy.types import AnyType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneTyp
|
||||
from mypy.types import ProperType
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import UnboundType
|
||||
from mypy.types import UnionType
|
||||
|
||||
from . import infer
|
||||
from . import util
|
||||
from .names import expr_to_mapped_constructor
|
||||
from .names import NAMED_TYPE_SQLA_MAPPED
|
||||
|
||||
|
||||
def apply_mypy_mapped_attr(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
item: Union[NameExpr, StrExpr],
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
if isinstance(item, NameExpr):
|
||||
name = item.name
|
||||
elif isinstance(item, StrExpr):
|
||||
name = item.value
|
||||
else:
|
||||
return None
|
||||
|
||||
for stmt in cls.defs.body:
|
||||
if (
|
||||
isinstance(stmt, AssignmentStmt)
|
||||
and isinstance(stmt.lvalues[0], NameExpr)
|
||||
and stmt.lvalues[0].name == name
|
||||
):
|
||||
break
|
||||
else:
|
||||
util.fail(api, f"Can't find mapped attribute {name}", cls)
|
||||
return None
|
||||
|
||||
if stmt.type is None:
|
||||
util.fail(
|
||||
api,
|
||||
"Statement linked from _mypy_mapped_attrs has no "
|
||||
"typing information",
|
||||
stmt,
|
||||
)
|
||||
return None
|
||||
|
||||
left_hand_explicit_type = get_proper_type(stmt.type)
|
||||
assert isinstance(
|
||||
left_hand_explicit_type, (Instance, UnionType, UnboundType)
|
||||
)
|
||||
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=name,
|
||||
line=item.line,
|
||||
column=item.column,
|
||||
typ=left_hand_explicit_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
|
||||
apply_type_to_mapped_statement(
|
||||
api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
|
||||
)
|
||||
|
||||
|
||||
def re_apply_declarative_assignments(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""For multiple class passes, re-apply our left-hand side types as mypy
|
||||
seems to reset them in place.
|
||||
|
||||
"""
|
||||
mapped_attr_lookup = {attr.name: attr for attr in attributes}
|
||||
update_cls_metadata = False
|
||||
|
||||
for stmt in cls.defs.body:
|
||||
# for a re-apply, all of our statements are AssignmentStmt;
|
||||
# @declared_attr calls will have been converted and this
|
||||
# currently seems to be preserved by mypy (but who knows if this
|
||||
# will change).
|
||||
if (
|
||||
isinstance(stmt, AssignmentStmt)
|
||||
and isinstance(stmt.lvalues[0], NameExpr)
|
||||
and stmt.lvalues[0].name in mapped_attr_lookup
|
||||
and isinstance(stmt.lvalues[0].node, Var)
|
||||
):
|
||||
left_node = stmt.lvalues[0].node
|
||||
|
||||
python_type_for_type = mapped_attr_lookup[
|
||||
stmt.lvalues[0].name
|
||||
].type
|
||||
|
||||
left_node_proper_type = get_proper_type(left_node.type)
|
||||
|
||||
# if we have scanned an UnboundType and now there's a more
|
||||
# specific type than UnboundType, call the re-scan so we
|
||||
# can get that set up correctly
|
||||
if (
|
||||
isinstance(python_type_for_type, UnboundType)
|
||||
and not isinstance(left_node_proper_type, UnboundType)
|
||||
and (
|
||||
isinstance(stmt.rvalue, CallExpr)
|
||||
and isinstance(stmt.rvalue.callee, MemberExpr)
|
||||
and isinstance(stmt.rvalue.callee.expr, NameExpr)
|
||||
and stmt.rvalue.callee.expr.node is not None
|
||||
and stmt.rvalue.callee.expr.node.fullname
|
||||
== NAMED_TYPE_SQLA_MAPPED
|
||||
and stmt.rvalue.callee.name == "_empty_constructor"
|
||||
and isinstance(stmt.rvalue.args[0], CallExpr)
|
||||
and isinstance(stmt.rvalue.args[0].callee, RefExpr)
|
||||
)
|
||||
):
|
||||
new_python_type_for_type = (
|
||||
infer.infer_type_from_right_hand_nameexpr(
|
||||
api,
|
||||
stmt,
|
||||
left_node,
|
||||
left_node_proper_type,
|
||||
stmt.rvalue.args[0].callee,
|
||||
)
|
||||
)
|
||||
|
||||
if new_python_type_for_type is not None and not isinstance(
|
||||
new_python_type_for_type, UnboundType
|
||||
):
|
||||
python_type_for_type = new_python_type_for_type
|
||||
|
||||
# update the SQLAlchemyAttribute with the better
|
||||
# information
|
||||
mapped_attr_lookup[stmt.lvalues[0].name].type = (
|
||||
python_type_for_type
|
||||
)
|
||||
|
||||
update_cls_metadata = True
|
||||
|
||||
if (
|
||||
not isinstance(left_node.type, Instance)
|
||||
or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED
|
||||
):
|
||||
assert python_type_for_type is not None
|
||||
left_node.type = api.named_type(
|
||||
NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
|
||||
)
|
||||
|
||||
if update_cls_metadata:
|
||||
util.set_mapped_attributes(cls.info, attributes)
|
||||
|
||||
|
||||
def apply_type_to_mapped_statement(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
lvalue: NameExpr,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
python_type_for_type: Optional[ProperType],
|
||||
) -> None:
|
||||
"""Apply the Mapped[<type>] annotation and right hand object to a
|
||||
declarative assignment statement.
|
||||
|
||||
This converts a Python declarative class statement such as::
|
||||
|
||||
class User(Base):
|
||||
# ...
|
||||
|
||||
attrname = Column(Integer)
|
||||
|
||||
To one that describes the final Python behavior to Mypy::
|
||||
|
||||
class User(Base):
|
||||
# ...
|
||||
|
||||
attrname : Mapped[Optional[int]] = <meaningless temp node>
|
||||
|
||||
"""
|
||||
left_node = lvalue.node
|
||||
assert isinstance(left_node, Var)
|
||||
|
||||
# to be completely honest I have no idea what the difference between
|
||||
# left_node.type and stmt.type is, what it means if these are different
|
||||
# vs. the same, why in order to get tests to pass I have to assign
|
||||
# to stmt.type for the second case and not the first. this is complete
|
||||
# trying every combination until it works stuff.
|
||||
|
||||
if left_hand_explicit_type is not None:
|
||||
lvalue.is_inferred_def = False
|
||||
left_node.type = api.named_type(
|
||||
NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
|
||||
)
|
||||
else:
|
||||
lvalue.is_inferred_def = False
|
||||
left_node.type = api.named_type(
|
||||
NAMED_TYPE_SQLA_MAPPED,
|
||||
(
|
||||
[AnyType(TypeOfAny.special_form)]
|
||||
if python_type_for_type is None
|
||||
else [python_type_for_type]
|
||||
),
|
||||
)
|
||||
|
||||
# so to have it skip the right side totally, we can do this:
|
||||
# stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
|
||||
|
||||
# however, if we instead manufacture a new node that uses the old
|
||||
# one, then we can still get type checking for the call itself,
|
||||
# e.g. the Column, relationship() call, etc.
|
||||
|
||||
# rewrite the node as:
|
||||
# <attr> : Mapped[<typ>] =
|
||||
# _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
|
||||
# the original right-hand side is maintained so it gets type checked
|
||||
# internally
|
||||
stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue)
|
||||
|
||||
if stmt.type is not None and python_type_for_type is not None:
|
||||
stmt.type = python_type_for_type
|
||||
|
||||
|
||||
def add_additional_orm_attributes(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Apply __init__, __table__ and other attributes to the mapped class."""
|
||||
|
||||
info = util.info_for_cls(cls, api)
|
||||
|
||||
if info is None:
|
||||
return
|
||||
|
||||
is_base = util.get_is_base(info)
|
||||
|
||||
if "__init__" not in info.names and not is_base:
|
||||
mapped_attr_names = {attr.name: attr.type for attr in attributes}
|
||||
|
||||
for base in info.mro[1:-1]:
|
||||
if "sqlalchemy" not in info.metadata:
|
||||
continue
|
||||
|
||||
base_cls_attributes = util.get_mapped_attributes(base, api)
|
||||
if base_cls_attributes is None:
|
||||
continue
|
||||
|
||||
for attr in base_cls_attributes:
|
||||
mapped_attr_names.setdefault(attr.name, attr.type)
|
||||
|
||||
arguments = []
|
||||
for name, typ in mapped_attr_names.items():
|
||||
if typ is None:
|
||||
typ = AnyType(TypeOfAny.special_form)
|
||||
arguments.append(
|
||||
Argument(
|
||||
variable=Var(name, typ),
|
||||
type_annotation=typ,
|
||||
initializer=TempNode(typ),
|
||||
kind=ARG_NAMED_OPT,
|
||||
)
|
||||
)
|
||||
|
||||
add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
|
||||
|
||||
if "__table__" not in info.names and util.get_has_table(info):
|
||||
_apply_placeholder_attr_to_class(
|
||||
api, cls, "sqlalchemy.sql.schema.Table", "__table__"
|
||||
)
|
||||
if not is_base:
|
||||
_apply_placeholder_attr_to_class(
|
||||
api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
|
||||
)
|
||||
|
||||
|
||||
def _apply_placeholder_attr_to_class(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
cls: ClassDef,
|
||||
qualified_name: str,
|
||||
attrname: str,
|
||||
) -> None:
|
||||
sym = api.lookup_fully_qualified_or_none(qualified_name)
|
||||
if sym:
|
||||
assert isinstance(sym.node, TypeInfo)
|
||||
type_: ProperType = Instance(sym.node, [])
|
||||
else:
|
||||
type_ = AnyType(TypeOfAny.special_form)
|
||||
var = Var(attrname)
|
||||
var._fullname = cls.fullname + "." + attrname
|
||||
var.info = cls.info
|
||||
var.type = type_
|
||||
cls.info.names[attrname] = SymbolTableNode(MDEF, var)
|
|
@ -0,0 +1,515 @@
|
|||
# ext/mypy/decl_class.py
|
||||
# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import AssignmentStmt
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import Decorator
|
||||
from mypy.nodes import LambdaExpr
|
||||
from mypy.nodes import ListExpr
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import PlaceholderNode
|
||||
from mypy.nodes import RefExpr
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.nodes import SymbolNode
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TempNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import Var
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.types import AnyType
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneType
|
||||
from mypy.types import ProperType
|
||||
from mypy.types import Type
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import UnboundType
|
||||
from mypy.types import UnionType
|
||||
|
||||
from . import apply
|
||||
from . import infer
|
||||
from . import names
|
||||
from . import util
|
||||
|
||||
|
||||
def scan_declarative_assignments_and_apply_types(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
is_mixin_scan: bool = False,
|
||||
) -> Optional[List[util.SQLAlchemyAttribute]]:
|
||||
info = util.info_for_cls(cls, api)
|
||||
|
||||
if info is None:
|
||||
# this can occur during cached passes
|
||||
return None
|
||||
elif cls.fullname.startswith("builtins"):
|
||||
return None
|
||||
|
||||
mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = (
|
||||
util.get_mapped_attributes(info, api)
|
||||
)
|
||||
|
||||
# used by assign.add_additional_orm_attributes among others
|
||||
util.establish_as_sqlalchemy(info)
|
||||
|
||||
if mapped_attributes is not None:
|
||||
# ensure that a class that's mapped is always picked up by
|
||||
# its mapped() decorator or declarative metaclass before
|
||||
# it would be detected as an unmapped mixin class
|
||||
|
||||
if not is_mixin_scan:
|
||||
# mypy can call us more than once. it then *may* have reset the
|
||||
# left hand side of everything, but not the right that we removed,
|
||||
# removing our ability to re-scan. but we have the types
|
||||
# here, so lets re-apply them, or if we have an UnboundType,
|
||||
# we can re-scan
|
||||
|
||||
apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
|
||||
|
||||
return mapped_attributes
|
||||
|
||||
mapped_attributes = []
|
||||
|
||||
if not cls.defs.body:
|
||||
# when we get a mixin class from another file, the body is
|
||||
# empty (!) but the names are in the symbol table. so use that.
|
||||
|
||||
for sym_name, sym in info.names.items():
|
||||
_scan_symbol_table_entry(
|
||||
cls, api, sym_name, sym, mapped_attributes
|
||||
)
|
||||
else:
|
||||
for stmt in util.flatten_typechecking(cls.defs.body):
|
||||
if isinstance(stmt, AssignmentStmt):
|
||||
_scan_declarative_assignment_stmt(
|
||||
cls, api, stmt, mapped_attributes
|
||||
)
|
||||
elif isinstance(stmt, Decorator):
|
||||
_scan_declarative_decorator_stmt(
|
||||
cls, api, stmt, mapped_attributes
|
||||
)
|
||||
_scan_for_mapped_bases(cls, api)
|
||||
|
||||
if not is_mixin_scan:
|
||||
apply.add_additional_orm_attributes(cls, api, mapped_attributes)
|
||||
|
||||
util.set_mapped_attributes(info, mapped_attributes)
|
||||
|
||||
return mapped_attributes
|
||||
|
||||
|
||||
def _scan_symbol_table_entry(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
name: str,
|
||||
value: SymbolTableNode,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Extract mapping information from a SymbolTableNode that's in the
|
||||
type.names dictionary.
|
||||
|
||||
"""
|
||||
value_type = get_proper_type(value.type)
|
||||
if not isinstance(value_type, Instance):
|
||||
return
|
||||
|
||||
left_hand_explicit_type = None
|
||||
type_id = names.type_id_for_named_node(value_type.type)
|
||||
# type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
|
||||
|
||||
err = False
|
||||
|
||||
# TODO: this is nearly the same logic as that of
|
||||
# _scan_declarative_decorator_stmt, likely can be merged
|
||||
if type_id in {
|
||||
names.MAPPED,
|
||||
names.RELATIONSHIP,
|
||||
names.COMPOSITE_PROPERTY,
|
||||
names.MAPPER_PROPERTY,
|
||||
names.SYNONYM_PROPERTY,
|
||||
names.COLUMN_PROPERTY,
|
||||
}:
|
||||
if value_type.args:
|
||||
left_hand_explicit_type = get_proper_type(value_type.args[0])
|
||||
else:
|
||||
err = True
|
||||
elif type_id is names.COLUMN:
|
||||
if not value_type.args:
|
||||
err = True
|
||||
else:
|
||||
typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
|
||||
value_type.args[0]
|
||||
)
|
||||
if isinstance(typeengine_arg, Instance):
|
||||
typeengine_arg = typeengine_arg.type
|
||||
|
||||
if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
|
||||
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
if names.has_base_type_id(sym.node, names.TYPEENGINE):
|
||||
left_hand_explicit_type = UnionType(
|
||||
[
|
||||
infer.extract_python_type_from_typeengine(
|
||||
api, sym.node, []
|
||||
),
|
||||
NoneType(),
|
||||
]
|
||||
)
|
||||
else:
|
||||
util.fail(
|
||||
api,
|
||||
"Column type should be a TypeEngine "
|
||||
"subclass not '{}'".format(sym.node.fullname),
|
||||
value_type,
|
||||
)
|
||||
|
||||
if err:
|
||||
msg = (
|
||||
"Can't infer type from attribute {} on class {}. "
|
||||
"please specify a return type from this function that is "
|
||||
"one of: Mapped[<python type>], relationship[<target class>], "
|
||||
"Column[<TypeEngine>], MapperProperty[<python type>]"
|
||||
)
|
||||
util.fail(api, msg.format(name, cls.name), cls)
|
||||
|
||||
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
|
||||
|
||||
if left_hand_explicit_type is not None:
|
||||
assert value.node is not None
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=name,
|
||||
line=value.node.line,
|
||||
column=value.node.column,
|
||||
typ=left_hand_explicit_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _scan_declarative_decorator_stmt(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: Decorator,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Extract mapping information from a @declared_attr in a declarative
|
||||
class.
|
||||
|
||||
E.g.::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
@declared_attr
|
||||
def updated_at(cls) -> Column[DateTime]:
|
||||
return Column(DateTime)
|
||||
|
||||
Will resolve in mypy as::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
updated_at: Mapped[Optional[datetime.datetime]]
|
||||
|
||||
"""
|
||||
for dec in stmt.decorators:
|
||||
if (
|
||||
isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
|
||||
and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
|
||||
):
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
dec_index = cls.defs.body.index(stmt)
|
||||
|
||||
left_hand_explicit_type: Optional[ProperType] = None
|
||||
|
||||
if util.name_is_dunder(stmt.name):
|
||||
# for dunder names like __table_args__, __tablename__,
|
||||
# __mapper_args__ etc., rewrite these as simple assignment
|
||||
# statements; otherwise mypy doesn't like if the decorated
|
||||
# function has an annotation like ``cls: Type[Foo]`` because
|
||||
# it isn't @classmethod
|
||||
any_ = AnyType(TypeOfAny.special_form)
|
||||
left_node = NameExpr(stmt.var.name)
|
||||
left_node.node = stmt.var
|
||||
new_stmt = AssignmentStmt([left_node], TempNode(any_))
|
||||
new_stmt.type = left_node.node.type
|
||||
cls.defs.body[dec_index] = new_stmt
|
||||
return
|
||||
elif isinstance(stmt.func.type, CallableType):
|
||||
func_type = stmt.func.type.ret_type
|
||||
if isinstance(func_type, UnboundType):
|
||||
type_id = names.type_id_for_unbound_type(func_type, cls, api)
|
||||
else:
|
||||
# this does not seem to occur unless the type argument is
|
||||
# incorrect
|
||||
return
|
||||
|
||||
if (
|
||||
type_id
|
||||
in {
|
||||
names.MAPPED,
|
||||
names.RELATIONSHIP,
|
||||
names.COMPOSITE_PROPERTY,
|
||||
names.MAPPER_PROPERTY,
|
||||
names.SYNONYM_PROPERTY,
|
||||
names.COLUMN_PROPERTY,
|
||||
}
|
||||
and func_type.args
|
||||
):
|
||||
left_hand_explicit_type = get_proper_type(func_type.args[0])
|
||||
elif type_id is names.COLUMN and func_type.args:
|
||||
typeengine_arg = func_type.args[0]
|
||||
if isinstance(typeengine_arg, UnboundType):
|
||||
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
|
||||
if sym is not None and isinstance(sym.node, TypeInfo):
|
||||
if names.has_base_type_id(sym.node, names.TYPEENGINE):
|
||||
left_hand_explicit_type = UnionType(
|
||||
[
|
||||
infer.extract_python_type_from_typeengine(
|
||||
api, sym.node, []
|
||||
),
|
||||
NoneType(),
|
||||
]
|
||||
)
|
||||
else:
|
||||
util.fail(
|
||||
api,
|
||||
"Column type should be a TypeEngine "
|
||||
"subclass not '{}'".format(sym.node.fullname),
|
||||
func_type,
|
||||
)
|
||||
|
||||
if left_hand_explicit_type is None:
|
||||
# no type on the decorated function. our option here is to
|
||||
# dig into the function body and get the return type, but they
|
||||
# should just have an annotation.
|
||||
msg = (
|
||||
"Can't infer type from @declared_attr on function '{}'; "
|
||||
"please specify a return type from this function that is "
|
||||
"one of: Mapped[<python type>], relationship[<target class>], "
|
||||
"Column[<TypeEngine>], MapperProperty[<python type>]"
|
||||
)
|
||||
util.fail(api, msg.format(stmt.var.name), stmt)
|
||||
|
||||
left_hand_explicit_type = AnyType(TypeOfAny.special_form)
|
||||
|
||||
left_node = NameExpr(stmt.var.name)
|
||||
left_node.node = stmt.var
|
||||
|
||||
# totally feeling around in the dark here as I don't totally understand
|
||||
# the significance of UnboundType. It seems to be something that is
|
||||
# not going to do what's expected when it is applied as the type of
|
||||
# an AssignmentStatement. So do a feeling-around-in-the-dark version
|
||||
# of converting it to the regular Instance/TypeInfo/UnionType structures
|
||||
# we see everywhere else.
|
||||
if isinstance(left_hand_explicit_type, UnboundType):
|
||||
left_hand_explicit_type = get_proper_type(
|
||||
util.unbound_to_instance(api, left_hand_explicit_type)
|
||||
)
|
||||
|
||||
left_node.node.type = api.named_type(
|
||||
names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
|
||||
)
|
||||
|
||||
# this will ignore the rvalue entirely
|
||||
# rvalue = TempNode(AnyType(TypeOfAny.special_form))
|
||||
|
||||
# rewrite the node as:
|
||||
# <attr> : Mapped[<typ>] =
|
||||
# _sa_Mapped._empty_constructor(lambda: <function body>)
|
||||
# the function body is maintained so it gets type checked internally
|
||||
rvalue = names.expr_to_mapped_constructor(
|
||||
LambdaExpr(stmt.func.arguments, stmt.func.body)
|
||||
)
|
||||
|
||||
new_stmt = AssignmentStmt([left_node], rvalue)
|
||||
new_stmt.type = left_node.node.type
|
||||
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=left_node.name,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
typ=left_hand_explicit_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
cls.defs.body[dec_index] = new_stmt
|
||||
|
||||
|
||||
def _scan_declarative_assignment_stmt(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
attributes: List[util.SQLAlchemyAttribute],
|
||||
) -> None:
|
||||
"""Extract mapping information from an assignment statement in a
|
||||
declarative class.
|
||||
|
||||
"""
|
||||
lvalue = stmt.lvalues[0]
|
||||
if not isinstance(lvalue, NameExpr):
|
||||
return
|
||||
|
||||
sym = cls.info.names.get(lvalue.name)
|
||||
|
||||
# this establishes that semantic analysis has taken place, which
|
||||
# means the nodes are populated and we are called from an appropriate
|
||||
# hook.
|
||||
assert sym is not None
|
||||
node = sym.node
|
||||
|
||||
if isinstance(node, PlaceholderNode):
|
||||
return
|
||||
|
||||
assert node is lvalue.node
|
||||
assert isinstance(node, Var)
|
||||
|
||||
if node.name == "__abstract__":
|
||||
if api.parse_bool(stmt.rvalue) is True:
|
||||
util.set_is_base(cls.info)
|
||||
return
|
||||
elif node.name == "__tablename__":
|
||||
util.set_has_table(cls.info)
|
||||
elif node.name.startswith("__"):
|
||||
return
|
||||
elif node.name == "_mypy_mapped_attrs":
|
||||
if not isinstance(stmt.rvalue, ListExpr):
|
||||
util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
|
||||
else:
|
||||
for item in stmt.rvalue.items:
|
||||
if isinstance(item, (NameExpr, StrExpr)):
|
||||
apply.apply_mypy_mapped_attr(cls, api, item, attributes)
|
||||
|
||||
left_hand_mapped_type: Optional[Type] = None
|
||||
left_hand_explicit_type: Optional[ProperType] = None
|
||||
|
||||
if node.is_inferred or node.type is None:
|
||||
if isinstance(stmt.type, UnboundType):
|
||||
# look for an explicit Mapped[] type annotation on the left
|
||||
# side with nothing on the right
|
||||
|
||||
# print(stmt.type)
|
||||
# Mapped?[Optional?[A?]]
|
||||
|
||||
left_hand_explicit_type = stmt.type
|
||||
|
||||
if stmt.type.name == "Mapped":
|
||||
mapped_sym = api.lookup_qualified("Mapped", cls)
|
||||
if (
|
||||
mapped_sym is not None
|
||||
and mapped_sym.node is not None
|
||||
and names.type_id_for_named_node(mapped_sym.node)
|
||||
is names.MAPPED
|
||||
):
|
||||
left_hand_explicit_type = get_proper_type(
|
||||
stmt.type.args[0]
|
||||
)
|
||||
left_hand_mapped_type = stmt.type
|
||||
|
||||
# TODO: do we need to convert from unbound for this case?
|
||||
# left_hand_explicit_type = util._unbound_to_instance(
|
||||
# api, left_hand_explicit_type
|
||||
# )
|
||||
else:
|
||||
node_type = get_proper_type(node.type)
|
||||
if (
|
||||
isinstance(node_type, Instance)
|
||||
and names.type_id_for_named_node(node_type.type) is names.MAPPED
|
||||
):
|
||||
# print(node.type)
|
||||
# sqlalchemy.orm.attributes.Mapped[<python type>]
|
||||
left_hand_explicit_type = get_proper_type(node_type.args[0])
|
||||
left_hand_mapped_type = node_type
|
||||
else:
|
||||
# print(node.type)
|
||||
# <python type>
|
||||
left_hand_explicit_type = node_type
|
||||
left_hand_mapped_type = None
|
||||
|
||||
if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
|
||||
# annotation without assignment and Mapped is present
|
||||
# as type annotation
|
||||
# equivalent to using _infer_type_from_left_hand_type_only.
|
||||
|
||||
python_type_for_type = left_hand_explicit_type
|
||||
elif isinstance(stmt.rvalue, CallExpr) and isinstance(
|
||||
stmt.rvalue.callee, RefExpr
|
||||
):
|
||||
python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
|
||||
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
|
||||
)
|
||||
|
||||
if python_type_for_type is None:
|
||||
return
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
assert python_type_for_type is not None
|
||||
|
||||
attributes.append(
|
||||
util.SQLAlchemyAttribute(
|
||||
name=node.name,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
typ=python_type_for_type,
|
||||
info=cls.info,
|
||||
)
|
||||
)
|
||||
|
||||
apply.apply_type_to_mapped_statement(
|
||||
api,
|
||||
stmt,
|
||||
lvalue,
|
||||
left_hand_explicit_type,
|
||||
python_type_for_type,
|
||||
)
|
||||
|
||||
|
||||
def _scan_for_mapped_bases(
|
||||
cls: ClassDef,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
) -> None:
|
||||
"""Given a class, iterate through its superclass hierarchy to find
|
||||
all other classes that are considered as ORM-significant.
|
||||
|
||||
Locates non-mapped mixins and scans them for mapped attributes to be
|
||||
applied to subclasses.
|
||||
|
||||
"""
|
||||
|
||||
info = util.info_for_cls(cls, api)
|
||||
|
||||
if info is None:
|
||||
return
|
||||
|
||||
for base_info in info.mro[1:-1]:
|
||||
if base_info.fullname.startswith("builtins"):
|
||||
continue
|
||||
|
||||
# scan each base for mapped attributes. if they are not already
|
||||
# scanned (but have all their type info), that means they are unmapped
|
||||
# mixins
|
||||
scan_declarative_assignments_and_apply_types(
|
||||
base_info.defn, api, is_mixin_scan=True
|
||||
)
|
|
@ -0,0 +1,590 @@
|
|||
# ext/mypy/infer.py
|
||||
# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
|
||||
from mypy.maptype import map_instance_to_supertype
|
||||
from mypy.nodes import AssignmentStmt
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import Expression
|
||||
from mypy.nodes import FuncDef
|
||||
from mypy.nodes import LambdaExpr
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import RefExpr
|
||||
from mypy.nodes import StrExpr
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.nodes import Var
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.subtypes import is_subtype
|
||||
from mypy.types import AnyType
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneType
|
||||
from mypy.types import ProperType
|
||||
from mypy.types import TypeOfAny
|
||||
from mypy.types import UnionType
|
||||
|
||||
from . import names
|
||||
from . import util
|
||||
|
||||
|
||||
def infer_type_from_right_hand_nameexpr(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
infer_from_right_side: RefExpr,
|
||||
) -> Optional[ProperType]:
|
||||
type_id = names.type_id_for_callee(infer_from_right_side)
|
||||
if type_id is None:
|
||||
return None
|
||||
elif type_id is names.MAPPED:
|
||||
python_type_for_type = _infer_type_from_mapped(
|
||||
api, stmt, node, left_hand_explicit_type, infer_from_right_side
|
||||
)
|
||||
elif type_id is names.COLUMN:
|
||||
python_type_for_type = _infer_type_from_decl_column(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.RELATIONSHIP:
|
||||
python_type_for_type = _infer_type_from_relationship(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.COLUMN_PROPERTY:
|
||||
python_type_for_type = _infer_type_from_decl_column_property(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.SYNONYM_PROPERTY:
|
||||
python_type_for_type = infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
elif type_id is names.COMPOSITE_PROPERTY:
|
||||
python_type_for_type = _infer_type_from_decl_composite_property(
|
||||
api, stmt, node, left_hand_explicit_type
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
return python_type_for_type
|
||||
|
||||
|
||||
def _infer_type_from_relationship(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a relationship.
|
||||
|
||||
E.g.::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
addresses = relationship(Address, uselist=True)
|
||||
|
||||
order: Mapped["Order"] = relationship("Order")
|
||||
|
||||
Will resolve in mypy as::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
addresses: Mapped[List[Address]]
|
||||
|
||||
order: Mapped["Order"]
|
||||
|
||||
"""
|
||||
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
target_cls_arg = stmt.rvalue.args[0]
|
||||
python_type_for_type: Optional[ProperType] = None
|
||||
|
||||
if isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
target_cls_arg.node, TypeInfo
|
||||
):
|
||||
# type
|
||||
related_object_type = target_cls_arg.node
|
||||
python_type_for_type = Instance(related_object_type, [])
|
||||
|
||||
# other cases not covered - an error message directs the user
|
||||
# to set an explicit type annotation
|
||||
#
|
||||
# node.type == str, it's a string
|
||||
# if isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
# target_cls_arg.node, Var
|
||||
# )
|
||||
# points to a type
|
||||
# isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
# target_cls_arg.node, TypeAlias
|
||||
# )
|
||||
# string expression
|
||||
# isinstance(target_cls_arg, StrExpr)
|
||||
|
||||
uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
|
||||
collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
|
||||
stmt.rvalue, "collection_class"
|
||||
)
|
||||
type_is_a_collection = False
|
||||
|
||||
# this can be used to determine Optional for a many-to-one
|
||||
# in the same way nullable=False could be used, if we start supporting
|
||||
# that.
|
||||
# innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
|
||||
|
||||
if (
|
||||
uselist_arg is not None
|
||||
and api.parse_bool(uselist_arg) is True
|
||||
and collection_cls_arg is None
|
||||
):
|
||||
type_is_a_collection = True
|
||||
if python_type_for_type is not None:
|
||||
python_type_for_type = api.named_type(
|
||||
names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
|
||||
)
|
||||
elif (
|
||||
uselist_arg is None or api.parse_bool(uselist_arg) is True
|
||||
) and collection_cls_arg is not None:
|
||||
type_is_a_collection = True
|
||||
if isinstance(collection_cls_arg, CallExpr):
|
||||
collection_cls_arg = collection_cls_arg.callee
|
||||
|
||||
if isinstance(collection_cls_arg, NameExpr) and isinstance(
|
||||
collection_cls_arg.node, TypeInfo
|
||||
):
|
||||
if python_type_for_type is not None:
|
||||
# this can still be overridden by the left hand side
|
||||
# within _infer_Type_from_left_and_inferred_right
|
||||
python_type_for_type = Instance(
|
||||
collection_cls_arg.node, [python_type_for_type]
|
||||
)
|
||||
elif (
|
||||
isinstance(collection_cls_arg, NameExpr)
|
||||
and isinstance(collection_cls_arg.node, FuncDef)
|
||||
and collection_cls_arg.node.type is not None
|
||||
):
|
||||
if python_type_for_type is not None:
|
||||
# this can still be overridden by the left hand side
|
||||
# within _infer_Type_from_left_and_inferred_right
|
||||
|
||||
# TODO: handle mypy.types.Overloaded
|
||||
if isinstance(collection_cls_arg.node.type, CallableType):
|
||||
rt = get_proper_type(collection_cls_arg.node.type.ret_type)
|
||||
|
||||
if isinstance(rt, CallableType):
|
||||
callable_ret_type = get_proper_type(rt.ret_type)
|
||||
if isinstance(callable_ret_type, Instance):
|
||||
python_type_for_type = Instance(
|
||||
callable_ret_type.type,
|
||||
[python_type_for_type],
|
||||
)
|
||||
else:
|
||||
util.fail(
|
||||
api,
|
||||
"Expected Python collection type for "
|
||||
"collection_class parameter",
|
||||
stmt.rvalue,
|
||||
)
|
||||
python_type_for_type = None
|
||||
elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
|
||||
if collection_cls_arg is not None:
|
||||
util.fail(
|
||||
api,
|
||||
"Sending uselist=False and collection_class at the same time "
|
||||
"does not make sense",
|
||||
stmt.rvalue,
|
||||
)
|
||||
if python_type_for_type is not None:
|
||||
python_type_for_type = UnionType(
|
||||
[python_type_for_type, NoneType()]
|
||||
)
|
||||
|
||||
else:
|
||||
if left_hand_explicit_type is None:
|
||||
msg = (
|
||||
"Can't infer scalar or collection for ORM mapped expression "
|
||||
"assigned to attribute '{}' if both 'uselist' and "
|
||||
"'collection_class' arguments are absent from the "
|
||||
"relationship(); please specify a "
|
||||
"type annotation on the left hand side."
|
||||
)
|
||||
util.fail(api, msg.format(node.name), node)
|
||||
|
||||
if python_type_for_type is None:
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
elif left_hand_explicit_type is not None:
|
||||
if type_is_a_collection:
|
||||
assert isinstance(left_hand_explicit_type, Instance)
|
||||
assert isinstance(python_type_for_type, Instance)
|
||||
return _infer_collection_type_from_left_and_inferred_right(
|
||||
api, node, left_hand_explicit_type, python_type_for_type
|
||||
)
|
||||
else:
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api,
|
||||
node,
|
||||
left_hand_explicit_type,
|
||||
python_type_for_type,
|
||||
)
|
||||
else:
|
||||
return python_type_for_type
|
||||
|
||||
|
||||
def _infer_type_from_decl_composite_property(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a Composite."""
|
||||
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
target_cls_arg = stmt.rvalue.args[0]
|
||||
python_type_for_type = None
|
||||
|
||||
if isinstance(target_cls_arg, NameExpr) and isinstance(
|
||||
target_cls_arg.node, TypeInfo
|
||||
):
|
||||
related_object_type = target_cls_arg.node
|
||||
python_type_for_type = Instance(related_object_type, [])
|
||||
else:
|
||||
python_type_for_type = None
|
||||
|
||||
if python_type_for_type is None:
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
elif left_hand_explicit_type is not None:
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api, node, left_hand_explicit_type, python_type_for_type
|
||||
)
|
||||
else:
|
||||
return python_type_for_type
|
||||
|
||||
|
||||
def _infer_type_from_mapped(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
infer_from_right_side: RefExpr,
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a right side expression
|
||||
that returns Mapped.
|
||||
|
||||
|
||||
"""
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
|
||||
# (Pdb) print(stmt.rvalue.callee)
|
||||
# NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501
|
||||
# (Pdb) stmt.rvalue.callee.node
|
||||
# <mypy.nodes.FuncDef object at 0x7f8d92fb5940>
|
||||
# (Pdb) stmt.rvalue.callee.node.type
|
||||
# def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501
|
||||
# sqlalchemy.orm.base.Mapped[_T`-1]
|
||||
# the_mapped_type = stmt.rvalue.callee.node.type.ret_type
|
||||
|
||||
# TODO: look at generic ref and either use that,
|
||||
# or reconcile w/ what's present, etc.
|
||||
the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa
|
||||
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
|
||||
|
||||
def _infer_type_from_decl_column_property(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a ColumnProperty.
|
||||
|
||||
This includes mappings against ``column_property()`` as well as the
|
||||
``deferred()`` function.
|
||||
|
||||
"""
|
||||
assert isinstance(stmt.rvalue, CallExpr)
|
||||
|
||||
if stmt.rvalue.args:
|
||||
first_prop_arg = stmt.rvalue.args[0]
|
||||
|
||||
if isinstance(first_prop_arg, CallExpr):
|
||||
type_id = names.type_id_for_callee(first_prop_arg.callee)
|
||||
|
||||
# look for column_property() / deferred() etc with Column as first
|
||||
# argument
|
||||
if type_id is names.COLUMN:
|
||||
return _infer_type_from_decl_column(
|
||||
api,
|
||||
stmt,
|
||||
node,
|
||||
left_hand_explicit_type,
|
||||
right_hand_expression=first_prop_arg,
|
||||
)
|
||||
|
||||
if isinstance(stmt.rvalue, CallExpr):
|
||||
type_id = names.type_id_for_callee(stmt.rvalue.callee)
|
||||
# this is probably not strictly necessary as we have to use the left
|
||||
# hand type for query expression in any case. any other no-arg
|
||||
# column prop objects would go here also
|
||||
if type_id is names.QUERY_EXPRESSION:
|
||||
return _infer_type_from_decl_column(
|
||||
api,
|
||||
stmt,
|
||||
node,
|
||||
left_hand_explicit_type,
|
||||
)
|
||||
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
|
||||
|
||||
def _infer_type_from_decl_column(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
stmt: AssignmentStmt,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
right_hand_expression: Optional[CallExpr] = None,
|
||||
) -> Optional[ProperType]:
|
||||
"""Infer the type of mapping from a Column.
|
||||
|
||||
E.g.::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
a = Column(Integer)
|
||||
|
||||
b = Column("b", String)
|
||||
|
||||
c: Mapped[int] = Column(Integer)
|
||||
|
||||
d: bool = Column(Boolean)
|
||||
|
||||
Will resolve in MyPy as::
|
||||
|
||||
@reg.mapped
|
||||
class MyClass:
|
||||
# ...
|
||||
|
||||
a : Mapped[int]
|
||||
|
||||
b : Mapped[str]
|
||||
|
||||
c: Mapped[int]
|
||||
|
||||
d: Mapped[bool]
|
||||
|
||||
"""
|
||||
assert isinstance(node, Var)
|
||||
|
||||
callee = None
|
||||
|
||||
if right_hand_expression is None:
|
||||
if not isinstance(stmt.rvalue, CallExpr):
|
||||
return None
|
||||
|
||||
right_hand_expression = stmt.rvalue
|
||||
|
||||
for column_arg in right_hand_expression.args[0:2]:
|
||||
if isinstance(column_arg, CallExpr):
|
||||
if isinstance(column_arg.callee, RefExpr):
|
||||
# x = Column(String(50))
|
||||
callee = column_arg.callee
|
||||
type_args: Sequence[Expression] = column_arg.args
|
||||
break
|
||||
elif isinstance(column_arg, (NameExpr, MemberExpr)):
|
||||
if isinstance(column_arg.node, TypeInfo):
|
||||
# x = Column(String)
|
||||
callee = column_arg
|
||||
type_args = ()
|
||||
break
|
||||
else:
|
||||
# x = Column(some_name, String), go to next argument
|
||||
continue
|
||||
elif isinstance(column_arg, (StrExpr,)):
|
||||
# x = Column("name", String), go to next argument
|
||||
continue
|
||||
elif isinstance(column_arg, (LambdaExpr,)):
|
||||
# x = Column("name", String, default=lambda: uuid.uuid4())
|
||||
# go to next argument
|
||||
continue
|
||||
else:
|
||||
assert False
|
||||
|
||||
if callee is None:
|
||||
return None
|
||||
|
||||
if isinstance(callee.node, TypeInfo) and names.mro_has_id(
|
||||
callee.node.mro, names.TYPEENGINE
|
||||
):
|
||||
python_type_for_type = extract_python_type_from_typeengine(
|
||||
api, callee.node, type_args
|
||||
)
|
||||
|
||||
if left_hand_explicit_type is not None:
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api, node, left_hand_explicit_type, python_type_for_type
|
||||
)
|
||||
|
||||
else:
|
||||
return UnionType([python_type_for_type, NoneType()])
|
||||
else:
|
||||
# it's not TypeEngine, it's typically implicitly typed
|
||||
# like ForeignKey. we can't infer from the right side.
|
||||
return infer_type_from_left_hand_type_only(
|
||||
api, node, left_hand_explicit_type
|
||||
)
|
||||
|
||||
|
||||
def _infer_type_from_left_and_inferred_right(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: Var,
|
||||
left_hand_explicit_type: ProperType,
|
||||
python_type_for_type: ProperType,
|
||||
orig_left_hand_type: Optional[ProperType] = None,
|
||||
orig_python_type_for_type: Optional[ProperType] = None,
|
||||
) -> Optional[ProperType]:
|
||||
"""Validate type when a left hand annotation is present and we also
|
||||
could infer the right hand side::
|
||||
|
||||
attrname: SomeType = Column(SomeDBType)
|
||||
|
||||
"""
|
||||
|
||||
if orig_left_hand_type is None:
|
||||
orig_left_hand_type = left_hand_explicit_type
|
||||
if orig_python_type_for_type is None:
|
||||
orig_python_type_for_type = python_type_for_type
|
||||
|
||||
if not is_subtype(left_hand_explicit_type, python_type_for_type):
|
||||
effective_type = api.named_type(
|
||||
names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
|
||||
)
|
||||
|
||||
msg = (
|
||||
"Left hand assignment '{}: {}' not compatible "
|
||||
"with ORM mapped expression of type {}"
|
||||
)
|
||||
util.fail(
|
||||
api,
|
||||
msg.format(
|
||||
node.name,
|
||||
util.format_type(orig_left_hand_type, api.options),
|
||||
util.format_type(effective_type, api.options),
|
||||
),
|
||||
node,
|
||||
)
|
||||
|
||||
return orig_left_hand_type
|
||||
|
||||
|
||||
def _infer_collection_type_from_left_and_inferred_right(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Instance,
|
||||
python_type_for_type: Instance,
|
||||
) -> Optional[ProperType]:
|
||||
orig_left_hand_type = left_hand_explicit_type
|
||||
orig_python_type_for_type = python_type_for_type
|
||||
|
||||
if left_hand_explicit_type.args:
|
||||
left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
|
||||
python_type_arg = get_proper_type(python_type_for_type.args[0])
|
||||
else:
|
||||
left_hand_arg = left_hand_explicit_type
|
||||
python_type_arg = python_type_for_type
|
||||
|
||||
assert isinstance(left_hand_arg, (Instance, UnionType))
|
||||
assert isinstance(python_type_arg, (Instance, UnionType))
|
||||
|
||||
return _infer_type_from_left_and_inferred_right(
|
||||
api,
|
||||
node,
|
||||
left_hand_arg,
|
||||
python_type_arg,
|
||||
orig_left_hand_type=orig_left_hand_type,
|
||||
orig_python_type_for_type=orig_python_type_for_type,
|
||||
)
|
||||
|
||||
|
||||
def infer_type_from_left_hand_type_only(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: Var,
|
||||
left_hand_explicit_type: Optional[ProperType],
|
||||
) -> Optional[ProperType]:
|
||||
"""Determine the type based on explicit annotation only.
|
||||
|
||||
if no annotation were present, note that we need one there to know
|
||||
the type.
|
||||
|
||||
"""
|
||||
if left_hand_explicit_type is None:
|
||||
msg = (
|
||||
"Can't infer type from ORM mapped expression "
|
||||
"assigned to attribute '{}'; please specify a "
|
||||
"Python type or "
|
||||
"Mapped[<python type>] on the left hand side."
|
||||
)
|
||||
util.fail(api, msg.format(node.name), node)
|
||||
|
||||
return api.named_type(
|
||||
names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
|
||||
)
|
||||
|
||||
else:
|
||||
# use type from the left hand side
|
||||
return left_hand_explicit_type
|
||||
|
||||
|
||||
def extract_python_type_from_typeengine(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
node: TypeInfo,
|
||||
type_args: Sequence[Expression],
|
||||
) -> ProperType:
|
||||
if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
|
||||
first_arg = type_args[0]
|
||||
if isinstance(first_arg, RefExpr) and isinstance(
|
||||
first_arg.node, TypeInfo
|
||||
):
|
||||
for base_ in first_arg.node.mro:
|
||||
if base_.fullname == "enum.Enum":
|
||||
return Instance(first_arg.node, [])
|
||||
# TODO: support other pep-435 types here
|
||||
else:
|
||||
return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
|
||||
|
||||
assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
|
||||
"could not extract Python type from node: %s" % node
|
||||
)
|
||||
|
||||
type_engine_sym = api.lookup_fully_qualified_or_none(
|
||||
"sqlalchemy.sql.type_api.TypeEngine"
|
||||
)
|
||||
|
||||
assert type_engine_sym is not None and isinstance(
|
||||
type_engine_sym.node, TypeInfo
|
||||
)
|
||||
type_engine = map_instance_to_supertype(
|
||||
Instance(node, []),
|
||||
type_engine_sym.node,
|
||||
)
|
||||
return get_proper_type(type_engine.args[-1])
|
|
@ -0,0 +1,342 @@
|
|||
# ext/mypy/names.py
|
||||
# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from mypy.nodes import ARG_POS
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import Decorator
|
||||
from mypy.nodes import Expression
|
||||
from mypy.nodes import FuncDef
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import OverloadedFuncDef
|
||||
from mypy.nodes import SymbolNode
|
||||
from mypy.nodes import TypeAlias
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import UnboundType
|
||||
|
||||
from ... import util
|
||||
|
||||
COLUMN: int = util.symbol("COLUMN")
|
||||
RELATIONSHIP: int = util.symbol("RELATIONSHIP")
|
||||
REGISTRY: int = util.symbol("REGISTRY")
|
||||
COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY")
|
||||
TYPEENGINE: int = util.symbol("TYPEENGNE")
|
||||
MAPPED: int = util.symbol("MAPPED")
|
||||
DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE")
|
||||
DECLARATIVE_META: int = util.symbol("DECLARATIVE_META")
|
||||
MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR")
|
||||
SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY")
|
||||
COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY")
|
||||
DECLARED_ATTR: int = util.symbol("DECLARED_ATTR")
|
||||
MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY")
|
||||
AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE")
|
||||
AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE")
|
||||
DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN")
|
||||
QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION")
|
||||
|
||||
# names that must succeed with mypy.api.named_type
|
||||
NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
|
||||
NAMED_TYPE_BUILTINS_STR = "builtins.str"
|
||||
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
|
||||
NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped"
|
||||
|
||||
_lookup: Dict[str, Tuple[int, Set[str]]] = {
|
||||
"Column": (
|
||||
COLUMN,
|
||||
{
|
||||
"sqlalchemy.sql.schema.Column",
|
||||
"sqlalchemy.sql.Column",
|
||||
},
|
||||
),
|
||||
"Relationship": (
|
||||
RELATIONSHIP,
|
||||
{
|
||||
"sqlalchemy.orm.relationships.Relationship",
|
||||
"sqlalchemy.orm.relationships.RelationshipProperty",
|
||||
"sqlalchemy.orm.Relationship",
|
||||
"sqlalchemy.orm.RelationshipProperty",
|
||||
},
|
||||
),
|
||||
"RelationshipProperty": (
|
||||
RELATIONSHIP,
|
||||
{
|
||||
"sqlalchemy.orm.relationships.Relationship",
|
||||
"sqlalchemy.orm.relationships.RelationshipProperty",
|
||||
"sqlalchemy.orm.Relationship",
|
||||
"sqlalchemy.orm.RelationshipProperty",
|
||||
},
|
||||
),
|
||||
"registry": (
|
||||
REGISTRY,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.registry",
|
||||
"sqlalchemy.orm.registry",
|
||||
},
|
||||
),
|
||||
"ColumnProperty": (
|
||||
COLUMN_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.properties.MappedSQLExpression",
|
||||
"sqlalchemy.orm.MappedSQLExpression",
|
||||
"sqlalchemy.orm.properties.ColumnProperty",
|
||||
"sqlalchemy.orm.ColumnProperty",
|
||||
},
|
||||
),
|
||||
"MappedSQLExpression": (
|
||||
COLUMN_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.properties.MappedSQLExpression",
|
||||
"sqlalchemy.orm.MappedSQLExpression",
|
||||
"sqlalchemy.orm.properties.ColumnProperty",
|
||||
"sqlalchemy.orm.ColumnProperty",
|
||||
},
|
||||
),
|
||||
"Synonym": (
|
||||
SYNONYM_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.descriptor_props.Synonym",
|
||||
"sqlalchemy.orm.Synonym",
|
||||
"sqlalchemy.orm.descriptor_props.SynonymProperty",
|
||||
"sqlalchemy.orm.SynonymProperty",
|
||||
},
|
||||
),
|
||||
"SynonymProperty": (
|
||||
SYNONYM_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.descriptor_props.Synonym",
|
||||
"sqlalchemy.orm.Synonym",
|
||||
"sqlalchemy.orm.descriptor_props.SynonymProperty",
|
||||
"sqlalchemy.orm.SynonymProperty",
|
||||
},
|
||||
),
|
||||
"Composite": (
|
||||
COMPOSITE_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.descriptor_props.Composite",
|
||||
"sqlalchemy.orm.Composite",
|
||||
"sqlalchemy.orm.descriptor_props.CompositeProperty",
|
||||
"sqlalchemy.orm.CompositeProperty",
|
||||
},
|
||||
),
|
||||
"CompositeProperty": (
|
||||
COMPOSITE_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.descriptor_props.Composite",
|
||||
"sqlalchemy.orm.Composite",
|
||||
"sqlalchemy.orm.descriptor_props.CompositeProperty",
|
||||
"sqlalchemy.orm.CompositeProperty",
|
||||
},
|
||||
),
|
||||
"MapperProperty": (
|
||||
MAPPER_PROPERTY,
|
||||
{
|
||||
"sqlalchemy.orm.interfaces.MapperProperty",
|
||||
"sqlalchemy.orm.MapperProperty",
|
||||
},
|
||||
),
|
||||
"TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
|
||||
"Mapped": (MAPPED, {NAMED_TYPE_SQLA_MAPPED}),
|
||||
"declarative_base": (
|
||||
DECLARATIVE_BASE,
|
||||
{
|
||||
"sqlalchemy.ext.declarative.declarative_base",
|
||||
"sqlalchemy.orm.declarative_base",
|
||||
"sqlalchemy.orm.decl_api.declarative_base",
|
||||
},
|
||||
),
|
||||
"DeclarativeMeta": (
|
||||
DECLARATIVE_META,
|
||||
{
|
||||
"sqlalchemy.ext.declarative.DeclarativeMeta",
|
||||
"sqlalchemy.orm.DeclarativeMeta",
|
||||
"sqlalchemy.orm.decl_api.DeclarativeMeta",
|
||||
},
|
||||
),
|
||||
"mapped": (
|
||||
MAPPED_DECORATOR,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.registry.mapped",
|
||||
"sqlalchemy.orm.registry.mapped",
|
||||
},
|
||||
),
|
||||
"as_declarative": (
|
||||
AS_DECLARATIVE,
|
||||
{
|
||||
"sqlalchemy.ext.declarative.as_declarative",
|
||||
"sqlalchemy.orm.decl_api.as_declarative",
|
||||
"sqlalchemy.orm.as_declarative",
|
||||
},
|
||||
),
|
||||
"as_declarative_base": (
|
||||
AS_DECLARATIVE_BASE,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.registry.as_declarative_base",
|
||||
"sqlalchemy.orm.registry.as_declarative_base",
|
||||
},
|
||||
),
|
||||
"declared_attr": (
|
||||
DECLARED_ATTR,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.declared_attr",
|
||||
"sqlalchemy.orm.declared_attr",
|
||||
},
|
||||
),
|
||||
"declarative_mixin": (
|
||||
DECLARATIVE_MIXIN,
|
||||
{
|
||||
"sqlalchemy.orm.decl_api.declarative_mixin",
|
||||
"sqlalchemy.orm.declarative_mixin",
|
||||
},
|
||||
),
|
||||
"query_expression": (
|
||||
QUERY_EXPRESSION,
|
||||
{
|
||||
"sqlalchemy.orm.query_expression",
|
||||
"sqlalchemy.orm._orm_constructors.query_expression",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
|
||||
for mr in info.mro:
|
||||
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
|
||||
if check_type_id == type_id:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
if fullnames is None:
|
||||
return False
|
||||
|
||||
return mr.fullname in fullnames
|
||||
|
||||
|
||||
def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
|
||||
for mr in mro:
|
||||
check_type_id, fullnames = _lookup.get(mr.name, (None, None))
|
||||
if check_type_id == type_id:
|
||||
break
|
||||
else:
|
||||
return False
|
||||
|
||||
if fullnames is None:
|
||||
return False
|
||||
|
||||
return mr.fullname in fullnames
|
||||
|
||||
|
||||
def type_id_for_unbound_type(
|
||||
type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
|
||||
) -> Optional[int]:
|
||||
sym = api.lookup_qualified(type_.name, type_)
|
||||
if sym is not None:
|
||||
if isinstance(sym.node, TypeAlias):
|
||||
target_type = get_proper_type(sym.node.target)
|
||||
if isinstance(target_type, Instance):
|
||||
return type_id_for_named_node(target_type.type)
|
||||
elif isinstance(sym.node, TypeInfo):
|
||||
return type_id_for_named_node(sym.node)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def type_id_for_callee(callee: Expression) -> Optional[int]:
|
||||
if isinstance(callee, (MemberExpr, NameExpr)):
|
||||
if isinstance(callee.node, Decorator) and isinstance(
|
||||
callee.node.func, FuncDef
|
||||
):
|
||||
if callee.node.func.type and isinstance(
|
||||
callee.node.func.type, CallableType
|
||||
):
|
||||
ret_type = get_proper_type(callee.node.func.type.ret_type)
|
||||
|
||||
if isinstance(ret_type, Instance):
|
||||
return type_id_for_fullname(ret_type.type.fullname)
|
||||
|
||||
return None
|
||||
|
||||
elif isinstance(callee.node, OverloadedFuncDef):
|
||||
if (
|
||||
callee.node.impl
|
||||
and callee.node.impl.type
|
||||
and isinstance(callee.node.impl.type, CallableType)
|
||||
):
|
||||
ret_type = get_proper_type(callee.node.impl.type.ret_type)
|
||||
|
||||
if isinstance(ret_type, Instance):
|
||||
return type_id_for_fullname(ret_type.type.fullname)
|
||||
|
||||
return None
|
||||
elif isinstance(callee.node, FuncDef):
|
||||
if callee.node.type and isinstance(callee.node.type, CallableType):
|
||||
ret_type = get_proper_type(callee.node.type.ret_type)
|
||||
|
||||
if isinstance(ret_type, Instance):
|
||||
return type_id_for_fullname(ret_type.type.fullname)
|
||||
|
||||
return None
|
||||
elif isinstance(callee.node, TypeAlias):
|
||||
target_type = get_proper_type(callee.node.target)
|
||||
if isinstance(target_type, Instance):
|
||||
return type_id_for_fullname(target_type.type.fullname)
|
||||
elif isinstance(callee.node, TypeInfo):
|
||||
return type_id_for_named_node(callee)
|
||||
return None
|
||||
|
||||
|
||||
def type_id_for_named_node(
|
||||
node: Union[NameExpr, MemberExpr, SymbolNode]
|
||||
) -> Optional[int]:
|
||||
type_id, fullnames = _lookup.get(node.name, (None, None))
|
||||
|
||||
if type_id is None or fullnames is None:
|
||||
return None
|
||||
elif node.fullname in fullnames:
|
||||
return type_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def type_id_for_fullname(fullname: str) -> Optional[int]:
|
||||
tokens = fullname.split(".")
|
||||
immediate = tokens[-1]
|
||||
|
||||
type_id, fullnames = _lookup.get(immediate, (None, None))
|
||||
|
||||
if type_id is None or fullnames is None:
|
||||
return None
|
||||
elif fullname in fullnames:
|
||||
return type_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
|
||||
column_descriptor = NameExpr("__sa_Mapped")
|
||||
column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED
|
||||
member_expr = MemberExpr(column_descriptor, "_empty_constructor")
|
||||
return CallExpr(
|
||||
member_expr,
|
||||
[expr],
|
||||
[ARG_POS],
|
||||
["arg1"],
|
||||
)
|
|
@ -0,0 +1,303 @@
|
|||
# ext/mypy/plugin.py
|
||||
# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Mypy plugin for SQLAlchemy ORM.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Type as TypingType
|
||||
from typing import Union
|
||||
|
||||
from mypy import nodes
|
||||
from mypy.mro import calculate_mro
|
||||
from mypy.mro import MroError
|
||||
from mypy.nodes import Block
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import GDEF
|
||||
from mypy.nodes import MypyFile
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import SymbolTable
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import AttributeContext
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugin import DynamicClassDefContext
|
||||
from mypy.plugin import Plugin
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import Type
|
||||
|
||||
from . import decl_class
|
||||
from . import names
|
||||
from . import util
|
||||
|
||||
try:
|
||||
__import__("sqlalchemy-stubs")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
raise ImportError(
|
||||
"The SQLAlchemy mypy plugin in SQLAlchemy "
|
||||
"2.0 does not work with sqlalchemy-stubs or "
|
||||
"sqlalchemy2-stubs installed, as well as with any other third party "
|
||||
"SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs "
|
||||
"packages."
|
||||
)
|
||||
|
||||
|
||||
class SQLAlchemyPlugin(Plugin):
|
||||
def get_dynamic_class_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[DynamicClassDefContext], None]]:
|
||||
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
|
||||
return _dynamic_class_hook
|
||||
return None
|
||||
|
||||
def get_customize_class_mro_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
return _fill_in_decorators
|
||||
|
||||
def get_class_decorator_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
|
||||
if sym is not None and sym.node is not None:
|
||||
type_id = names.type_id_for_named_node(sym.node)
|
||||
if type_id is names.MAPPED_DECORATOR:
|
||||
return _cls_decorator_hook
|
||||
elif type_id in (
|
||||
names.AS_DECLARATIVE,
|
||||
names.AS_DECLARATIVE_BASE,
|
||||
):
|
||||
return _base_cls_decorator_hook
|
||||
elif type_id is names.DECLARATIVE_MIXIN:
|
||||
return _declarative_mixin_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_metaclass_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
|
||||
# Set any classes that explicitly have metaclass=DeclarativeMeta
|
||||
# as declarative so the check in `get_base_class_hook()` works
|
||||
return _metaclass_cls_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_base_class_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
|
||||
if (
|
||||
sym
|
||||
and isinstance(sym.node, TypeInfo)
|
||||
and util.has_declarative_base(sym.node)
|
||||
):
|
||||
return _base_cls_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_attribute_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[AttributeContext], Type]]:
|
||||
if fullname.startswith(
|
||||
"sqlalchemy.orm.attributes.QueryableAttribute."
|
||||
):
|
||||
return _queryable_getattr_hook
|
||||
|
||||
return None
|
||||
|
||||
def get_additional_deps(
|
||||
self, file: MypyFile
|
||||
) -> List[Tuple[int, str, int]]:
|
||||
return [
|
||||
#
|
||||
(10, "sqlalchemy.orm", -1),
|
||||
(10, "sqlalchemy.orm.attributes", -1),
|
||||
(10, "sqlalchemy.orm.decl_api", -1),
|
||||
]
|
||||
|
||||
|
||||
def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
|
||||
return SQLAlchemyPlugin
|
||||
|
||||
|
||||
def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
|
||||
"""Generate a declarative Base class when the declarative_base() function
|
||||
is encountered."""
|
||||
|
||||
_add_globals(ctx)
|
||||
|
||||
cls = ClassDef(ctx.name, Block([]))
|
||||
cls.fullname = ctx.api.qualified_name(ctx.name)
|
||||
|
||||
info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
|
||||
cls.info = info
|
||||
_set_declarative_metaclass(ctx.api, cls)
|
||||
|
||||
cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
|
||||
if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
|
||||
util.set_is_base(cls_arg.node)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(
|
||||
cls_arg.node.defn, ctx.api, is_mixin_scan=True
|
||||
)
|
||||
info.bases = [Instance(cls_arg.node, [])]
|
||||
else:
|
||||
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
|
||||
|
||||
info.bases = [obj]
|
||||
|
||||
try:
|
||||
calculate_mro(info)
|
||||
except MroError:
|
||||
util.fail(
|
||||
ctx.api, "Not able to calculate MRO for declarative base", ctx.call
|
||||
)
|
||||
obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
|
||||
info.bases = [obj]
|
||||
info.fallback_to_any = True
|
||||
|
||||
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
|
||||
util.set_is_base(info)
|
||||
|
||||
|
||||
def _fill_in_decorators(ctx: ClassDefContext) -> None:
|
||||
for decorator in ctx.cls.decorators:
|
||||
# set the ".fullname" attribute of a class decorator
|
||||
# that is a MemberExpr. This causes the logic in
|
||||
# semanal.py->apply_class_plugin_hooks to invoke the
|
||||
# get_class_decorator_hook for our "registry.map_class()"
|
||||
# and "registry.as_declarative_base()" methods.
|
||||
# this seems like a bug in mypy that these decorators are otherwise
|
||||
# skipped.
|
||||
|
||||
if (
|
||||
isinstance(decorator, nodes.CallExpr)
|
||||
and isinstance(decorator.callee, nodes.MemberExpr)
|
||||
and decorator.callee.name == "as_declarative_base"
|
||||
):
|
||||
target = decorator.callee
|
||||
elif (
|
||||
isinstance(decorator, nodes.MemberExpr)
|
||||
and decorator.name == "mapped"
|
||||
):
|
||||
target = decorator
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(target.expr, NameExpr):
|
||||
sym = ctx.api.lookup_qualified(
|
||||
target.expr.name, target, suppress_errors=True
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if sym and sym.node:
|
||||
sym_type = get_proper_type(sym.type)
|
||||
if isinstance(sym_type, Instance):
|
||||
target.fullname = f"{sym_type.type.fullname}.{target.name}"
|
||||
else:
|
||||
# if the registry is in the same file as where the
|
||||
# decorator is used, it might not have semantic
|
||||
# symbols applied and we can't get a fully qualified
|
||||
# name or an inferred type, so we are actually going to
|
||||
# flag an error in this case that they need to annotate
|
||||
# it. The "registry" is declared just
|
||||
# once (or few times), so they have to just not use
|
||||
# type inference for its assignment in this one case.
|
||||
util.fail(
|
||||
ctx.api,
|
||||
"Class decorator called %s(), but we can't "
|
||||
"tell if it's from an ORM registry. Please "
|
||||
"annotate the registry assignment, e.g. "
|
||||
"my_registry: registry = registry()" % target.name,
|
||||
sym.node,
|
||||
)
|
||||
|
||||
|
||||
def _cls_decorator_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
assert isinstance(ctx.reason, nodes.MemberExpr)
|
||||
expr = ctx.reason.expr
|
||||
|
||||
assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
|
||||
|
||||
node_type = get_proper_type(expr.node.type)
|
||||
|
||||
assert (
|
||||
isinstance(node_type, Instance)
|
||||
and names.type_id_for_named_node(node_type.type) is names.REGISTRY
|
||||
)
|
||||
|
||||
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
|
||||
|
||||
|
||||
def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
|
||||
cls = ctx.cls
|
||||
|
||||
_set_declarative_metaclass(ctx.api, cls)
|
||||
|
||||
util.set_is_base(ctx.cls.info)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(
|
||||
cls, ctx.api, is_mixin_scan=True
|
||||
)
|
||||
|
||||
|
||||
def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
util.set_is_base(ctx.cls.info)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(
|
||||
ctx.cls, ctx.api, is_mixin_scan=True
|
||||
)
|
||||
|
||||
|
||||
def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
|
||||
util.set_is_base(ctx.cls.info)
|
||||
|
||||
|
||||
def _base_cls_hook(ctx: ClassDefContext) -> None:
|
||||
_add_globals(ctx)
|
||||
decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
|
||||
|
||||
|
||||
def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
|
||||
# how do I....tell it it has no attribute of a certain name?
|
||||
# can't find any Type that seems to match that
|
||||
return ctx.default_attr_type
|
||||
|
||||
|
||||
def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
|
||||
"""Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
|
||||
for all class defs
|
||||
|
||||
"""
|
||||
|
||||
util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped")
|
||||
|
||||
|
||||
def _set_declarative_metaclass(
|
||||
api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
|
||||
) -> None:
|
||||
info = target_cls.info
|
||||
sym = api.lookup_fully_qualified_or_none(
|
||||
"sqlalchemy.orm.decl_api.DeclarativeMeta"
|
||||
)
|
||||
assert sym is not None and isinstance(sym.node, TypeInfo)
|
||||
info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
|
|
@ -0,0 +1,338 @@
|
|||
# ext/mypy/util.py
|
||||
# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Tuple
|
||||
from typing import Type as TypingType
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from mypy import version
|
||||
from mypy.messages import format_type as _mypy_format_type
|
||||
from mypy.nodes import CallExpr
|
||||
from mypy.nodes import ClassDef
|
||||
from mypy.nodes import CLASSDEF_NO_INFO
|
||||
from mypy.nodes import Context
|
||||
from mypy.nodes import Expression
|
||||
from mypy.nodes import FuncDef
|
||||
from mypy.nodes import IfStmt
|
||||
from mypy.nodes import JsonDict
|
||||
from mypy.nodes import MemberExpr
|
||||
from mypy.nodes import NameExpr
|
||||
from mypy.nodes import Statement
|
||||
from mypy.nodes import SymbolTableNode
|
||||
from mypy.nodes import TypeAlias
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugin import DynamicClassDefContext
|
||||
from mypy.plugin import SemanticAnalyzerPluginInterface
|
||||
from mypy.plugins.common import deserialize_and_fixup_type
|
||||
from mypy.typeops import map_type_from_supertype
|
||||
from mypy.types import CallableType
|
||||
from mypy.types import get_proper_type
|
||||
from mypy.types import Instance
|
||||
from mypy.types import NoneType
|
||||
from mypy.types import Type
|
||||
from mypy.types import TypeVarType
|
||||
from mypy.types import UnboundType
|
||||
from mypy.types import UnionType
|
||||
|
||||
_vers = tuple(
|
||||
[int(x) for x in version.__version__.split(".") if re.match(r"^\d+$", x)]
|
||||
)
|
||||
mypy_14 = _vers >= (1, 4)
|
||||
|
||||
|
||||
_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
|
||||
|
||||
|
||||
class SQLAlchemyAttribute:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
line: int,
|
||||
column: int,
|
||||
typ: Optional[Type],
|
||||
info: TypeInfo,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.line = line
|
||||
self.column = column
|
||||
self.type = typ
|
||||
self.info = info
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
assert self.type
|
||||
return {
|
||||
"name": self.name,
|
||||
"line": self.line,
|
||||
"column": self.column,
|
||||
"type": self.type.serialize(),
|
||||
}
|
||||
|
||||
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
|
||||
"""Expands type vars in the context of a subtype when an attribute is
|
||||
inherited from a generic super type.
|
||||
"""
|
||||
if not isinstance(self.type, TypeVarType):
|
||||
return
|
||||
|
||||
self.type = map_type_from_supertype(self.type, sub_type, self.info)
|
||||
|
||||
@classmethod
|
||||
def deserialize(
|
||||
cls,
|
||||
info: TypeInfo,
|
||||
data: JsonDict,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
) -> SQLAlchemyAttribute:
|
||||
data = data.copy()
|
||||
typ = deserialize_and_fixup_type(data.pop("type"), api)
|
||||
return cls(typ=typ, info=info, **data)
|
||||
|
||||
|
||||
def name_is_dunder(name: str) -> bool:
|
||||
return bool(re.match(r"^__.+?__$", name))
|
||||
|
||||
|
||||
def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
|
||||
info.metadata.setdefault("sqlalchemy", {})[key] = data
|
||||
|
||||
|
||||
def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
|
||||
return info.metadata.get("sqlalchemy", {}).get(key, None)
|
||||
|
||||
|
||||
def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
|
||||
if info.mro:
|
||||
for base in info.mro:
|
||||
metadata = _get_info_metadata(base, key)
|
||||
if metadata is not None:
|
||||
return metadata
|
||||
return None
|
||||
|
||||
|
||||
def establish_as_sqlalchemy(info: TypeInfo) -> None:
|
||||
info.metadata.setdefault("sqlalchemy", {})
|
||||
|
||||
|
||||
def set_is_base(info: TypeInfo) -> None:
|
||||
_set_info_metadata(info, "is_base", True)
|
||||
|
||||
|
||||
def get_is_base(info: TypeInfo) -> bool:
|
||||
is_base = _get_info_metadata(info, "is_base")
|
||||
return is_base is True
|
||||
|
||||
|
||||
def has_declarative_base(info: TypeInfo) -> bool:
|
||||
is_base = _get_info_mro_metadata(info, "is_base")
|
||||
return is_base is True
|
||||
|
||||
|
||||
def set_has_table(info: TypeInfo) -> None:
|
||||
_set_info_metadata(info, "has_table", True)
|
||||
|
||||
|
||||
def get_has_table(info: TypeInfo) -> bool:
|
||||
is_base = _get_info_metadata(info, "has_table")
|
||||
return is_base is True
|
||||
|
||||
|
||||
def get_mapped_attributes(
|
||||
info: TypeInfo, api: SemanticAnalyzerPluginInterface
|
||||
) -> Optional[List[SQLAlchemyAttribute]]:
|
||||
mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
|
||||
info, "mapped_attributes"
|
||||
)
|
||||
if mapped_attributes is None:
|
||||
return None
|
||||
|
||||
attributes: List[SQLAlchemyAttribute] = []
|
||||
|
||||
for data in mapped_attributes:
|
||||
attr = SQLAlchemyAttribute.deserialize(info, data, api)
|
||||
attr.expand_typevar_from_subtype(info)
|
||||
attributes.append(attr)
|
||||
|
||||
return attributes
|
||||
|
||||
|
||||
def format_type(typ_: Type, options: Options) -> str:
|
||||
if mypy_14:
|
||||
return _mypy_format_type(typ_, options)
|
||||
else:
|
||||
return _mypy_format_type(typ_) # type: ignore
|
||||
|
||||
|
||||
def set_mapped_attributes(
|
||||
info: TypeInfo, attributes: List[SQLAlchemyAttribute]
|
||||
) -> None:
|
||||
_set_info_metadata(
|
||||
info,
|
||||
"mapped_attributes",
|
||||
[attribute.serialize() for attribute in attributes],
|
||||
)
|
||||
|
||||
|
||||
def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
|
||||
msg = "[SQLAlchemy Mypy plugin] %s" % msg
|
||||
return api.fail(msg, ctx)
|
||||
|
||||
|
||||
def add_global(
|
||||
ctx: Union[ClassDefContext, DynamicClassDefContext],
|
||||
module: str,
|
||||
symbol_name: str,
|
||||
asname: str,
|
||||
) -> None:
|
||||
module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
|
||||
|
||||
if asname not in module_globals:
|
||||
lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
|
||||
symbol_name
|
||||
]
|
||||
|
||||
module_globals[asname] = lookup_sym
|
||||
|
||||
|
||||
@overload
|
||||
def get_callexpr_kwarg(
|
||||
callexpr: CallExpr, name: str, *, expr_types: None = ...
|
||||
) -> Optional[Union[CallExpr, NameExpr]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_callexpr_kwarg(
|
||||
callexpr: CallExpr,
|
||||
name: str,
|
||||
*,
|
||||
expr_types: Tuple[TypingType[_TArgType], ...],
|
||||
) -> Optional[_TArgType]: ...
|
||||
|
||||
|
||||
def get_callexpr_kwarg(
|
||||
callexpr: CallExpr,
|
||||
name: str,
|
||||
*,
|
||||
expr_types: Optional[Tuple[TypingType[Any], ...]] = None,
|
||||
) -> Optional[Any]:
|
||||
try:
|
||||
arg_idx = callexpr.arg_names.index(name)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
kwarg = callexpr.args[arg_idx]
|
||||
if isinstance(
|
||||
kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
|
||||
):
|
||||
return kwarg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
|
||||
for stmt in stmts:
|
||||
if (
|
||||
isinstance(stmt, IfStmt)
|
||||
and isinstance(stmt.expr[0], NameExpr)
|
||||
and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
|
||||
):
|
||||
yield from stmt.body[0].body
|
||||
else:
|
||||
yield stmt
|
||||
|
||||
|
||||
def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]:
|
||||
if isinstance(callee, (MemberExpr, NameExpr)):
|
||||
if isinstance(callee.node, FuncDef):
|
||||
if callee.node.type and isinstance(callee.node.type, CallableType):
|
||||
ret_type = get_proper_type(callee.node.type.ret_type)
|
||||
|
||||
if isinstance(ret_type, Instance):
|
||||
return ret_type
|
||||
|
||||
return None
|
||||
elif isinstance(callee.node, TypeAlias):
|
||||
target_type = get_proper_type(callee.node.target)
|
||||
if isinstance(target_type, Instance):
|
||||
return target_type
|
||||
elif isinstance(callee.node, TypeInfo):
|
||||
return callee.node
|
||||
return None
|
||||
|
||||
|
||||
def unbound_to_instance(
|
||||
api: SemanticAnalyzerPluginInterface, typ: Type
|
||||
) -> Type:
|
||||
"""Take the UnboundType that we seem to get as the ret_type from a FuncDef
|
||||
and convert it into an Instance/TypeInfo kind of structure that seems
|
||||
to work as the left-hand type of an AssignmentStatement.
|
||||
|
||||
"""
|
||||
|
||||
if not isinstance(typ, UnboundType):
|
||||
return typ
|
||||
|
||||
# TODO: figure out a more robust way to check this. The node is some
|
||||
# kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
|
||||
# but I can't figure out how to get them to match up
|
||||
if typ.name == "Optional":
|
||||
# convert from "Optional?" to the more familiar
|
||||
# UnionType[..., NoneType()]
|
||||
return unbound_to_instance(
|
||||
api,
|
||||
UnionType(
|
||||
[unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
|
||||
+ [NoneType()]
|
||||
),
|
||||
)
|
||||
|
||||
node = api.lookup_qualified(typ.name, typ)
|
||||
|
||||
if (
|
||||
node is not None
|
||||
and isinstance(node, SymbolTableNode)
|
||||
and isinstance(node.node, TypeInfo)
|
||||
):
|
||||
bound_type = node.node
|
||||
|
||||
return Instance(
|
||||
bound_type,
|
||||
[
|
||||
(
|
||||
unbound_to_instance(api, arg)
|
||||
if isinstance(arg, UnboundType)
|
||||
else arg
|
||||
)
|
||||
for arg in typ.args
|
||||
],
|
||||
)
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
def info_for_cls(
|
||||
cls: ClassDef, api: SemanticAnalyzerPluginInterface
|
||||
) -> Optional[TypeInfo]:
|
||||
if cls.info is CLASSDEF_NO_INFO:
|
||||
sym = api.lookup_qualified(cls.name, cls)
|
||||
if sym is None:
|
||||
return None
|
||||
assert sym and isinstance(sym.node, TypeInfo)
|
||||
return sym.node
|
||||
|
||||
return cls.info
|
|
@ -0,0 +1,416 @@
|
|||
# ext/orderinglist.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""A custom list that manages index/position information for contained
|
||||
elements.
|
||||
|
||||
:author: Jason Kirtland
|
||||
|
||||
``orderinglist`` is a helper for mutable ordered relationships. It will
|
||||
intercept list operations performed on a :func:`_orm.relationship`-managed
|
||||
collection and
|
||||
automatically synchronize changes in list position onto a target scalar
|
||||
attribute.
|
||||
|
||||
Example: A ``slide`` table, where each row refers to zero or more entries
|
||||
in a related ``bullet`` table. The bullets within a slide are
|
||||
displayed in order based on the value of the ``position`` column in the
|
||||
``bullet`` table. As entries are reordered in memory, the value of the
|
||||
``position`` attribute should be updated to reflect the new sort order::
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Slide(Base):
|
||||
__tablename__ = 'slide'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
bullets = relationship("Bullet", order_by="Bullet.position")
|
||||
|
||||
class Bullet(Base):
|
||||
__tablename__ = 'bullet'
|
||||
id = Column(Integer, primary_key=True)
|
||||
slide_id = Column(Integer, ForeignKey('slide.id'))
|
||||
position = Column(Integer)
|
||||
text = Column(String)
|
||||
|
||||
The standard relationship mapping will produce a list-like attribute on each
|
||||
``Slide`` containing all related ``Bullet`` objects,
|
||||
but coping with changes in ordering is not handled automatically.
|
||||
When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
|
||||
attribute will remain unset until manually assigned. When the ``Bullet``
|
||||
is inserted into the middle of the list, the following ``Bullet`` objects
|
||||
will also need to be renumbered.
|
||||
|
||||
The :class:`.OrderingList` object automates this task, managing the
|
||||
``position`` attribute on all ``Bullet`` objects in the collection. It is
|
||||
constructed using the :func:`.ordering_list` factory::
|
||||
|
||||
from sqlalchemy.ext.orderinglist import ordering_list
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Slide(Base):
|
||||
__tablename__ = 'slide'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
bullets = relationship("Bullet", order_by="Bullet.position",
|
||||
collection_class=ordering_list('position'))
|
||||
|
||||
class Bullet(Base):
|
||||
__tablename__ = 'bullet'
|
||||
id = Column(Integer, primary_key=True)
|
||||
slide_id = Column(Integer, ForeignKey('slide.id'))
|
||||
position = Column(Integer)
|
||||
text = Column(String)
|
||||
|
||||
With the above mapping the ``Bullet.position`` attribute is managed::
|
||||
|
||||
s = Slide()
|
||||
s.bullets.append(Bullet())
|
||||
s.bullets.append(Bullet())
|
||||
s.bullets[1].position
|
||||
>>> 1
|
||||
s.bullets.insert(1, Bullet())
|
||||
s.bullets[2].position
|
||||
>>> 2
|
||||
|
||||
The :class:`.OrderingList` construct only works with **changes** to a
|
||||
collection, and not the initial load from the database, and requires that the
|
||||
list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
|
||||
:func:`_orm.relationship` against the target ordering attribute, so that the
|
||||
ordering is correct when first loaded.
|
||||
|
||||
.. warning::
|
||||
|
||||
:class:`.OrderingList` only provides limited functionality when a primary
|
||||
key column or unique column is the target of the sort. Operations
|
||||
that are unsupported or are problematic include:
|
||||
|
||||
* two entries must trade values. This is not supported directly in the
|
||||
case of a primary key or unique constraint because it means at least
|
||||
one row would need to be temporarily removed first, or changed to
|
||||
a third, neutral value while the switch occurs.
|
||||
|
||||
* an entry must be deleted in order to make room for a new entry.
|
||||
SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
|
||||
single flush. In the case of a primary key, it will trade
|
||||
an INSERT/DELETE of the same primary key for an UPDATE statement in order
|
||||
to lessen the impact of this limitation, however this does not take place
|
||||
for a UNIQUE column.
|
||||
A future feature will allow the "DELETE before INSERT" behavior to be
|
||||
possible, alleviating this limitation, though this feature will require
|
||||
explicit configuration at the mapper level for sets of columns that
|
||||
are to be handled in this way.
|
||||
|
||||
:func:`.ordering_list` takes the name of the related object's ordering
|
||||
attribute as an argument. By default, the zero-based integer index of the
|
||||
object's position in the :func:`.ordering_list` is synchronized with the
|
||||
ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
|
||||
start numbering at 1 or some other integer, provide ``count_from=1``.
|
||||
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from ..orm.collections import collection
|
||||
from ..orm.collections import collection_adapter
|
||||
|
||||
_T = TypeVar("_T")
|
||||
OrderingFunc = Callable[[int, Sequence[_T]], int]
|
||||
|
||||
|
||||
__all__ = ["ordering_list"]
|
||||
|
||||
|
||||
def ordering_list(
|
||||
attr: str,
|
||||
count_from: Optional[int] = None,
|
||||
ordering_func: Optional[OrderingFunc] = None,
|
||||
reorder_on_append: bool = False,
|
||||
) -> Callable[[], OrderingList]:
|
||||
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
|
||||
|
||||
Returns an object suitable for use as an argument to a Mapper
|
||||
relationship's ``collection_class`` option. e.g.::
|
||||
|
||||
from sqlalchemy.ext.orderinglist import ordering_list
|
||||
|
||||
class Slide(Base):
|
||||
__tablename__ = 'slide'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
|
||||
bullets = relationship("Bullet", order_by="Bullet.position",
|
||||
collection_class=ordering_list('position'))
|
||||
|
||||
:param attr:
|
||||
Name of the mapped attribute to use for storage and retrieval of
|
||||
ordering information
|
||||
|
||||
:param count_from:
|
||||
Set up an integer-based ordering, starting at ``count_from``. For
|
||||
example, ``ordering_list('pos', count_from=1)`` would create a 1-based
|
||||
list in SQL, storing the value in the 'pos' column. Ignored if
|
||||
``ordering_func`` is supplied.
|
||||
|
||||
Additional arguments are passed to the :class:`.OrderingList` constructor.
|
||||
|
||||
"""
|
||||
|
||||
kw = _unsugar_count_from(
|
||||
count_from=count_from,
|
||||
ordering_func=ordering_func,
|
||||
reorder_on_append=reorder_on_append,
|
||||
)
|
||||
return lambda: OrderingList(attr, **kw)
|
||||
|
||||
|
||||
# Ordering utility functions
|
||||
|
||||
|
||||
def count_from_0(index, collection):
|
||||
"""Numbering function: consecutive integers starting at 0."""
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def count_from_1(index, collection):
|
||||
"""Numbering function: consecutive integers starting at 1."""
|
||||
|
||||
return index + 1
|
||||
|
||||
|
||||
def count_from_n_factory(start):
|
||||
"""Numbering function: consecutive integers starting at arbitrary start."""
|
||||
|
||||
def f(index, collection):
|
||||
return index + start
|
||||
|
||||
try:
|
||||
f.__name__ = "count_from_%i" % start
|
||||
except TypeError:
|
||||
pass
|
||||
return f
|
||||
|
||||
|
||||
def _unsugar_count_from(**kw):
|
||||
"""Builds counting functions from keyword arguments.
|
||||
|
||||
Keyword argument filter, prepares a simple ``ordering_func`` from a
|
||||
``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
|
||||
"""
|
||||
|
||||
count_from = kw.pop("count_from", None)
|
||||
if kw.get("ordering_func", None) is None and count_from is not None:
|
||||
if count_from == 0:
|
||||
kw["ordering_func"] = count_from_0
|
||||
elif count_from == 1:
|
||||
kw["ordering_func"] = count_from_1
|
||||
else:
|
||||
kw["ordering_func"] = count_from_n_factory(count_from)
|
||||
return kw
|
||||
|
||||
|
||||
class OrderingList(List[_T]):
|
||||
"""A custom list that manages position information for its children.
|
||||
|
||||
The :class:`.OrderingList` object is normally set up using the
|
||||
:func:`.ordering_list` factory function, used in conjunction with
|
||||
the :func:`_orm.relationship` function.
|
||||
|
||||
"""
|
||||
|
||||
ordering_attr: str
|
||||
ordering_func: OrderingFunc
|
||||
reorder_on_append: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ordering_attr: Optional[str] = None,
|
||||
ordering_func: Optional[OrderingFunc] = None,
|
||||
reorder_on_append: bool = False,
|
||||
):
|
||||
"""A custom list that manages position information for its children.
|
||||
|
||||
``OrderingList`` is a ``collection_class`` list implementation that
|
||||
syncs position in a Python list with a position attribute on the
|
||||
mapped objects.
|
||||
|
||||
This implementation relies on the list starting in the proper order,
|
||||
so be **sure** to put an ``order_by`` on your relationship.
|
||||
|
||||
:param ordering_attr:
|
||||
Name of the attribute that stores the object's order in the
|
||||
relationship.
|
||||
|
||||
:param ordering_func: Optional. A function that maps the position in
|
||||
the Python list to a value to store in the
|
||||
``ordering_attr``. Values returned are usually (but need not be!)
|
||||
integers.
|
||||
|
||||
An ``ordering_func`` is called with two positional parameters: the
|
||||
index of the element in the list, and the list itself.
|
||||
|
||||
If omitted, Python list indexes are used for the attribute values.
|
||||
Two basic pre-built numbering functions are provided in this module:
|
||||
``count_from_0`` and ``count_from_1``. For more exotic examples
|
||||
like stepped numbering, alphabetical and Fibonacci numbering, see
|
||||
the unit tests.
|
||||
|
||||
:param reorder_on_append:
|
||||
Default False. When appending an object with an existing (non-None)
|
||||
ordering value, that value will be left untouched unless
|
||||
``reorder_on_append`` is true. This is an optimization to avoid a
|
||||
variety of dangerous unexpected database writes.
|
||||
|
||||
SQLAlchemy will add instances to the list via append() when your
|
||||
object loads. If for some reason the result set from the database
|
||||
skips a step in the ordering (say, row '1' is missing but you get
|
||||
'2', '3', and '4'), reorder_on_append=True would immediately
|
||||
renumber the items to '1', '2', '3'. If you have multiple sessions
|
||||
making changes, any of whom happen to load this collection even in
|
||||
passing, all of the sessions would try to "clean up" the numbering
|
||||
in their commits, possibly causing all but one to fail with a
|
||||
concurrent modification error.
|
||||
|
||||
Recommend leaving this with the default of False, and just call
|
||||
``reorder()`` if you're doing ``append()`` operations with
|
||||
previously ordered instances or when doing some housekeeping after
|
||||
manual sql operations.
|
||||
|
||||
"""
|
||||
self.ordering_attr = ordering_attr
|
||||
if ordering_func is None:
|
||||
ordering_func = count_from_0
|
||||
self.ordering_func = ordering_func
|
||||
self.reorder_on_append = reorder_on_append
|
||||
|
||||
# More complex serialization schemes (multi column, e.g.) are possible by
|
||||
# subclassing and reimplementing these two methods.
|
||||
def _get_order_value(self, entity):
|
||||
return getattr(entity, self.ordering_attr)
|
||||
|
||||
def _set_order_value(self, entity, value):
|
||||
setattr(entity, self.ordering_attr, value)
|
||||
|
||||
def reorder(self) -> None:
|
||||
"""Synchronize ordering for the entire collection.
|
||||
|
||||
Sweeps through the list and ensures that each object has accurate
|
||||
ordering information set.
|
||||
|
||||
"""
|
||||
for index, entity in enumerate(self):
|
||||
self._order_entity(index, entity, True)
|
||||
|
||||
# As of 0.5, _reorder is no longer semi-private
|
||||
_reorder = reorder
|
||||
|
||||
def _order_entity(self, index, entity, reorder=True):
|
||||
have = self._get_order_value(entity)
|
||||
|
||||
# Don't disturb existing ordering if reorder is False
|
||||
if have is not None and not reorder:
|
||||
return
|
||||
|
||||
should_be = self.ordering_func(index, self)
|
||||
if have != should_be:
|
||||
self._set_order_value(entity, should_be)
|
||||
|
||||
def append(self, entity):
|
||||
super().append(entity)
|
||||
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
|
||||
|
||||
def _raw_append(self, entity):
|
||||
"""Append without any ordering behavior."""
|
||||
|
||||
super().append(entity)
|
||||
|
||||
_raw_append = collection.adds(1)(_raw_append)
|
||||
|
||||
def insert(self, index, entity):
|
||||
super().insert(index, entity)
|
||||
self._reorder()
|
||||
|
||||
def remove(self, entity):
|
||||
super().remove(entity)
|
||||
|
||||
adapter = collection_adapter(self)
|
||||
if adapter and adapter._referenced_by_owner:
|
||||
self._reorder()
|
||||
|
||||
def pop(self, index=-1):
|
||||
entity = super().pop(index)
|
||||
self._reorder()
|
||||
return entity
|
||||
|
||||
def __setitem__(self, index, entity):
|
||||
if isinstance(index, slice):
|
||||
step = index.step or 1
|
||||
start = index.start or 0
|
||||
if start < 0:
|
||||
start += len(self)
|
||||
stop = index.stop or len(self)
|
||||
if stop < 0:
|
||||
stop += len(self)
|
||||
|
||||
for i in range(start, stop, step):
|
||||
self.__setitem__(i, entity[i])
|
||||
else:
|
||||
self._order_entity(index, entity, True)
|
||||
super().__setitem__(index, entity)
|
||||
|
||||
def __delitem__(self, index):
|
||||
super().__delitem__(index)
|
||||
self._reorder()
|
||||
|
||||
def __setslice__(self, start, end, values):
|
||||
super().__setslice__(start, end, values)
|
||||
self._reorder()
|
||||
|
||||
def __delslice__(self, start, end):
|
||||
super().__delslice__(start, end)
|
||||
self._reorder()
|
||||
|
||||
def __reduce__(self):
|
||||
return _reconstitute, (self.__class__, self.__dict__, list(self))
|
||||
|
||||
for func_name, func in list(locals().items()):
|
||||
if (
|
||||
callable(func)
|
||||
and func.__name__ == func_name
|
||||
and not func.__doc__
|
||||
and hasattr(list, func_name)
|
||||
):
|
||||
func.__doc__ = getattr(list, func_name).__doc__
|
||||
del func_name, func
|
||||
|
||||
|
||||
def _reconstitute(cls, dict_, items):
|
||||
"""Reconstitute an :class:`.OrderingList`.
|
||||
|
||||
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
|
||||
unpickling :class:`.OrderingList` objects.
|
||||
|
||||
"""
|
||||
obj = cls.__new__(cls)
|
||||
obj.__dict__.update(dict_)
|
||||
list.extend(obj, items)
|
||||
return obj
|
|
@ -0,0 +1,185 @@
|
|||
# ext/serializer.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
|
||||
allowing "contextual" deserialization.
|
||||
|
||||
.. legacy::
|
||||
|
||||
The serializer extension is **legacy** and should not be used for
|
||||
new development.
|
||||
|
||||
Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
|
||||
or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
|
||||
etc. which are referenced by the structure are not persisted in serialized
|
||||
form, but are instead re-associated with the query structure
|
||||
when it is deserialized.
|
||||
|
||||
.. warning:: The serializer extension uses pickle to serialize and
|
||||
deserialize objects, so the same security consideration mentioned
|
||||
in the `python documentation
|
||||
<https://docs.python.org/3/library/pickle.html>`_ apply.
|
||||
|
||||
Usage is nearly the same as that of the standard Python pickle module::
|
||||
|
||||
from sqlalchemy.ext.serializer import loads, dumps
|
||||
metadata = MetaData(bind=some_engine)
|
||||
Session = scoped_session(sessionmaker())
|
||||
|
||||
# ... define mappers
|
||||
|
||||
query = Session.query(MyClass).
|
||||
filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
|
||||
|
||||
# pickle the query
|
||||
serialized = dumps(query)
|
||||
|
||||
# unpickle. Pass in metadata + scoped_session
|
||||
query2 = loads(serialized, metadata, Session)
|
||||
|
||||
print query2.all()
|
||||
|
||||
Similar restrictions as when using raw pickle apply; mapped classes must be
|
||||
themselves be pickleable, meaning they are importable from a module-level
|
||||
namespace.
|
||||
|
||||
The serializer module is only appropriate for query structures. It is not
|
||||
needed for:
|
||||
|
||||
* instances of user-defined classes. These contain no references to engines,
|
||||
sessions or expression constructs in the typical case and can be serialized
|
||||
directly.
|
||||
|
||||
* Table metadata that is to be loaded entirely from the serialized structure
|
||||
(i.e. is not already declared in the application). Regular
|
||||
pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
|
||||
typically one which was reflected from an existing database at some previous
|
||||
point in time. The serializer module is specifically for the opposite case,
|
||||
where the Table metadata is already present in memory.
|
||||
|
||||
"""
|
||||
|
||||
from io import BytesIO
|
||||
import pickle
|
||||
import re
|
||||
|
||||
from .. import Column
|
||||
from .. import Table
|
||||
from ..engine import Engine
|
||||
from ..orm import class_mapper
|
||||
from ..orm.interfaces import MapperProperty
|
||||
from ..orm.mapper import Mapper
|
||||
from ..orm.session import Session
|
||||
from ..util import b64decode
|
||||
from ..util import b64encode
|
||||
|
||||
|
||||
__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
|
||||
|
||||
|
||||
def Serializer(*args, **kw):
|
||||
pickler = pickle.Pickler(*args, **kw)
|
||||
|
||||
def persistent_id(obj):
|
||||
# print "serializing:", repr(obj)
|
||||
if isinstance(obj, Mapper) and not obj.non_primary:
|
||||
id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
|
||||
elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
|
||||
id_ = (
|
||||
"mapperprop:"
|
||||
+ b64encode(pickle.dumps(obj.parent.class_))
|
||||
+ ":"
|
||||
+ obj.key
|
||||
)
|
||||
elif isinstance(obj, Table):
|
||||
if "parententity" in obj._annotations:
|
||||
id_ = "mapper_selectable:" + b64encode(
|
||||
pickle.dumps(obj._annotations["parententity"].class_)
|
||||
)
|
||||
else:
|
||||
id_ = f"table:{obj.key}"
|
||||
elif isinstance(obj, Column) and isinstance(obj.table, Table):
|
||||
id_ = f"column:{obj.table.key}:{obj.key}"
|
||||
elif isinstance(obj, Session):
|
||||
id_ = "session:"
|
||||
elif isinstance(obj, Engine):
|
||||
id_ = "engine:"
|
||||
else:
|
||||
return None
|
||||
return id_
|
||||
|
||||
pickler.persistent_id = persistent_id
|
||||
return pickler
|
||||
|
||||
|
||||
our_ids = re.compile(
|
||||
r"(mapperprop|mapper|mapper_selectable|table|column|"
|
||||
r"session|attribute|engine):(.*)"
|
||||
)
|
||||
|
||||
|
||||
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
|
||||
unpickler = pickle.Unpickler(file)
|
||||
|
||||
def get_engine():
|
||||
if engine:
|
||||
return engine
|
||||
elif scoped_session and scoped_session().bind:
|
||||
return scoped_session().bind
|
||||
elif metadata and metadata.bind:
|
||||
return metadata.bind
|
||||
else:
|
||||
return None
|
||||
|
||||
def persistent_load(id_):
|
||||
m = our_ids.match(str(id_))
|
||||
if not m:
|
||||
return None
|
||||
else:
|
||||
type_, args = m.group(1, 2)
|
||||
if type_ == "attribute":
|
||||
key, clsarg = args.split(":")
|
||||
cls = pickle.loads(b64decode(clsarg))
|
||||
return getattr(cls, key)
|
||||
elif type_ == "mapper":
|
||||
cls = pickle.loads(b64decode(args))
|
||||
return class_mapper(cls)
|
||||
elif type_ == "mapper_selectable":
|
||||
cls = pickle.loads(b64decode(args))
|
||||
return class_mapper(cls).__clause_element__()
|
||||
elif type_ == "mapperprop":
|
||||
mapper, keyname = args.split(":")
|
||||
cls = pickle.loads(b64decode(mapper))
|
||||
return class_mapper(cls).attrs[keyname]
|
||||
elif type_ == "table":
|
||||
return metadata.tables[args]
|
||||
elif type_ == "column":
|
||||
table, colname = args.split(":")
|
||||
return metadata.tables[table].c[colname]
|
||||
elif type_ == "session":
|
||||
return scoped_session()
|
||||
elif type_ == "engine":
|
||||
return get_engine()
|
||||
else:
|
||||
raise Exception("Unknown token: %s" % type_)
|
||||
|
||||
unpickler.persistent_load = persistent_load
|
||||
return unpickler
|
||||
|
||||
|
||||
def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
|
||||
buf = BytesIO()
|
||||
pickler = Serializer(buf, protocol)
|
||||
pickler.dump(obj)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def loads(data, metadata=None, scoped_session=None, engine=None):
|
||||
buf = BytesIO(data)
|
||||
unpickler = Deserializer(buf, metadata, scoped_session, engine)
|
||||
return unpickler.load()
|
Loading…
Add table
Add a link
Reference in a new issue