Cleaned up the directories

This commit is contained in:
ComputerTech312 2024-02-19 15:34:25 +01:00
parent f708506d68
commit a683fcffea
1340 changed files with 554582 additions and 6840 deletions

View file

@ -0,0 +1,294 @@
# __init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from . import util as _util
from .engine import AdaptedConnection as AdaptedConnection
from .engine import BaseRow as BaseRow
from .engine import BindTyping as BindTyping
from .engine import ChunkedIteratorResult as ChunkedIteratorResult
from .engine import Compiled as Compiled
from .engine import Connection as Connection
from .engine import create_engine as create_engine
from .engine import create_mock_engine as create_mock_engine
from .engine import create_pool_from_url as create_pool_from_url
from .engine import CreateEnginePlugin as CreateEnginePlugin
from .engine import CursorResult as CursorResult
from .engine import Dialect as Dialect
from .engine import Engine as Engine
from .engine import engine_from_config as engine_from_config
from .engine import ExceptionContext as ExceptionContext
from .engine import ExecutionContext as ExecutionContext
from .engine import FrozenResult as FrozenResult
from .engine import Inspector as Inspector
from .engine import IteratorResult as IteratorResult
from .engine import make_url as make_url
from .engine import MappingResult as MappingResult
from .engine import MergedResult as MergedResult
from .engine import NestedTransaction as NestedTransaction
from .engine import Result as Result
from .engine import result_tuple as result_tuple
from .engine import ResultProxy as ResultProxy
from .engine import RootTransaction as RootTransaction
from .engine import Row as Row
from .engine import RowMapping as RowMapping
from .engine import ScalarResult as ScalarResult
from .engine import Transaction as Transaction
from .engine import TwoPhaseTransaction as TwoPhaseTransaction
from .engine import TypeCompiler as TypeCompiler
from .engine import URL as URL
from .inspection import inspect as inspect
from .pool import AssertionPool as AssertionPool
from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool
from .pool import (
FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool,
)
from .pool import NullPool as NullPool
from .pool import Pool as Pool
from .pool import PoolProxiedConnection as PoolProxiedConnection
from .pool import PoolResetState as PoolResetState
from .pool import QueuePool as QueuePool
from .pool import SingletonThreadPool as SingletonThreadPool
from .pool import StaticPool as StaticPool
from .schema import BaseDDLElement as BaseDDLElement
from .schema import BLANK_SCHEMA as BLANK_SCHEMA
from .schema import CheckConstraint as CheckConstraint
from .schema import Column as Column
from .schema import ColumnDefault as ColumnDefault
from .schema import Computed as Computed
from .schema import Constraint as Constraint
from .schema import DDL as DDL
from .schema import DDLElement as DDLElement
from .schema import DefaultClause as DefaultClause
from .schema import ExecutableDDLElement as ExecutableDDLElement
from .schema import FetchedValue as FetchedValue
from .schema import ForeignKey as ForeignKey
from .schema import ForeignKeyConstraint as ForeignKeyConstraint
from .schema import Identity as Identity
from .schema import Index as Index
from .schema import insert_sentinel as insert_sentinel
from .schema import MetaData as MetaData
from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint
from .schema import Sequence as Sequence
from .schema import Table as Table
from .schema import UniqueConstraint as UniqueConstraint
from .sql import ColumnExpressionArgument as ColumnExpressionArgument
from .sql import NotNullable as NotNullable
from .sql import Nullable as Nullable
from .sql import SelectLabelStyle as SelectLabelStyle
from .sql.expression import Alias as Alias
from .sql.expression import alias as alias
from .sql.expression import AliasedReturnsRows as AliasedReturnsRows
from .sql.expression import all_ as all_
from .sql.expression import and_ as and_
from .sql.expression import any_ as any_
from .sql.expression import asc as asc
from .sql.expression import between as between
from .sql.expression import BinaryExpression as BinaryExpression
from .sql.expression import bindparam as bindparam
from .sql.expression import BindParameter as BindParameter
from .sql.expression import bitwise_not as bitwise_not
from .sql.expression import BooleanClauseList as BooleanClauseList
from .sql.expression import CacheKey as CacheKey
from .sql.expression import Case as Case
from .sql.expression import case as case
from .sql.expression import Cast as Cast
from .sql.expression import cast as cast
from .sql.expression import ClauseElement as ClauseElement
from .sql.expression import ClauseList as ClauseList
from .sql.expression import collate as collate
from .sql.expression import CollectionAggregate as CollectionAggregate
from .sql.expression import column as column
from .sql.expression import ColumnClause as ColumnClause
from .sql.expression import ColumnCollection as ColumnCollection
from .sql.expression import ColumnElement as ColumnElement
from .sql.expression import ColumnOperators as ColumnOperators
from .sql.expression import CompoundSelect as CompoundSelect
from .sql.expression import CTE as CTE
from .sql.expression import cte as cte
from .sql.expression import custom_op as custom_op
from .sql.expression import Delete as Delete
from .sql.expression import delete as delete
from .sql.expression import desc as desc
from .sql.expression import distinct as distinct
from .sql.expression import except_ as except_
from .sql.expression import except_all as except_all
from .sql.expression import Executable as Executable
from .sql.expression import Exists as Exists
from .sql.expression import exists as exists
from .sql.expression import Extract as Extract
from .sql.expression import extract as extract
from .sql.expression import false as false
from .sql.expression import False_ as False_
from .sql.expression import FromClause as FromClause
from .sql.expression import FromGrouping as FromGrouping
from .sql.expression import func as func
from .sql.expression import funcfilter as funcfilter
from .sql.expression import Function as Function
from .sql.expression import FunctionElement as FunctionElement
from .sql.expression import FunctionFilter as FunctionFilter
from .sql.expression import GenerativeSelect as GenerativeSelect
from .sql.expression import Grouping as Grouping
from .sql.expression import HasCTE as HasCTE
from .sql.expression import HasPrefixes as HasPrefixes
from .sql.expression import HasSuffixes as HasSuffixes
from .sql.expression import Insert as Insert
from .sql.expression import insert as insert
from .sql.expression import intersect as intersect
from .sql.expression import intersect_all as intersect_all
from .sql.expression import Join as Join
from .sql.expression import join as join
from .sql.expression import Label as Label
from .sql.expression import label as label
from .sql.expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
from .sql.expression import (
LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
)
from .sql.expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE
from .sql.expression import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from .sql.expression import lambda_stmt as lambda_stmt
from .sql.expression import LambdaElement as LambdaElement
from .sql.expression import Lateral as Lateral
from .sql.expression import lateral as lateral
from .sql.expression import literal as literal
from .sql.expression import literal_column as literal_column
from .sql.expression import modifier as modifier
from .sql.expression import not_ as not_
from .sql.expression import Null as Null
from .sql.expression import null as null
from .sql.expression import nulls_first as nulls_first
from .sql.expression import nulls_last as nulls_last
from .sql.expression import nullsfirst as nullsfirst
from .sql.expression import nullslast as nullslast
from .sql.expression import Operators as Operators
from .sql.expression import or_ as or_
from .sql.expression import outerjoin as outerjoin
from .sql.expression import outparam as outparam
from .sql.expression import Over as Over
from .sql.expression import over as over
from .sql.expression import quoted_name as quoted_name
from .sql.expression import ReleaseSavepointClause as ReleaseSavepointClause
from .sql.expression import ReturnsRows as ReturnsRows
from .sql.expression import (
RollbackToSavepointClause as RollbackToSavepointClause,
)
from .sql.expression import SavepointClause as SavepointClause
from .sql.expression import ScalarSelect as ScalarSelect
from .sql.expression import Select as Select
from .sql.expression import select as select
from .sql.expression import Selectable as Selectable
from .sql.expression import SelectBase as SelectBase
from .sql.expression import SQLColumnExpression as SQLColumnExpression
from .sql.expression import StatementLambdaElement as StatementLambdaElement
from .sql.expression import Subquery as Subquery
from .sql.expression import table as table
from .sql.expression import TableClause as TableClause
from .sql.expression import TableSample as TableSample
from .sql.expression import tablesample as tablesample
from .sql.expression import TableValuedAlias as TableValuedAlias
from .sql.expression import text as text
from .sql.expression import TextAsFrom as TextAsFrom
from .sql.expression import TextClause as TextClause
from .sql.expression import TextualSelect as TextualSelect
from .sql.expression import true as true
from .sql.expression import True_ as True_
from .sql.expression import try_cast as try_cast
from .sql.expression import TryCast as TryCast
from .sql.expression import Tuple as Tuple
from .sql.expression import tuple_ as tuple_
from .sql.expression import type_coerce as type_coerce
from .sql.expression import TypeClause as TypeClause
from .sql.expression import TypeCoerce as TypeCoerce
from .sql.expression import UnaryExpression as UnaryExpression
from .sql.expression import union as union
from .sql.expression import union_all as union_all
from .sql.expression import Update as Update
from .sql.expression import update as update
from .sql.expression import UpdateBase as UpdateBase
from .sql.expression import Values as Values
from .sql.expression import values as values
from .sql.expression import ValuesBase as ValuesBase
from .sql.expression import Visitable as Visitable
from .sql.expression import within_group as within_group
from .sql.expression import WithinGroup as WithinGroup
from .types import ARRAY as ARRAY
from .types import BIGINT as BIGINT
from .types import BigInteger as BigInteger
from .types import BINARY as BINARY
from .types import BLOB as BLOB
from .types import BOOLEAN as BOOLEAN
from .types import Boolean as Boolean
from .types import CHAR as CHAR
from .types import CLOB as CLOB
from .types import DATE as DATE
from .types import Date as Date
from .types import DATETIME as DATETIME
from .types import DateTime as DateTime
from .types import DECIMAL as DECIMAL
from .types import DOUBLE as DOUBLE
from .types import Double as Double
from .types import DOUBLE_PRECISION as DOUBLE_PRECISION
from .types import Enum as Enum
from .types import FLOAT as FLOAT
from .types import Float as Float
from .types import INT as INT
from .types import INTEGER as INTEGER
from .types import Integer as Integer
from .types import Interval as Interval
from .types import JSON as JSON
from .types import LargeBinary as LargeBinary
from .types import NCHAR as NCHAR
from .types import NUMERIC as NUMERIC
from .types import Numeric as Numeric
from .types import NVARCHAR as NVARCHAR
from .types import PickleType as PickleType
from .types import REAL as REAL
from .types import SMALLINT as SMALLINT
from .types import SmallInteger as SmallInteger
from .types import String as String
from .types import TEXT as TEXT
from .types import Text as Text
from .types import TIME as TIME
from .types import Time as Time
from .types import TIMESTAMP as TIMESTAMP
from .types import TupleType as TupleType
from .types import TypeDecorator as TypeDecorator
from .types import Unicode as Unicode
from .types import UnicodeText as UnicodeText
from .types import UUID as UUID
from .types import Uuid as Uuid
from .types import VARBINARY as VARBINARY
from .types import VARCHAR as VARCHAR
__version__ = "2.0.27"
def __go(lcls: Any) -> None:
_util.preloaded.import_prefix("sqlalchemy")
from . import exc
exc._version_token = "".join(__version__.split(".")[0:2])
__go(locals())
def __getattr__(name: str) -> Any:
if name == "SingleonThreadPool":
_util.warn_deprecated(
"SingleonThreadPool was a typo in the v2 series. "
"Please use the correct SingletonThreadPool name.",
"2.0.24",
)
return SingletonThreadPool
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View file

@ -0,0 +1,18 @@
# connectors/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from ..engine.interfaces import Dialect
class Connector(Dialect):
"""Base class for dialect mixins, for DBAPIs that work
across entirely different database backends.
Currently the only such mixin is pyodbc.
"""

View file

@ -0,0 +1,174 @@
# connectors/aioodbc.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import TYPE_CHECKING
from .asyncio import AsyncAdapt_dbapi_connection
from .asyncio import AsyncAdapt_dbapi_cursor
from .asyncio import AsyncAdapt_dbapi_ss_cursor
from .asyncio import AsyncAdaptFallback_dbapi_connection
from .pyodbc import PyODBCConnector
from .. import pool
from .. import util
from ..util.concurrency import await_fallback
from ..util.concurrency import await_only
if TYPE_CHECKING:
from ..engine.interfaces import ConnectArgsType
from ..engine.url import URL
class AsyncAdapt_aioodbc_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
def setinputsizes(self, *inputsizes):
# see https://github.com/aio-libs/aioodbc/issues/451
return self._cursor._impl.setinputsizes(*inputsizes)
# how it's supposed to work
# return self.await_(self._cursor.setinputsizes(*inputsizes))
class AsyncAdapt_aioodbc_ss_cursor(
AsyncAdapt_aioodbc_cursor, AsyncAdapt_dbapi_ss_cursor
):
__slots__ = ()
class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection):
_cursor_cls = AsyncAdapt_aioodbc_cursor
_ss_cursor_cls = AsyncAdapt_aioodbc_ss_cursor
__slots__ = ()
@property
def autocommit(self):
return self._connection.autocommit
@autocommit.setter
def autocommit(self, value):
# https://github.com/aio-libs/aioodbc/issues/448
# self._connection.autocommit = value
self._connection._conn.autocommit = value
def cursor(self, server_side=False):
# aioodbc sets connection=None when closed and just fails with
# AttributeError here. Here we use the same ProgrammingError +
# message that pyodbc uses, so it triggers is_disconnect() as well.
if self._connection.closed:
raise self.dbapi.ProgrammingError(
"Attempt to use a closed connection."
)
return super().cursor(server_side=server_side)
def rollback(self):
# aioodbc sets connection=None when closed and just fails with
# AttributeError here. should be a no-op
if not self._connection.closed:
super().rollback()
def commit(self):
# aioodbc sets connection=None when closed and just fails with
# AttributeError here. should be a no-op
if not self._connection.closed:
super().commit()
def close(self):
# aioodbc sets connection=None when closed and just fails with
# AttributeError here. should be a no-op
if not self._connection.closed:
super().close()
class AsyncAdaptFallback_aioodbc_connection(
AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aioodbc_connection
):
__slots__ = ()
class AsyncAdapt_aioodbc_dbapi:
def __init__(self, aioodbc, pyodbc):
self.aioodbc = aioodbc
self.pyodbc = pyodbc
self.paramstyle = pyodbc.paramstyle
self._init_dbapi_attributes()
self.Cursor = AsyncAdapt_dbapi_cursor
self.version = pyodbc.version
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
"InterfaceError",
"DataError",
"DatabaseError",
"OperationalError",
"InterfaceError",
"IntegrityError",
"ProgrammingError",
"InternalError",
"NotSupportedError",
"NUMBER",
"STRING",
"DATETIME",
"BINARY",
"Binary",
"BinaryNull",
"SQL_VARCHAR",
"SQL_WVARCHAR",
):
setattr(self, name, getattr(self.pyodbc, name))
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.aioodbc.connect)
if util.asbool(async_fallback):
return AsyncAdaptFallback_aioodbc_connection(
self,
await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_aioodbc_connection(
self,
await_only(creator_fn(*arg, **kw)),
)
class aiodbcConnector(PyODBCConnector):
is_async = True
supports_statement_cache = True
supports_server_side_cursors = True
@classmethod
def import_dbapi(cls):
return AsyncAdapt_aioodbc_dbapi(
__import__("aioodbc"), __import__("pyodbc")
)
def create_connect_args(self, url: URL) -> ConnectArgsType:
arg, kw = super().create_connect_args(url)
if arg and arg[0]:
kw["dsn"] = arg[0]
return (), kw
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def get_driver_connection(self, connection):
return connection._connection

View file

@ -0,0 +1,208 @@
# connectors/asyncio.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""generic asyncio-adapted versions of DBAPI connection and cursor"""
from __future__ import annotations
import collections
import itertools
from ..engine import AdaptedConnection
from ..util.concurrency import asyncio
from ..util.concurrency import await_fallback
from ..util.concurrency import await_only
class AsyncAdapt_dbapi_cursor:
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
self._cursor = self._aenter_cursor(cursor)
self._rows = collections.deque()
def _aenter_cursor(self, cursor):
return self.await_(cursor.__aenter__())
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. see notes in aiomysql dialect
self._rows.clear()
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._execute_mutex:
result = await self._cursor.execute(operation, parameters or ())
if self._cursor.description and not self.server_side:
self._rows = collections.deque(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._execute_mutex:
return await self._cursor.executemany(operation, seq_of_parameters)
def nextset(self):
self.await_(self._cursor.nextset())
if self._cursor.description and not self.server_side:
self._rows = collections.deque(
self.await_(self._cursor.fetchall())
)
def setinputsizes(self, *inputsizes):
# NOTE: this is overrridden in aioodbc due to
# see https://github.com/aio-libs/aioodbc/issues/451
# right now
return self.await_(self._cursor.setinputsizes(*inputsizes))
def __iter__(self):
while self._rows:
yield self._rows.popleft()
def fetchone(self):
if self._rows:
return self._rows.popleft()
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
rr = iter(self._rows)
retval = list(itertools.islice(rr, 0, size))
self._rows = collections.deque(rr)
return retval
def fetchall(self):
retval = list(self._rows)
self._rows.clear()
return retval
class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
server_side = True
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
self._cursor = self.await_(cursor.__aenter__())
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_dbapi_connection(AdaptedConnection):
_cursor_cls = AsyncAdapt_dbapi_cursor
_ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
def ping(self, reconnect):
return self.await_(self._connection.ping(reconnect))
def add_output_converter(self, *arg, **kw):
self._connection.add_output_converter(*arg, **kw)
def character_set_name(self):
return self._connection.character_set_name()
@property
def autocommit(self):
return self._connection.autocommit
@autocommit.setter
def autocommit(self, value):
# https://github.com/aio-libs/aioodbc/issues/448
# self._connection.autocommit = value
self._connection._conn.autocommit = value
def cursor(self, server_side=False):
if server_side:
return self._ss_cursor_cls(self)
else:
return self._cursor_cls(self)
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
self.await_(self._connection.close())
class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)

View file

@ -0,0 +1,249 @@
# connectors/pyodbc.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import re
from types import ModuleType
import typing
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from urllib.parse import unquote_plus
from . import Connector
from .. import ExecutionContext
from .. import pool
from .. import util
from ..engine import ConnectArgsType
from ..engine import Connection
from ..engine import interfaces
from ..engine import URL
from ..sql.type_api import TypeEngine
if typing.TYPE_CHECKING:
from ..engine.interfaces import IsolationLevel
class PyODBCConnector(Connector):
driver = "pyodbc"
# this is no longer False for pyodbc in general
supports_sane_rowcount_returning = True
supports_sane_multi_rowcount = False
supports_native_decimal = True
default_paramstyle = "named"
fast_executemany = False
# for non-DSN connections, this *may* be used to
# hold the desired driver name
pyodbc_driver_name: Optional[str] = None
dbapi: ModuleType
def __init__(self, use_setinputsizes: bool = False, **kw: Any):
super().__init__(**kw)
if use_setinputsizes:
self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
@classmethod
def import_dbapi(cls) -> ModuleType:
return __import__("pyodbc")
def create_connect_args(self, url: URL) -> ConnectArgsType:
opts = url.translate_connect_args(username="user")
opts.update(url.query)
keys = opts
query = url.query
connect_args: Dict[str, Any] = {}
connectors: List[str]
for param in ("ansi", "unicode_results", "autocommit"):
if param in keys:
connect_args[param] = util.asbool(keys.pop(param))
if "odbc_connect" in keys:
connectors = [unquote_plus(keys.pop("odbc_connect"))]
else:
def check_quote(token: str) -> str:
if ";" in str(token) or str(token).startswith("{"):
token = "{%s}" % token.replace("}", "}}")
return token
keys = {k: check_quote(v) for k, v in keys.items()}
dsn_connection = "dsn" in keys or (
"host" in keys and "database" not in keys
)
if dsn_connection:
connectors = [
"dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", ""))
]
else:
port = ""
if "port" in keys and "port" not in query:
port = ",%d" % int(keys.pop("port"))
connectors = []
driver = keys.pop("driver", self.pyodbc_driver_name)
if driver is None and keys:
# note if keys is empty, this is a totally blank URL
util.warn(
"No driver name specified; "
"this is expected by PyODBC when using "
"DSN-less connections"
)
else:
connectors.append("DRIVER={%s}" % driver)
connectors.extend(
[
"Server=%s%s" % (keys.pop("host", ""), port),
"Database=%s" % keys.pop("database", ""),
]
)
user = keys.pop("user", None)
if user:
connectors.append("UID=%s" % user)
pwd = keys.pop("password", "")
if pwd:
connectors.append("PWD=%s" % pwd)
else:
authentication = keys.pop("authentication", None)
if authentication:
connectors.append("Authentication=%s" % authentication)
else:
connectors.append("Trusted_Connection=Yes")
# if set to 'Yes', the ODBC layer will try to automagically
# convert textual data from your database encoding to your
# client encoding. This should obviously be set to 'No' if
# you query a cp1253 encoded database from a latin1 client...
if "odbc_autotranslate" in keys:
connectors.append(
"AutoTranslate=%s" % keys.pop("odbc_autotranslate")
)
connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()])
return ((";".join(connectors),), connect_args)
def is_disconnect(
self,
e: Exception,
connection: Optional[
Union[pool.PoolProxiedConnection, interfaces.DBAPIConnection]
],
cursor: Optional[interfaces.DBAPICursor],
) -> bool:
if isinstance(e, self.dbapi.ProgrammingError):
return "The cursor's connection has been closed." in str(
e
) or "Attempt to use a closed connection." in str(e)
else:
return False
def _dbapi_version(self) -> interfaces.VersionInfoType:
if not self.dbapi:
return ()
return self._parse_dbapi_version(self.dbapi.version)
def _parse_dbapi_version(self, vers: str) -> interfaces.VersionInfoType:
m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers)
if not m:
return ()
vers_tuple: interfaces.VersionInfoType = tuple(
[int(x) for x in m.group(1).split(".")]
)
if m.group(2):
vers_tuple += (m.group(2),)
return vers_tuple
def _get_server_version_info(
self, connection: Connection
) -> interfaces.VersionInfoType:
# NOTE: this function is not reliable, particularly when
# freetds is in use. Implement database-specific server version
# queries.
dbapi_con = connection.connection.dbapi_connection
version: Tuple[Union[int, str], ...] = ()
r = re.compile(r"[.\-]")
for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): # type: ignore[union-attr] # noqa: E501
try:
version += (int(n),)
except ValueError:
pass
return tuple(version)
def do_set_input_sizes(
self,
cursor: interfaces.DBAPICursor,
list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]],
context: ExecutionContext,
) -> None:
# the rules for these types seems a little strange, as you can pass
# non-tuples as well as tuples, however it seems to assume "0"
# for the subsequent values if you don't pass a tuple which fails
# for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
# that ticket #5649 is targeting.
# NOTE: as of #6058, this won't be called if the use_setinputsizes
# parameter were not passed to the dialect, or if no types were
# specified in list_of_tuples
# as of #8177 for 2.0 we assume use_setinputsizes=True and only
# omit the setinputsizes calls for .executemany() with
# fast_executemany=True
if (
context.execute_style is interfaces.ExecuteStyle.EXECUTEMANY
and self.fast_executemany
):
return
cursor.setinputsizes(
[
(
(dbtype, None, None)
if not isinstance(dbtype, tuple)
else dbtype
)
for key, dbtype, sqltype in list_of_tuples
]
)
def get_isolation_level_values(
self, dbapi_connection: interfaces.DBAPIConnection
) -> List[IsolationLevel]:
return super().get_isolation_level_values(dbapi_connection) + [
"AUTOCOMMIT"
]
def set_isolation_level(
self,
dbapi_connection: interfaces.DBAPIConnection,
level: IsolationLevel,
) -> None:
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
# to work properly
if level == "AUTOCOMMIT":
dbapi_connection.autocommit = True
else:
dbapi_connection.autocommit = False
super().set_isolation_level(dbapi_connection, level)

View file

@ -0,0 +1,6 @@
# cyextension/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

View file

@ -0,0 +1,409 @@
# cyextension/collections.pyx
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
cimport cython
from cpython.long cimport PyLong_FromLongLong
from cpython.set cimport PySet_Add
from collections.abc import Collection
from itertools import filterfalse
cdef bint add_not_present(set seen, object item, hashfunc):
hash_value = hashfunc(item)
if hash_value not in seen:
PySet_Add(seen, hash_value)
return True
else:
return False
cdef list cunique_list(seq, hashfunc=None):
cdef set seen = set()
if not hashfunc:
return [x for x in seq if x not in seen and not PySet_Add(seen, x)]
else:
return [x for x in seq if add_not_present(seen, x, hashfunc)]
def unique_list(seq, hashfunc=None):
return cunique_list(seq, hashfunc)
cdef class OrderedSet(set):
cdef list _list
@classmethod
def __class_getitem__(cls, key):
return cls
def __init__(self, d=None):
set.__init__(self)
if d is not None:
self._list = cunique_list(d)
set.update(self, self._list)
else:
self._list = []
cpdef OrderedSet copy(self):
cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
cp._list = list(self._list)
set.update(cp, cp._list)
return cp
@cython.final
cdef OrderedSet _from_list(self, list new_list):
cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
new._list = new_list
set.update(new, new_list)
return new
def add(self, element):
if element not in self:
self._list.append(element)
PySet_Add(self, element)
def remove(self, element):
# set.remove will raise if element is not in self
set.remove(self, element)
self._list.remove(element)
def pop(self):
try:
value = self._list.pop()
except IndexError:
raise KeyError("pop from an empty set") from None
set.remove(self, value)
return value
def insert(self, Py_ssize_t pos, element):
if element not in self:
self._list.insert(pos, element)
PySet_Add(self, element)
def discard(self, element):
if element in self:
set.remove(self, element)
self._list.remove(element)
def clear(self):
set.clear(self)
self._list = []
def __getitem__(self, key):
return self._list[key]
def __iter__(self):
return iter(self._list)
def __add__(self, other):
return self.union(other)
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
def update(self, *iterables):
for iterable in iterables:
for e in iterable:
if e not in self:
self._list.append(e)
set.add(self, e)
def __ior__(self, iterable):
self.update(iterable)
return self
def union(self, *other):
result = self.copy()
result.update(*other)
return result
def __or__(self, other):
return self.union(other)
def intersection(self, *other):
cdef set other_set = set.intersection(self, *other)
return self._from_list([a for a in self._list if a in other_set])
def __and__(self, other):
return self.intersection(other)
def symmetric_difference(self, other):
cdef set other_set
if isinstance(other, set):
other_set = <set> other
collection = other_set
elif isinstance(other, Collection):
collection = other
other_set = set(other)
else:
collection = list(other)
other_set = set(collection)
result = self._from_list([a for a in self._list if a not in other_set])
result.update(a for a in collection if a not in self)
return result
def __xor__(self, other):
return self.symmetric_difference(other)
def difference(self, *other):
cdef set other_set = set.difference(self, *other)
return self._from_list([a for a in self._list if a in other_set])
def __sub__(self, other):
return self.difference(other)
def intersection_update(self, *other):
set.intersection_update(self, *other)
self._list = [a for a in self._list if a in self]
def __iand__(self, other):
self.intersection_update(other)
return self
cpdef symmetric_difference_update(self, other):
collection = other if isinstance(other, Collection) else list(other)
set.symmetric_difference_update(self, collection)
self._list = [a for a in self._list if a in self]
self._list += [a for a in collection if a in self]
def __ixor__(self, other):
self.symmetric_difference_update(other)
return self
def difference_update(self, *other):
set.difference_update(self, *other)
self._list = [a for a in self._list if a in self]
def __isub__(self, other):
self.difference_update(other)
return self
cdef object cy_id(object item):
return PyLong_FromLongLong(<long long> (<void *>item))
# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped
# instead of the __rmeth__, so they need to check that also self is of the
# correct type. This is fixed in cython 3.x. See:
# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods
cdef class IdentitySet:
"""A set that considers only object id() for uniqueness.
This strategy has edge cases for builtin types- it's possible to have
two 'foo' strings in one of these sets, for example. Use sparingly.
"""
cdef dict _members
def __init__(self, iterable=None):
self._members = {}
if iterable:
self.update(iterable)
def add(self, value):
self._members[cy_id(value)] = value
def __contains__(self, value):
return cy_id(value) in self._members
cpdef remove(self, value):
del self._members[cy_id(value)]
def discard(self, value):
try:
self.remove(value)
except KeyError:
pass
def pop(self):
cdef tuple pair
try:
pair = self._members.popitem()
return pair[1]
except KeyError:
raise KeyError("pop from an empty set")
def clear(self):
self._members.clear()
def __eq__(self, other):
cdef IdentitySet other_
if isinstance(other, IdentitySet):
other_ = other
return self._members == other_._members
else:
return False
def __ne__(self, other):
cdef IdentitySet other_
if isinstance(other, IdentitySet):
other_ = other
return self._members != other_._members
else:
return True
cpdef issubset(self, iterable):
cdef IdentitySet other
if isinstance(iterable, self.__class__):
other = iterable
else:
other = self.__class__(iterable)
if len(self) > len(other):
return False
for m in filterfalse(other._members.__contains__, self._members):
return False
return True
def __le__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issubset(other)
def __lt__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) < len(other) and self.issubset(other)
cpdef issuperset(self, iterable):
cdef IdentitySet other
if isinstance(iterable, self.__class__):
other = iterable
else:
other = self.__class__(iterable)
if len(self) < len(other):
return False
for m in filterfalse(self._members.__contains__, other._members):
return False
return True
def __ge__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
return self.issuperset(other)
def __gt__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
return len(self) > len(other) and self.issuperset(other)
cpdef IdentitySet union(self, iterable):
cdef IdentitySet result = self.__class__()
result._members.update(self._members)
result.update(iterable)
return result
def __or__(self, other):
if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
return NotImplemented
return self.union(other)
cpdef update(self, iterable):
for obj in iterable:
self._members[cy_id(obj)] = obj
def __ior__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
self.update(other)
return self
cpdef IdentitySet difference(self, iterable):
cdef IdentitySet result = self.__new__(self.__class__)
if isinstance(iterable, self.__class__):
other = (<IdentitySet>iterable)._members
else:
other = {cy_id(obj) for obj in iterable}
result._members = {k:v for k, v in self._members.items() if k not in other}
return result
def __sub__(self, other):
if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
return NotImplemented
return self.difference(other)
cpdef difference_update(self, iterable):
cdef IdentitySet other = self.difference(iterable)
self._members = other._members
def __isub__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
self.difference_update(other)
return self
cpdef IdentitySet intersection(self, iterable):
cdef IdentitySet result = self.__new__(self.__class__)
if isinstance(iterable, self.__class__):
other = (<IdentitySet>iterable)._members
else:
other = {cy_id(obj) for obj in iterable}
result._members = {k: v for k, v in self._members.items() if k in other}
return result
def __and__(self, other):
if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
return NotImplemented
return self.intersection(other)
cpdef intersection_update(self, iterable):
cdef IdentitySet other = self.intersection(iterable)
self._members = other._members
def __iand__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
self.intersection_update(other)
return self
cpdef IdentitySet symmetric_difference(self, iterable):
cdef IdentitySet result = self.__new__(self.__class__)
cdef dict other
if isinstance(iterable, self.__class__):
other = (<IdentitySet>iterable)._members
else:
other = {cy_id(obj): obj for obj in iterable}
result._members = {k: v for k, v in self._members.items() if k not in other}
result._members.update(
[(k, v) for k, v in other.items() if k not in self._members]
)
return result
def __xor__(self, other):
if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet):
return NotImplemented
return self.symmetric_difference(other)
cpdef symmetric_difference_update(self, iterable):
cdef IdentitySet other = self.symmetric_difference(iterable)
self._members = other._members
def __ixor__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
self.symmetric_difference(other)
return self
cpdef IdentitySet copy(self):
cdef IdentitySet cp = self.__new__(self.__class__)
cp._members = self._members.copy()
return cp
def __copy__(self):
return self.copy()
def __len__(self):
return len(self._members)
def __iter__(self):
return iter(self._members.values())
def __hash__(self):
raise TypeError("set objects are unhashable")
def __repr__(self):
return "%s(%r)" % (type(self).__name__, list(self._members.values()))

View file

@ -0,0 +1,8 @@
# cyextension/immutabledict.pxd
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
cdef class immutabledict(dict):
pass

View file

@ -0,0 +1,133 @@
# cyextension/immutabledict.pyx
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size
def _readonly_fn(obj):
raise TypeError(
"%s object is immutable and/or readonly" % obj.__class__.__name__)
def _immutable_fn(obj):
raise TypeError(
"%s object is immutable" % obj.__class__.__name__)
class ReadOnlyContainer:
__slots__ = ()
def _readonly(self, *a,**kw):
_readonly_fn(self)
__delitem__ = __setitem__ = __setattr__ = _readonly
class ImmutableDictBase(dict):
def _immutable(self, *a,**kw):
_immutable_fn(self)
@classmethod
def __class_getitem__(cls, key):
return cls
__delitem__ = __setitem__ = __setattr__ = _immutable
clear = pop = popitem = setdefault = update = _immutable
cdef class immutabledict(dict):
def __repr__(self):
return f"immutabledict({dict.__repr__(self)})"
@classmethod
def __class_getitem__(cls, key):
return cls
def union(self, *args, **kw):
cdef dict to_merge = None
cdef immutabledict result
cdef Py_ssize_t args_len = len(args)
if args_len > 1:
raise TypeError(
f'union expected at most 1 argument, got {args_len}'
)
if args_len == 1:
attribute = args[0]
if isinstance(attribute, dict):
to_merge = <dict> attribute
if to_merge is None:
to_merge = dict(*args, **kw)
if PyDict_Size(to_merge) == 0:
return self
# new + update is faster than immutabledict(self)
result = immutabledict()
PyDict_Update(result, self)
PyDict_Update(result, to_merge)
return result
def merge_with(self, *other):
cdef immutabledict result = None
cdef object d
cdef bint update = False
if not other:
return self
for d in other:
if d:
if update == False:
update = True
# new + update is faster than immutabledict(self)
result = immutabledict()
PyDict_Update(result, self)
PyDict_Update(
result, <dict>(d if isinstance(d, dict) else dict(d))
)
return self if update == False else result
def copy(self):
return self
def __reduce__(self):
return immutabledict, (dict(self), )
def __delitem__(self, k):
_immutable_fn(self)
def __setitem__(self, k, v):
_immutable_fn(self)
def __setattr__(self, k, v):
_immutable_fn(self)
def clear(self, *args, **kw):
_immutable_fn(self)
def pop(self, *args, **kw):
_immutable_fn(self)
def popitem(self, *args, **kw):
_immutable_fn(self)
def setdefault(self, *args, **kw):
_immutable_fn(self)
def update(self, *args, **kw):
_immutable_fn(self)
# PEP 584
def __ior__(self, other):
_immutable_fn(self)
def __or__(self, other):
return immutabledict(dict.__or__(self, other))
def __ror__(self, other):
# NOTE: this is used only in cython 3.x;
# version 0.x will call __or__ with args inversed
return immutabledict(dict.__ror__(self, other))

View file

@ -0,0 +1,68 @@
# cyextension/processors.pyx
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import datetime
from datetime import datetime as datetime_cls
from datetime import time as time_cls
from datetime import date as date_cls
import re
from cpython.object cimport PyObject_Str
from cpython.unicode cimport PyUnicode_AsASCIIString, PyUnicode_Check, PyUnicode_Decode
from libc.stdio cimport sscanf
def int_to_boolean(value):
if value is None:
return None
return True if value else False
def to_str(value):
return PyObject_Str(value) if value is not None else None
def to_float(value):
return float(value) if value is not None else None
cdef inline bytes to_bytes(object value, str type_name):
try:
return PyUnicode_AsASCIIString(value)
except Exception as e:
raise ValueError(
f"Couldn't parse {type_name} string '{value!r}' "
"- value is not a string."
) from e
def str_to_datetime(value):
if value is not None:
value = datetime_cls.fromisoformat(value)
return value
def str_to_time(value):
if value is not None:
value = time_cls.fromisoformat(value)
return value
def str_to_date(value):
if value is not None:
value = date_cls.fromisoformat(value)
return value
cdef class DecimalResultProcessor:
cdef object type_
cdef str format_
def __cinit__(self, type_, format_):
self.type_ = type_
self.format_ = format_
def process(self, object value):
if value is None:
return None
else:
return self.type_(self.format_ % value)

View file

@ -0,0 +1,102 @@
# cyextension/resultproxy.pyx
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import operator
cdef class BaseRow:
cdef readonly object _parent
cdef readonly dict _key_to_index
cdef readonly tuple _data
def __init__(self, object parent, object processors, dict key_to_index, object data):
"""Row objects are constructed by CursorResult objects."""
self._parent = parent
self._key_to_index = key_to_index
if processors:
self._data = _apply_processors(processors, data)
else:
self._data = tuple(data)
def __reduce__(self):
return (
rowproxy_reconstructor,
(self.__class__, self.__getstate__()),
)
def __getstate__(self):
return {"_parent": self._parent, "_data": self._data}
def __setstate__(self, dict state):
parent = state["_parent"]
self._parent = parent
self._data = state["_data"]
self._key_to_index = parent._key_to_index
def _values_impl(self):
return list(self)
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
def __hash__(self):
return hash(self._data)
def __getitem__(self, index):
return self._data[index]
def _get_by_key_impl_mapping(self, key):
return self._get_by_key_impl(key, 0)
cdef _get_by_key_impl(self, object key, int attr_err):
index = self._key_to_index.get(key)
if index is not None:
return self._data[<int>index]
self._parent._key_not_found(key, attr_err != 0)
def __getattr__(self, name):
return self._get_by_key_impl(name, 1)
def _to_tuple_instance(self):
return self._data
cdef tuple _apply_processors(proc, data):
res = []
for i in range(len(proc)):
p = proc[i]
if p is None:
res.append(data[i])
else:
res.append(p(data[i]))
return tuple(res)
def rowproxy_reconstructor(cls, state):
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
cdef int is_contiguous(tuple indexes):
cdef int i
for i in range(1, len(indexes)):
if indexes[i-1] != indexes[i] -1:
return 0
return 1
def tuplegetter(*indexes):
if len(indexes) == 1 or is_contiguous(indexes) != 0:
# slice form is faster but returns a list if input is list
return operator.itemgetter(slice(indexes[0], indexes[-1] + 1))
else:
return operator.itemgetter(*indexes)

View file

@ -0,0 +1,91 @@
# cyextension/util.pyx
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from collections.abc import Mapping
from sqlalchemy import exc
cdef tuple _Empty_Tuple = ()
cdef inline bint _mapping_or_tuple(object value):
return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping)
cdef inline bint _check_item(object params) except 0:
cdef object item
cdef bint ret = 1
if params:
item = params[0]
if not _mapping_or_tuple(item):
ret = 0
raise exc.ArgumentError(
"List argument must consist only of tuples or dictionaries"
)
return ret
def _distill_params_20(object params):
if params is None:
return _Empty_Tuple
elif isinstance(params, list) or isinstance(params, tuple):
_check_item(params)
return params
elif isinstance(params, dict) or isinstance(params, Mapping):
return [params]
else:
raise exc.ArgumentError("mapping or list expected for parameters")
def _distill_raw_params(object params):
if params is None:
return _Empty_Tuple
elif isinstance(params, list):
_check_item(params)
return params
elif _mapping_or_tuple(params):
return [params]
else:
raise exc.ArgumentError("mapping or sequence expected for parameters")
cdef class prefix_anon_map(dict):
def __missing__(self, str key):
cdef str derived
cdef int anonymous_counter
cdef dict self_dict = self
derived = key.split(" ", 1)[1]
anonymous_counter = self_dict.get(derived, 1)
self_dict[derived] = anonymous_counter + 1
value = f"{derived}_{anonymous_counter}"
self_dict[key] = value
return value
cdef class cache_anon_map(dict):
cdef int _index
def __init__(self):
self._index = 0
def get_anon(self, obj):
cdef long long idself
cdef str id_
cdef dict self_dict = self
idself = id(obj)
if idself in self_dict:
return self_dict[idself], True
else:
id_ = self.__missing__(idself)
return id_, False
def __missing__(self, key):
cdef str val
cdef dict self_dict = self
self_dict[key] = val = str(self._index)
self._index += 1
return val

View file

@ -0,0 +1,61 @@
# dialects/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Callable
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from .. import util
if TYPE_CHECKING:
from ..engine.interfaces import Dialect
__all__ = ("mssql", "mysql", "oracle", "postgresql", "sqlite")
def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]:
"""default dialect importer.
plugs into the :class:`.PluginLoader`
as a first-hit system.
"""
if "." in name:
dialect, driver = name.split(".")
else:
dialect = name
driver = "base"
try:
if dialect == "mariadb":
# it's "OK" for us to hardcode here since _auto_fn is already
# hardcoded. if mysql / mariadb etc were third party dialects
# they would just publish all the entrypoints, which would actually
# look much nicer.
module = __import__(
"sqlalchemy.dialects.mysql.mariadb"
).dialects.mysql.mariadb
return module.loader(driver) # type: ignore
else:
module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects
module = getattr(module, dialect)
except ImportError:
return None
if hasattr(module, driver):
module = getattr(module, driver)
return lambda: module.dialect
else:
return None
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
plugins = util.PluginLoader("sqlalchemy.plugins")

View file

@ -0,0 +1,25 @@
# dialects/_typing.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Union
from ..sql._typing import _DDLColumnArgument
from ..sql.elements import DQLDMLClauseElement
from ..sql.schema import ColumnCollectionConstraint
from ..sql.schema import Index
_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None]
_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]]
_OnConflictIndexWhereT = Optional[DQLDMLClauseElement]
_OnConflictSetT = Optional[Mapping[Any, Any]]
_OnConflictWhereT = Union[DQLDMLClauseElement, str, None]

View file

@ -0,0 +1,88 @@
# dialects/mssql/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from . import aioodbc # noqa
from . import base # noqa
from . import pymssql # noqa
from . import pyodbc # noqa
from .base import BIGINT
from .base import BINARY
from .base import BIT
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import DATETIME2
from .base import DATETIMEOFFSET
from .base import DECIMAL
from .base import DOUBLE_PRECISION
from .base import FLOAT
from .base import IMAGE
from .base import INTEGER
from .base import JSON
from .base import MONEY
from .base import NCHAR
from .base import NTEXT
from .base import NUMERIC
from .base import NVARCHAR
from .base import REAL
from .base import ROWVERSION
from .base import SMALLDATETIME
from .base import SMALLINT
from .base import SMALLMONEY
from .base import SQL_VARIANT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import TINYINT
from .base import UNIQUEIDENTIFIER
from .base import VARBINARY
from .base import VARCHAR
from .base import XML
from ...sql import try_cast
base.dialect = dialect = pyodbc.dialect
__all__ = (
"JSON",
"INTEGER",
"BIGINT",
"SMALLINT",
"TINYINT",
"VARCHAR",
"NVARCHAR",
"CHAR",
"NCHAR",
"TEXT",
"NTEXT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DATETIME",
"DATETIME2",
"DATETIMEOFFSET",
"DATE",
"DOUBLE_PRECISION",
"TIME",
"SMALLDATETIME",
"BINARY",
"VARBINARY",
"BIT",
"REAL",
"IMAGE",
"TIMESTAMP",
"ROWVERSION",
"MONEY",
"SMALLMONEY",
"UNIQUEIDENTIFIER",
"SQL_VARIANT",
"XML",
"dialect",
"try_cast",
)

View file

@ -0,0 +1,64 @@
# dialects/mssql/aioodbc.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mssql+aioodbc
:name: aioodbc
:dbapi: aioodbc
:connectstring: mssql+aioodbc://<username>:<password>@<dsnname>
:url: https://pypi.org/project/aioodbc/
Support for the SQL Server database in asyncio style, using the aioodbc
driver which itself is a thread-wrapper around pyodbc.
.. versionadded:: 2.0.23 Added the mssql+aioodbc dialect which builds
on top of the pyodbc and general aio* dialect architecture.
Using a special asyncio mediation layer, the aioodbc dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
Most behaviors and caveats for this driver are the same as that of the
pyodbc dialect used on SQL Server; see :ref:`mssql_pyodbc` for general
background.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function; connection
styles are otherwise equivalent to those documented in the pyodbc section::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine(
"mssql+aioodbc://scott:tiger@mssql2017:1433/test?"
"driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
)
"""
from __future__ import annotations
from .pyodbc import MSDialect_pyodbc
from .pyodbc import MSExecutionContext_pyodbc
from ...connectors.aioodbc import aiodbcConnector
class MSExecutionContext_aioodbc(MSExecutionContext_pyodbc):
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
class MSDialectAsync_aioodbc(aiodbcConnector, MSDialect_pyodbc):
driver = "aioodbc"
supports_statement_cache = True
execution_ctx_cls = MSExecutionContext_aioodbc
dialect = MSDialectAsync_aioodbc

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,254 @@
# dialects/mssql/information_schema.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import cast
from ... import Column
from ... import MetaData
from ... import Table
from ...ext.compiler import compiles
from ...sql import expression
from ...types import Boolean
from ...types import Integer
from ...types import Numeric
from ...types import NVARCHAR
from ...types import String
from ...types import TypeDecorator
from ...types import Unicode
ischema = MetaData()
class CoerceUnicode(TypeDecorator):
impl = Unicode
cache_ok = True
def bind_expression(self, bindvalue):
return _cast_on_2005(bindvalue)
class _cast_on_2005(expression.ColumnElement):
def __init__(self, bindvalue):
self.bindvalue = bindvalue
@compiles(_cast_on_2005)
def _compile(element, compiler, **kw):
from . import base
if (
compiler.dialect.server_version_info is None
or compiler.dialect.server_version_info < base.MS_2005_VERSION
):
return compiler.process(element.bindvalue, **kw)
else:
return compiler.process(cast(element.bindvalue, Unicode), **kw)
schemata = Table(
"SCHEMATA",
ischema,
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
schema="INFORMATION_SCHEMA",
)
tables = Table(
"TABLES",
ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("TABLE_TYPE", CoerceUnicode, key="table_type"),
schema="INFORMATION_SCHEMA",
)
columns = Table(
"COLUMNS",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column(
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="INFORMATION_SCHEMA",
)
mssql_temp_table_columns = Table(
"COLUMNS",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column(
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="tempdb.INFORMATION_SCHEMA",
)
constraints = Table(
"TABLE_CONSTRAINTS",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"),
schema="INFORMATION_SCHEMA",
)
column_constraints = Table(
"CONSTRAINT_COLUMN_USAGE",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
schema="INFORMATION_SCHEMA",
)
key_constraints = Table(
"KEY_COLUMN_USAGE",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
schema="INFORMATION_SCHEMA",
)
ref_constraints = Table(
"REFERENTIAL_CONSTRAINTS",
ischema,
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
# TODO: is CATLOG misspelled ?
Column(
"UNIQUE_CONSTRAINT_CATLOG",
CoerceUnicode,
key="unique_constraint_catalog",
),
Column(
"UNIQUE_CONSTRAINT_SCHEMA",
CoerceUnicode,
key="unique_constraint_schema",
),
Column(
"UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"
),
Column("MATCH_OPTION", String, key="match_option"),
Column("UPDATE_RULE", String, key="update_rule"),
Column("DELETE_RULE", String, key="delete_rule"),
schema="INFORMATION_SCHEMA",
)
views = Table(
"VIEWS",
ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
Column("CHECK_OPTION", String, key="check_option"),
Column("IS_UPDATABLE", String, key="is_updatable"),
schema="INFORMATION_SCHEMA",
)
computed_columns = Table(
"computed_columns",
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("is_computed", Boolean),
Column("is_persisted", Boolean),
Column("definition", CoerceUnicode),
schema="sys",
)
sequences = Table(
"SEQUENCES",
ischema,
Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"),
Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"),
Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"),
schema="INFORMATION_SCHEMA",
)
class NumericSqlVariant(TypeDecorator):
r"""This type casts sql_variant columns in the identity_columns view
to numeric. This is required because:
* pyodbc does not support sql_variant
* pymssql under python 2 return the byte representation of the number,
int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
correct value as string.
"""
impl = Unicode
cache_ok = True
def column_expression(self, colexpr):
return cast(colexpr, Numeric(38, 0))
identity_columns = Table(
"identity_columns",
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("is_identity", Boolean),
Column("seed_value", NumericSqlVariant),
Column("increment_value", NumericSqlVariant),
Column("last_value", NumericSqlVariant),
Column("is_not_for_replication", Boolean),
schema="sys",
)
class NVarcharSqlVariant(TypeDecorator):
"""This type casts sql_variant columns in the extended_properties view
to nvarchar. This is required because pyodbc does not support sql_variant
"""
impl = Unicode
cache_ok = True
def column_expression(self, colexpr):
return cast(colexpr, NVARCHAR)
extended_properties = Table(
"extended_properties",
ischema,
Column("class", Integer), # TINYINT
Column("class_desc", CoerceUnicode),
Column("major_id", Integer),
Column("minor_id", Integer),
Column("name", CoerceUnicode),
Column("value", NVarcharSqlVariant),
schema="sys",
)

View file

@ -0,0 +1,133 @@
# dialects/mssql/json.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import types as sqltypes
# technically, all the dialect-specific datatypes that don't have any special
# behaviors would be private with names like _MSJson. However, we haven't been
# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType
# / JSONPathType in their json.py files, so keep consistent with that
# sub-convention for now. A future change can update them all to be
# package-private at once.
class JSON(sqltypes.JSON):
"""MSSQL JSON type.
MSSQL supports JSON-formatted data as of SQL Server 2016.
The :class:`_mssql.JSON` datatype at the DDL level will represent the
datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison
functions as well as Python coercion behavior.
:class:`_mssql.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a SQL Server backend.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The :class:`_mssql.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`_types.JSON`
datatype, by adapting the operations to render the ``JSON_VALUE``
or ``JSON_QUERY`` functions at the database level.
The SQL Server :class:`_mssql.JSON` type necessarily makes use of the
``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements
of a JSON object. These two functions have a major restriction in that
they are **mutually exclusive** based on the type of object to be returned.
The ``JSON_QUERY`` function **only** returns a JSON dictionary or list,
but not an individual string, numeric, or boolean element; the
``JSON_VALUE`` function **only** returns an individual string, numeric,
or boolean element. **both functions either return NULL or raise
an error if they are not used against the correct expected value**.
To handle this awkward requirement, indexed access rules are as follows:
1. When extracting a sub element from a JSON that is itself a JSON
dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
should be used::
stmt = select(
data_table.c.data["some key"].as_json()
).where(
data_table.c.data["some key"].as_json() == {"sub": "structure"}
)
2. When extracting a sub element from a JSON that is a plain boolean,
string, integer, or float, use the appropriate method among
:meth:`_types.JSON.Comparator.as_boolean`,
:meth:`_types.JSON.Comparator.as_string`,
:meth:`_types.JSON.Comparator.as_integer`,
:meth:`_types.JSON.Comparator.as_float`::
stmt = select(
data_table.c.data["some key"].as_string()
).where(
data_table.c.data["some key"].as_string() == "some string"
)
.. versionadded:: 1.4
"""
# note there was a result processor here that was looking for "number",
# but none of the tests seem to exercise it.
# Note: these objects currently match exactly those of MySQL, however since
# these are not generalizable to all JSON implementations, remain separately
# implemented for each dialect.
class _FormatTypeMixin:
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
)
)

View file

@ -0,0 +1,155 @@
# dialects/mssql/provision.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from sqlalchemy import inspect
from sqlalchemy import Integer
from ... import create_engine
from ... import exc
from ...schema import Column
from ...schema import DropConstraint
from ...schema import ForeignKeyConstraint
from ...schema import MetaData
from ...schema import Table
from ...testing.provision import create_db
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import generate_driver_url
from ...testing.provision import get_temp_table_name
from ...testing.provision import log
from ...testing.provision import normalize_sequence
from ...testing.provision import run_reap_dbs
from ...testing.provision import temp_table_keyword_args
@generate_driver_url.for_db("mssql")
def generate_driver_url(url, driver, query_str):
backend = url.get_backend_name()
new_url = url.set(drivername="%s+%s" % (backend, driver))
if driver not in ("pyodbc", "aioodbc"):
new_url = new_url.set(query="")
if driver == "aioodbc":
new_url = new_url.update_query_dict({"MARS_Connection": "Yes"})
if query_str:
new_url = new_url.update_query_string(query_str)
try:
new_url.get_dialect()
except exc.NoSuchModuleError:
return None
else:
return new_url
@create_db.for_db("mssql")
def _mssql_create_db(cfg, eng, ident):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.exec_driver_sql("create database %s" % ident)
conn.exec_driver_sql(
"ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
)
conn.exec_driver_sql(
"ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
)
conn.exec_driver_sql("use %s" % ident)
conn.exec_driver_sql("create schema test_schema")
conn.exec_driver_sql("create schema test_schema_2")
@drop_db.for_db("mssql")
def _mssql_drop_db(cfg, eng, ident):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
_mssql_drop_ignore(conn, ident)
def _mssql_drop_ignore(conn, ident):
try:
# typically when this happens, we can't KILL the session anyway,
# so let the cleanup process drop the DBs
# for row in conn.exec_driver_sql(
# "select session_id from sys.dm_exec_sessions "
# "where database_id=db_id('%s')" % ident):
# log.info("killing SQL server session %s", row['session_id'])
# conn.exec_driver_sql("kill %s" % row['session_id'])
conn.exec_driver_sql("drop database %s" % ident)
log.info("Reaped db: %s", ident)
return True
except exc.DatabaseError as err:
log.warning("couldn't drop db: %s", err)
return False
@run_reap_dbs.for_db("mssql")
def _reap_mssql_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
log.info("identifiers in file: %s", ", ".join(idents))
to_reap = conn.exec_driver_sql(
"select d.name from sys.databases as d where name "
"like 'TEST_%' and not exists (select session_id "
"from sys.dm_exec_sessions "
"where database_id=d.database_id)"
)
all_names = {dbname.lower() for (dbname,) in to_reap}
to_drop = set()
for name in all_names:
if name in idents:
to_drop.add(name)
dropped = total = 0
for total, dbname in enumerate(to_drop, 1):
if _mssql_drop_ignore(conn, dbname):
dropped += 1
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
@temp_table_keyword_args.for_db("mssql")
def _mssql_temp_table_keyword_args(cfg, eng):
return {}
@get_temp_table_name.for_db("mssql")
def _mssql_get_temp_table_name(cfg, eng, base_name):
return "##" + base_name
@drop_all_schema_objects_pre_tables.for_db("mssql")
def drop_all_schema_objects_pre_tables(cfg, eng):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
inspector = inspect(conn)
for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
for tname in inspector.get_table_names(schema=schema):
tb = Table(
tname,
MetaData(),
Column("x", Integer),
Column("y", Integer),
schema=schema,
)
for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
conn.execute(
DropConstraint(
ForeignKeyConstraint(
[tb.c.x], [tb.c.y], name=fk["name"]
)
)
)
@normalize_sequence.for_db("mssql")
def normalize_sequence(cfg, sequence):
if sequence.start is None:
sequence.start = 1
return sequence

View file

@ -0,0 +1,125 @@
# dialects/mssql/pymssql.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
.. dialect:: mssql+pymssql
:name: pymssql
:dbapi: pymssql
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?charset=utf8
pymssql is a Python module that provides a Python DBAPI interface around
`FreeTDS <https://www.freetds.org/>`_.
.. versionchanged:: 2.0.5
pymssql was restored to SQLAlchemy's continuous integration testing
""" # noqa
import re
from .base import MSDialect
from .base import MSIdentifierPreparer
from ... import types as sqltypes
from ... import util
from ...engine import processors
class _MSNumeric_pymssql(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
def __init__(self, dialect):
super().__init__(dialect)
# pymssql has the very unusual behavior that it uses pyformat
# yet does not require that percent signs be doubled
self._double_percents = False
class MSDialect_pymssql(MSDialect):
supports_statement_cache = True
supports_native_decimal = True
supports_native_uuid = True
driver = "pymssql"
preparer = MSIdentifierPreparer_pymssql
colspecs = util.update_copy(
MSDialect.colspecs,
{sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float},
)
@classmethod
def import_dbapi(cls):
module = __import__("pymssql")
# pymmsql < 2.1.1 doesn't have a Binary method. we use string
client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (2, 1, 1):
# TODO: monkeypatching here is less than ideal
module.Binary = lambda x: x if hasattr(x, "decode") else str(x)
if client_ver < (1,):
util.warn(
"The pymssql dialect expects at least "
"the 1.0 series of the pymssql DBAPI."
)
return module
def _get_server_version_info(self, connection):
vers = connection.exec_driver_sql("select @@version").scalar()
m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
return None
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
opts.update(url.query)
port = opts.pop("port", None)
if port and "host" in opts:
opts["host"] = "%s:%s" % (opts["host"], port)
return ([], opts)
def is_disconnect(self, e, connection, cursor):
for msg in (
"Adaptive Server connection timed out",
"Net-Lib error during Connection reset by peer",
"message 20003", # connection timeout
"Error 10054",
"Not connected to any MS SQL server",
"Connection is closed",
"message 20006", # Write to the server failed
"message 20017", # Unexpected EOF from the server
"message 20047", # DBPROCESS is dead or not enabled
):
if msg in str(e):
return True
else:
return False
def get_isolation_level_values(self, dbapi_connection):
return super().get_isolation_level_values(dbapi_connection) + [
"AUTOCOMMIT"
]
def set_isolation_level(self, dbapi_connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit(True)
else:
dbapi_connection.autocommit(False)
super().set_isolation_level(dbapi_connection, level)
dialect = MSDialect_pymssql

View file

@ -0,0 +1,745 @@
# dialects/mssql/pyodbc.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mssql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
:url: https://pypi.org/project/pyodbc/
Connecting to PyODBC
--------------------
The URL here is to be translated to PyODBC connection strings, as
detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
DSN Connections
^^^^^^^^^^^^^^^
A DSN connection in ODBC means that a pre-existing ODBC datasource is
configured on the client machine. The application then specifies the name
of this datasource, which encompasses details such as the specific ODBC driver
in use as well as the network address of the database. Assuming a datasource
is configured on the client, a basic DSN-based connection looks like::
engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
Which above, will pass the following connection string to PyODBC::
DSN=some_dsn;UID=scott;PWD=tiger
If the username and password are omitted, the DSN form will also add
the ``Trusted_Connection=yes`` directive to the ODBC string.
Hostname Connections
^^^^^^^^^^^^^^^^^^^^
Hostname-based connections are also supported by pyodbc. These are often
easier to use than a DSN and have the additional advantage that the specific
database name to connect towards may be specified locally in the URL, rather
than it being fixed as part of a datasource configuration.
When using a hostname connection, the driver name must also be specified in the
query parameters of the URL. As these names usually have spaces in them, the
name must be URL encoded which means using plus signs for spaces::
engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
The ``driver`` keyword is significant to the pyodbc dialect and must be
specified in lowercase.
Any other names passed in the query string are passed through in the pyodbc
connect string, such as ``authentication``, ``TrustServerCertificate``, etc.
Multiple keyword arguments must be separated by an ampersand (``&``); these
will be translated to semicolons when the pyodbc connect string is generated
internally::
e = create_engine(
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?"
"driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
"&authentication=ActiveDirectoryIntegrated"
)
The equivalent URL can be constructed using :class:`_sa.engine.URL`::
from sqlalchemy.engine import URL
connection_url = URL.create(
"mssql+pyodbc",
username="scott",
password="tiger",
host="mssql2017",
port=1433,
database="test",
query={
"driver": "ODBC Driver 18 for SQL Server",
"TrustServerCertificate": "yes",
"authentication": "ActiveDirectoryIntegrated",
},
)
Pass through exact Pyodbc string
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
A PyODBC connection string can also be sent in pyodbc's format directly, as
specified in `the PyODBC documentation
<https://github.com/mkleehammer/pyodbc/wiki/Connecting-to-databases>`_,
using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
can help make this easier::
from sqlalchemy.engine import URL
connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
engine = create_engine(connection_url)
.. _mssql_pyodbc_access_tokens:
Connecting to databases with access tokens
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some database servers are set up to only accept access tokens for login. For
example, SQL Server allows the use of Azure Active Directory tokens to connect
to databases. This requires creating a credential object using the
``azure-identity`` library. More information about the authentication step can be
found in `Microsoft's documentation
<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=bash>`_.
After getting an engine, the credentials need to be sent to ``pyodbc.connect``
each time a connection is requested. One way to do this is to set up an event
listener on the engine that adds the credential token to the dialect's connect
call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For
SQL Server in particular, this is passed as an ODBC connection attribute with
a data structure `described by Microsoft
<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_.
The following code snippet will create an engine that connects to an Azure SQL
database using Azure credentials::
import struct
from sqlalchemy import create_engine, event
from sqlalchemy.engine.url import URL
from azure import identity
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
engine = create_engine(connection_string)
azure_credentials = identity.DefaultAzureCredential()
@event.listens_for(engine, "do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
# create token credential
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
# apply it to keyword arguments
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
.. tip::
The ``Trusted_Connection`` token is currently added by the SQLAlchemy
pyodbc dialect when no username or password is present. This needs
to be removed per Microsoft's
`documentation for Azure access tokens
<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_,
stating that a connection string when using an access token must not contain
``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters.
.. _azure_synapse_ignore_no_transaction_on_rollback:
Avoiding transaction-related exceptions on Azure Synapse Analytics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Azure Synapse Analytics has a significant difference in its transaction
handling compared to plain SQL Server; in some cases an error within a Synapse
transaction can cause it to be arbitrarily terminated on the server side, which
then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to
fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()``
to pass silently if no transaction is present as the driver does not expect
this condition. The symptom of this failure is an exception with a message
resembling 'No corresponding transaction found. (111214)' when attempting to
emit a ``.rollback()`` after an operation had a failure of some kind.
This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
Using the above parameter, the dialect will catch ``ProgrammingError``
exceptions raised during ``connection.rollback()`` and emit a warning
if the error message contains code ``111214``, however will not raise
an exception.
.. versionadded:: 1.4.40 Added the
``ignore_no_transaction_on_rollback=True`` parameter.
Enable autocommit for Azure SQL Data Warehouse (DW) connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Azure SQL Data Warehouse does not support transactions,
and that can cause problems with SQLAlchemy's "autobegin" (and implicit
commit/rollback) behavior. We can avoid these problems by enabling autocommit
at both the pyodbc and engine levels::
connection_url = sa.engine.URL.create(
"mssql+pyodbc",
username="scott",
password="tiger",
host="dw.azure.example.com",
database="mydb",
query={
"driver": "ODBC Driver 17 for SQL Server",
"autocommit": "True",
},
)
engine = create_engine(connection_url).execution_options(
isolation_level="AUTOCOMMIT"
)
Avoiding sending large string parameters as TEXT/NTEXT
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
By default, for historical reasons, Microsoft's ODBC drivers for SQL Server
send long string parameters (greater than 4000 SBCS characters or 2000 Unicode
characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many
years and are starting to cause compatibility issues with newer versions of
SQL_Server/Azure. For example, see `this
issue <https://github.com/mkleehammer/pyodbc/issues/835>`_.
Starting with ODBC Driver 18 for SQL Server we can override the legacy
behavior and pass long strings as varchar(max)/nvarchar(max) using the
``LongAsMax=Yes`` connection string parameter::
connection_url = sa.engine.URL.create(
"mssql+pyodbc",
username="scott",
password="tiger",
host="mssqlserver.example.com",
database="mydb",
query={
"driver": "ODBC Driver 18 for SQL Server",
"LongAsMax": "Yes",
},
)
Pyodbc Pooling / connection close behavior
------------------------------------------
PyODBC uses internal `pooling
<https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ by
default, which means connections will be longer lived than they are within
SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often
preferable to disable this behavior. This behavior can only be disabled
globally at the PyODBC module level, **before** any connections are made::
import pyodbc
pyodbc.pooling = False
# don't use the engine before pooling is set to False
engine = create_engine("mssql+pyodbc://user:pass@dsn")
If this variable is left at its default value of ``True``, **the application
will continue to maintain active database connections**, even when the
SQLAlchemy engine itself fully discards a connection or if the engine is
disposed.
.. seealso::
`pooling <https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ -
in the PyODBC documentation.
Driver / Unicode Support
-------------------------
PyODBC works best with Microsoft ODBC drivers, particularly in the area
of Unicode support on both Python 2 and Python 3.
Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not**
recommended; there have been historically many Unicode-related issues
in this area, including before Microsoft offered ODBC drivers for Linux
and OSX. Now that Microsoft offers drivers for all platforms, for
PyODBC support these are recommended. FreeTDS remains relevant for
non-ODBC drivers such as pymssql where it works very well.
Rowcount Support
----------------
Previous limitations with the SQLAlchemy ORM's "versioned rows" feature with
Pyodbc have been resolved as of SQLAlchemy 2.0.5. See the notes at
:ref:`mssql_rowcount_versioning`.
.. _mssql_pyodbc_fastexecutemany:
Fast Executemany Mode
---------------------
The PyODBC driver includes support for a "fast executemany" mode of execution
which greatly reduces round trips for a DBAPI ``executemany()`` call when using
Microsoft ODBC drivers, for **limited size batches that fit in memory**. The
feature is enabled by setting the attribute ``.fast_executemany`` on the DBAPI
cursor when an executemany call is to be used. The SQLAlchemy PyODBC SQL
Server dialect supports this parameter by passing the
``fast_executemany`` parameter to
:func:`_sa.create_engine` , when using the **Microsoft ODBC driver only**::
engine = create_engine(
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server",
fast_executemany=True)
.. versionchanged:: 2.0.9 - the ``fast_executemany`` parameter now has its
intended effect of this PyODBC feature taking effect for all INSERT
statements that are executed with multiple parameter sets, which don't
include RETURNING. Previously, SQLAlchemy 2.0's :term:`insertmanyvalues`
feature would cause ``fast_executemany`` to not be used in most cases
even if specified.
.. versionadded:: 1.3
.. seealso::
`fast executemany <https://github.com/mkleehammer/pyodbc/wiki/Features-beyond-the-DB-API#fast_executemany>`_
- on github
.. _mssql_pyodbc_setinputsizes:
Setinputsizes Support
-----------------------
As of version 2.0, the pyodbc ``cursor.setinputsizes()`` method is used for
all statement executions, except for ``cursor.executemany()`` calls when
fast_executemany=True where it is not supported (assuming
:ref:`insertmanyvalues <engine_insertmanyvalues>` is kept enabled,
"fastexecutemany" will not take place for INSERT statements in any case).
The use of ``cursor.setinputsizes()`` can be disabled by passing
``use_setinputsizes=False`` to :func:`_sa.create_engine`.
When ``use_setinputsizes`` is left at its default of ``True``, the
specific per-type symbols passed to ``cursor.setinputsizes()`` can be
programmatically customized using the :meth:`.DialectEvents.do_setinputsizes`
hook. See that method for usage examples.
.. versionchanged:: 2.0 The mssql+pyodbc dialect now defaults to using
``use_setinputsizes=True`` for all statement executions with the exception of
cursor.executemany() calls when fast_executemany=True. The behavior can
be turned off by passing ``use_setinputsizes=False`` to
:func:`_sa.create_engine`.
""" # noqa
import datetime
import decimal
import re
import struct
from .base import _MSDateTime
from .base import _MSUnicode
from .base import _MSUnicodeText
from .base import BINARY
from .base import DATETIMEOFFSET
from .base import MSDialect
from .base import MSExecutionContext
from .base import VARBINARY
from .json import JSON as _MSJson
from .json import JSONIndexType as _MSJsonIndexType
from .json import JSONPathType as _MSJsonPathType
from ... import exc
from ... import types as sqltypes
from ... import util
from ...connectors.pyodbc import PyODBCConnector
from ...engine import cursor as _cursor
class _ms_numeric_pyodbc:
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
The routines here are needed for older pyodbc versions
as well as current mxODBC versions.
"""
def bind_processor(self, dialect):
super_process = super().bind_processor(dialect)
if not dialect._need_decimal_fix:
return super_process
def process(value):
if self.asdecimal and isinstance(value, decimal.Decimal):
adjusted = value.adjusted()
if adjusted < 0:
return self._small_dec_to_string(value)
elif adjusted > 7:
return self._large_dec_to_string(value)
if super_process:
return super_process(value)
else:
return value
return process
# these routines needed for older versions of pyodbc.
# as of 2.1.8 this logic is integrated.
def _small_dec_to_string(self, value):
return "%s0.%s%s" % (
(value < 0 and "-" or ""),
"0" * (abs(value.adjusted()) - 1),
"".join([str(nint) for nint in value.as_tuple()[1]]),
)
def _large_dec_to_string(self, value):
_int = value.as_tuple()[1]
if "E" in str(value):
result = "%s%s%s" % (
(value < 0 and "-" or ""),
"".join([str(s) for s in _int]),
"0" * (value.adjusted() - (len(_int) - 1)),
)
else:
if (len(_int) - 1) > value.adjusted():
result = "%s%s.%s" % (
(value < 0 and "-" or ""),
"".join([str(s) for s in _int][0 : value.adjusted() + 1]),
"".join([str(s) for s in _int][value.adjusted() + 1 :]),
)
else:
result = "%s%s" % (
(value < 0 and "-" or ""),
"".join([str(s) for s in _int][0 : value.adjusted() + 1]),
)
return result
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
pass
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
pass
class _ms_binary_pyodbc:
"""Wraps binary values in dialect-specific Binary wrapper.
If the value is null, return a pyodbc-specific BinaryNull
object to prevent pyODBC [and FreeTDS] from defaulting binary
NULL types to SQLWCHAR and causing implicit conversion errors.
"""
def bind_processor(self, dialect):
if dialect.dbapi is None:
return None
DBAPIBinary = dialect.dbapi.Binary
def process(value):
if value is not None:
return DBAPIBinary(value)
else:
# pyodbc-specific
return dialect.dbapi.BinaryNull
return process
class _ODBCDateTimeBindProcessor:
"""Add bind processors to handle datetimeoffset behaviors"""
has_tz = False
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
elif isinstance(value, str):
# if a string was passed directly, allow it through
return value
elif not value.tzinfo or (not self.timezone and not self.has_tz):
# for DateTime(timezone=False)
return value
else:
# for DATETIMEOFFSET or DateTime(timezone=True)
#
# Convert to string format required by T-SQL
dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z")
# offset needs a colon, e.g., -0700 -> -07:00
# "UTC offset in the form (+-)HHMM[SS[.ffffff]]"
# backend currently rejects seconds / fractional seconds
dto_string = re.sub(
r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string
)
return dto_string
return process
class _ODBCDateTime(_ODBCDateTimeBindProcessor, _MSDateTime):
pass
class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET):
has_tz = True
class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY):
pass
class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY):
pass
class _String_pyodbc(sqltypes.String):
def get_dbapi_type(self, dbapi):
if self.length in (None, "max") or self.length >= 2000:
return (dbapi.SQL_VARCHAR, 0, 0)
else:
return dbapi.SQL_VARCHAR
class _Unicode_pyodbc(_MSUnicode):
def get_dbapi_type(self, dbapi):
if self.length in (None, "max") or self.length >= 2000:
return (dbapi.SQL_WVARCHAR, 0, 0)
else:
return dbapi.SQL_WVARCHAR
class _UnicodeText_pyodbc(_MSUnicodeText):
def get_dbapi_type(self, dbapi):
if self.length in (None, "max") or self.length >= 2000:
return (dbapi.SQL_WVARCHAR, 0, 0)
else:
return dbapi.SQL_WVARCHAR
class _JSON_pyodbc(_MSJson):
def get_dbapi_type(self, dbapi):
return (dbapi.SQL_WVARCHAR, 0, 0)
class _JSONIndexType_pyodbc(_MSJsonIndexType):
def get_dbapi_type(self, dbapi):
return dbapi.SQL_WVARCHAR
class _JSONPathType_pyodbc(_MSJsonPathType):
def get_dbapi_type(self, dbapi):
return dbapi.SQL_WVARCHAR
class MSExecutionContext_pyodbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same
statement.
Background on why "scope_identity()" is preferable to "@@identity":
https://msdn.microsoft.com/en-us/library/ms190315.aspx
Background on why we attempt to embed "scope_identity()" into the same
statement as the INSERT:
https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
"""
super().pre_exec()
# don't embed the scope_identity select into an
# "INSERT .. DEFAULT VALUES"
if (
self._select_lastrowid
and self.dialect.use_scope_identity
and len(self.parameters[0])
):
self._embedded_scope_identity = True
self.statement += "; select scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with
# no data (due to triggers, etc.)
while True:
try:
# fetchall() ensures the cursor is consumed
# without closing it (FreeTDS particularly)
rows = self.cursor.fetchall()
except self.dialect.dbapi.Error:
# no way around this - nextset() consumes the previous set
# so we need to just keep flipping
self.cursor.nextset()
else:
if not rows:
# async adapter drivers just return None here
self.cursor.nextset()
continue
row = rows[0]
break
self._lastrowid = int(row[0])
self.cursor_fetch_strategy = _cursor._NO_CURSOR_DML
else:
super().post_exec()
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
supports_statement_cache = True
# note this parameter is no longer used by the ORM or default dialect
# see #9414
supports_sane_rowcount_returning = False
execution_ctx_cls = MSExecutionContext_pyodbc
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric: _MSNumeric_pyodbc,
sqltypes.Float: _MSFloat_pyodbc,
BINARY: _BINARY_pyodbc,
# support DateTime(timezone=True)
sqltypes.DateTime: _ODBCDateTime,
DATETIMEOFFSET: _ODBCDATETIMEOFFSET,
# SQL Server dialect has a VARBINARY that is just to support
# "deprecate_large_types" w/ VARBINARY(max), but also we must
# handle the usual SQL standard VARBINARY
VARBINARY: _VARBINARY_pyodbc,
sqltypes.VARBINARY: _VARBINARY_pyodbc,
sqltypes.LargeBinary: _VARBINARY_pyodbc,
sqltypes.String: _String_pyodbc,
sqltypes.Unicode: _Unicode_pyodbc,
sqltypes.UnicodeText: _UnicodeText_pyodbc,
sqltypes.JSON: _JSON_pyodbc,
sqltypes.JSON.JSONIndexType: _JSONIndexType_pyodbc,
sqltypes.JSON.JSONPathType: _JSONPathType_pyodbc,
# this excludes Enum from the string/VARCHAR thing for now
# it looks like Enum's adaptation doesn't really support the
# String type itself having a dialect-level impl
sqltypes.Enum: sqltypes.Enum,
},
)
def __init__(
self,
fast_executemany=False,
use_setinputsizes=True,
**params,
):
super().__init__(use_setinputsizes=use_setinputsizes, **params)
self.use_scope_identity = (
self.use_scope_identity
and self.dbapi
and hasattr(self.dbapi.Cursor, "nextset")
)
self._need_decimal_fix = self.dbapi and self._dbapi_version() < (
2,
1,
8,
)
self.fast_executemany = fast_executemany
if fast_executemany:
self.use_insertmanyvalues_wo_returning = False
def _get_server_version_info(self, connection):
try:
# "Version of the instance of SQL Server, in the form
# of 'major.minor.build.revision'"
raw = connection.exec_driver_sql(
"SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
).scalar()
except exc.DBAPIError:
# SQL Server docs indicate this function isn't present prior to
# 2008. Before we had the VARCHAR cast above, pyodbc would also
# fail on this query.
return super()._get_server_version_info(connection)
else:
version = []
r = re.compile(r"[.\-]")
for n in r.split(raw):
try:
version.append(int(n))
except ValueError:
pass
return tuple(version)
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn):
if super_ is not None:
super_(conn)
self._setup_timestampoffset_type(conn)
return on_connect
def _setup_timestampoffset_type(self, connection):
# output converter function for datetimeoffset
def _handle_datetimeoffset(dto_value):
tup = struct.unpack("<6hI2h", dto_value)
return datetime.datetime(
tup[0],
tup[1],
tup[2],
tup[3],
tup[4],
tup[5],
tup[6] // 1000,
datetime.timezone(
datetime.timedelta(hours=tup[7], minutes=tup[8])
),
)
odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h
connection.add_output_converter(
odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset
)
def do_executemany(self, cursor, statement, parameters, context=None):
if self.fast_executemany:
cursor.fast_executemany = True
super().do_executemany(cursor, statement, parameters, context=context)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
code = e.args[0]
if code in {
"08S01",
"01000",
"01002",
"08003",
"08007",
"08S02",
"08001",
"HYT00",
"HY010",
"10054",
}:
return True
return super().is_disconnect(e, connection, cursor)
dialect = MSDialect_pyodbc

View file

@ -0,0 +1,101 @@
# dialects/mysql/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from . import aiomysql # noqa
from . import asyncmy # noqa
from . import base # noqa
from . import cymysql # noqa
from . import mariadbconnector # noqa
from . import mysqlconnector # noqa
from . import mysqldb # noqa
from . import pymysql # noqa
from . import pyodbc # noqa
from .base import BIGINT
from .base import BINARY
from .base import BIT
from .base import BLOB
from .base import BOOLEAN
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import DECIMAL
from .base import DOUBLE
from .base import ENUM
from .base import FLOAT
from .base import INTEGER
from .base import JSON
from .base import LONGBLOB
from .base import LONGTEXT
from .base import MEDIUMBLOB
from .base import MEDIUMINT
from .base import MEDIUMTEXT
from .base import NCHAR
from .base import NUMERIC
from .base import NVARCHAR
from .base import REAL
from .base import SET
from .base import SMALLINT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import TINYBLOB
from .base import TINYINT
from .base import TINYTEXT
from .base import VARBINARY
from .base import VARCHAR
from .base import YEAR
from .dml import Insert
from .dml import insert
from .expression import match
from ...util import compat
# default dialect
base.dialect = dialect = mysqldb.dialect
__all__ = (
"BIGINT",
"BINARY",
"BIT",
"BLOB",
"BOOLEAN",
"CHAR",
"DATE",
"DATETIME",
"DECIMAL",
"DOUBLE",
"ENUM",
"FLOAT",
"INTEGER",
"INTEGER",
"JSON",
"LONGBLOB",
"LONGTEXT",
"MEDIUMBLOB",
"MEDIUMINT",
"MEDIUMTEXT",
"NCHAR",
"NVARCHAR",
"NUMERIC",
"SET",
"SMALLINT",
"REAL",
"TEXT",
"TIME",
"TIMESTAMP",
"TINYBLOB",
"TINYINT",
"TINYTEXT",
"VARBINARY",
"VARCHAR",
"YEAR",
"dialect",
"insert",
"Insert",
"match",
)

View file

@ -0,0 +1,332 @@
# dialects/mysql/aiomysql.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+aiomysql
:name: aiomysql
:dbapi: aiomysql
:connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
:url: https://github.com/aio-libs/aiomysql
The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
Using a special asyncio mediation layer, the aiomysql dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
""" # noqa
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_aiomysql_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
# see https://github.com/aio-libs/aiomysql/issues/543
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._execute_mutex:
result = await self._cursor.execute(operation, parameters)
if not self.server_side:
# aiomysql has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._execute_mutex:
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
self._cursor = self.await_(cursor.__aenter__())
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
def ping(self, reconnect):
return self.await_(self._connection.ping(reconnect))
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiomysql_ss_cursor(self)
else:
return AsyncAdapt_aiomysql_cursor(self)
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def terminate(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiomysql_dbapi:
def __init__(self, aiomysql, pymysql):
self.aiomysql = aiomysql
self.pymysql = pymysql
self.paramstyle = "format"
self._init_dbapi_attributes()
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
"InterfaceError",
"DataError",
"DatabaseError",
"OperationalError",
"InterfaceError",
"IntegrityError",
"ProgrammingError",
"InternalError",
"NotSupportedError",
):
setattr(self, name, getattr(self.aiomysql, name))
for name in (
"NUMBER",
"STRING",
"DATETIME",
"BINARY",
"TIMESTAMP",
"Binary",
):
setattr(self, name, getattr(self.pymysql, name))
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiomysql_connection(
self,
await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_aiomysql_connection(
self,
await_only(creator_fn(*arg, **kw)),
)
def _init_cursors_subclasses(self):
# suppress unconditional warning emitted by aiomysql
class Cursor(self.aiomysql.Cursor):
async def _show_warnings(self, conn):
pass
class SSCursor(self.aiomysql.SSCursor):
async def _show_warnings(self, conn):
pass
return Cursor, SSCursor
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
driver = "aiomysql"
supports_statement_cache = True
supports_server_side_cursors = True
_sscursor = AsyncAdapt_aiomysql_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls):
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection) -> None:
dbapi_connection.terminate()
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
str_e = str(e).lower()
return "not connected" in str_e
def _found_rows_client_flag(self):
from pymysql.constants import CLIENT
return CLIENT.FOUND_ROWS
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_aiomysql

View file

@ -0,0 +1,337 @@
# dialects/mysql/asyncmy.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+asyncmy
:name: asyncmy
:dbapi: asyncmy
:connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
:url: https://github.com/long2ice/asyncmy
Using a special asyncio mediation layer, the asyncmy dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
""" # noqa
from contextlib import asynccontextmanager
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_asyncmy_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
if not self.server_side:
# asyncmy has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(
adapt_connection.dbapi.asyncmy.cursors.SSCursor
)
self._cursor = self.await_(cursor.__aenter__())
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
@asynccontextmanager
async def _mutex_and_adapt_errors(self):
async with self._execute_mutex:
try:
yield
except AttributeError:
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
def ping(self, reconnect):
assert not reconnect
return self.await_(self._do_ping())
async def _do_ping(self):
async with self._mutex_and_adapt_errors():
return await self._connection.ping(False)
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_asyncmy_ss_cursor(self)
else:
return AsyncAdapt_asyncmy_cursor(self)
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def terminate(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)
def _Binary(x):
"""Return x as a binary type."""
return bytes(x)
class AsyncAdapt_asyncmy_dbapi:
def __init__(self, asyncmy):
self.asyncmy = asyncmy
self.paramstyle = "format"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
"InterfaceError",
"DataError",
"DatabaseError",
"OperationalError",
"InterfaceError",
"IntegrityError",
"ProgrammingError",
"InternalError",
"NotSupportedError",
):
setattr(self, name, getattr(self.asyncmy.errors, name))
STRING = util.symbol("STRING")
NUMBER = util.symbol("NUMBER")
BINARY = util.symbol("BINARY")
DATETIME = util.symbol("DATETIME")
TIMESTAMP = util.symbol("TIMESTAMP")
Binary = staticmethod(_Binary)
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
if util.asbool(async_fallback):
return AsyncAdaptFallback_asyncmy_connection(
self,
await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_asyncmy_connection(
self,
await_only(creator_fn(*arg, **kw)),
)
class MySQLDialect_asyncmy(MySQLDialect_pymysql):
driver = "asyncmy"
supports_statement_cache = True
supports_server_side_cursors = True
_sscursor = AsyncAdapt_asyncmy_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls):
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection) -> None:
dbapi_connection.terminate()
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
str_e = str(e).lower()
return (
"not connected" in str_e or "network operation failed" in str_e
)
def _found_rows_client_flag(self):
from asyncmy.constants import CLIENT
return CLIENT.FOUND_ROWS
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_asyncmy

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,84 @@
# dialects/mysql/cymysql.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+cymysql
:name: CyMySQL
:dbapi: cymysql
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: https://github.com/nakagami/CyMySQL
.. note::
The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous
integration** and may have unresolved issues. The recommended MySQL
dialects are mysqlclient and PyMySQL.
""" # noqa
from .base import BIT
from .base import MySQLDialect
from .mysqldb import MySQLDialect_mysqldb
from ... import util
class _cymysqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""Convert MySQL's 64 bit, variable length binary string to a long."""
def process(value):
if value is not None:
v = 0
for i in iter(value):
v = v << 8 | i
return v
return value
return process
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
driver = "cymysql"
supports_statement_cache = True
description_encoding = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_unicode_statements = True
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
def import_dbapi(cls):
return __import__("cymysql")
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in (
2006,
2013,
2014,
2045,
2055,
)
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True
else:
return False
dialect = MySQLDialect_cymysql

View file

@ -0,0 +1,219 @@
# dialects/mysql/dml.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Union
from ... import exc
from ... import util
from ...sql._typing import _DMLTableArgument
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ColumnCollection
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.elements import KeyedColumnElement
from ...sql.expression import alias
from ...sql.selectable import NamedFromClause
from ...util.typing import Self
__all__ = ("Insert", "insert")
def insert(table: _DMLTableArgument) -> Insert:
"""Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert`
construct.
.. container:: inherited_member
The :func:`sqlalchemy.dialects.mysql.insert` function creates
a :class:`sqlalchemy.dialects.mysql.Insert`. This class is based
on the dialect-agnostic :class:`_sql.Insert` construct which may
be constructed using the :func:`_sql.insert` function in
SQLAlchemy Core.
The :class:`_mysql.Insert` construct includes additional methods
:meth:`_mysql.Insert.on_duplicate_key_update`.
"""
return Insert(table)
class Insert(StandardInsert):
"""MySQL-specific implementation of INSERT.
Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE.
The :class:`~.mysql.Insert` object is created using the
:func:`sqlalchemy.dialects.mysql.insert` function.
.. versionadded:: 1.2
"""
stringify_dialect = "mysql"
inherit_cache = False
@property
def inserted(
self,
) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE
statement
MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row
that would be inserted, via a special function called ``VALUES()``.
This attribute provides all columns in this row to be referenceable
such that they will render within a ``VALUES()`` function inside the
ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted``
so as not to conflict with the existing
:meth:`_expression.Insert.values` method.
.. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance
of :class:`_expression.ColumnCollection`, which provides an
interface the same as that of the :attr:`_schema.Table.c`
collection described at :ref:`metadata_tables_and_columns`.
With this collection, ordinary names are accessible like attributes
(e.g. ``stmt.inserted.some_column``), but special names and
dictionary method names should be accessed using indexed access,
such as ``stmt.inserted["column name"]`` or
``stmt.inserted["values"]``. See the docstring for
:class:`_expression.ColumnCollection` for further examples.
.. seealso::
:ref:`mysql_insert_on_duplicate_key_update` - example of how
to use :attr:`_expression.Insert.inserted`
"""
return self.inserted_alias.columns
@util.memoized_property
def inserted_alias(self) -> NamedFromClause:
return alias(self.table, name="inserted")
@_generative
@_exclusive_against(
"_post_values_clause",
msgs={
"_post_values_clause": "This Insert construct already "
"has an ON DUPLICATE KEY clause present"
},
)
def on_duplicate_key_update(self, *args: _UpdateArg, **kw: Any) -> Self:
r"""
Specifies the ON DUPLICATE KEY UPDATE clause.
:param \**kw: Column keys linked to UPDATE values. The
values may be any SQL expression or supported literal Python
values.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`_schema.Column.onupdate`.
These values will not be exercised for an ON DUPLICATE KEY UPDATE
style of UPDATE, unless values are manually specified here.
:param \*args: As an alternative to passing key/value parameters,
a dictionary or list of 2-tuples can be passed as a single positional
argument.
Passing a single dictionary is equivalent to the keyword argument
form::
insert().on_duplicate_key_update({"name": "some name"})
Passing a list of 2-tuples indicates that the parameter assignments
in the UPDATE clause should be ordered as sent, in a manner similar
to that described for the :class:`_expression.Update`
construct overall
in :ref:`tutorial_parameter_ordered_updates`::
insert().on_duplicate_key_update(
[("name", "some name"), ("value", "some value")])
.. versionchanged:: 1.3 parameters can be specified as a dictionary
or list of 2-tuples; the latter form provides for parameter
ordering.
.. versionadded:: 1.2
.. seealso::
:ref:`mysql_insert_on_duplicate_key_update`
"""
if args and kw:
raise exc.ArgumentError(
"Can't pass kwargs and positional arguments simultaneously"
)
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary or list of tuples "
"is accepted positionally."
)
values = args[0]
else:
values = kw
self._post_values_clause = OnDuplicateClause(
self.inserted_alias, values
)
return self
class OnDuplicateClause(ClauseElement):
__visit_name__ = "on_duplicate_key_update"
_parameter_ordering: Optional[List[str]] = None
stringify_dialect = "mysql"
def __init__(
self, inserted_alias: NamedFromClause, update: _UpdateArg
) -> None:
self.inserted_alias = inserted_alias
# auto-detect that parameters should be ordered. This is copied from
# Update._proces_colparams(), however we don't look for a special flag
# in this case since we are not disambiguating from other use cases as
# we are in Update.values().
if isinstance(update, list) and (
update and isinstance(update[0], tuple)
):
self._parameter_ordering = [key for key, value in update]
update = dict(update)
if isinstance(update, dict):
if not update:
raise ValueError(
"update parameter dictionary must not be empty"
)
elif isinstance(update, ColumnCollection):
update = dict(update)
else:
raise ValueError(
"update parameter must be a non-empty dictionary "
"or a ColumnCollection such as the `.c.` collection "
"of a Table object"
)
self.update = update
_UpdateArg = Union[
Mapping[Any, Any], List[Tuple[str, Any]], ColumnCollection[Any, Any]
]

View file

@ -0,0 +1,244 @@
# dialects/mysql/enumerated.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import re
from .types import _StringType
from ... import exc
from ... import sql
from ... import util
from ...sql import sqltypes
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
"""MySQL ENUM type."""
__visit_name__ = "ENUM"
native_enum = True
def __init__(self, *enums, **kw):
"""Construct an ENUM.
E.g.::
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values in
enums are not quoted, they will be escaped and surrounded by single
quotes when generating the schema. This object may also be a
PEP-435-compliant enumerated type.
.. versionadded: 1.1 added support for PEP-435-compliant enumerated
types.
:param strict: This flag has no effect.
.. versionchanged:: The MySQL ENUM type as well as the base Enum
type now validates all Python data values.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
kw.pop("strict", None)
self._enum_init(enums, kw)
_StringType.__init__(self, length=self.length, **kw)
@classmethod
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
:class:`.Enum`.
"""
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
def _object_value_for_elem(self, elem):
# mysql sends back a blank string for any value that
# was persisted that was not in the enums; that is, it does no
# validation on the incoming data, it "truncates" it to be
# the blank string. Return it straight.
if elem == "":
return elem
else:
return super()._object_value_for_elem(elem)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
)
class SET(_StringType):
"""MySQL SET type."""
__visit_name__ = "SET"
def __init__(self, *values, **kw):
"""Construct a SET.
E.g.::
Column('myset', SET("foo", "bar", "baz"))
The list of potential values is required in the case that this
set will be used to generate DDL for a table, or if the
:paramref:`.SET.retrieve_as_bitwise` flag is set to True.
:param values: The range of valid values for this SET. The values
are not quoted, they will be escaped and surrounded by single
quotes when generating the schema.
:param convert_unicode: Same flag as that of
:paramref:`.String.convert_unicode`.
:param collation: same as that of :paramref:`.String.collation`
:param charset: same as that of :paramref:`.VARCHAR.charset`.
:param ascii: same as that of :paramref:`.VARCHAR.ascii`.
:param unicode: same as that of :paramref:`.VARCHAR.unicode`.
:param binary: same as that of :paramref:`.VARCHAR.binary`.
:param retrieve_as_bitwise: if True, the data for the set type will be
persisted and selected using an integer value, where a set is coerced
into a bitwise mask for persistence. MySQL allows this mode which
has the advantage of being able to store values unambiguously,
such as the blank string ``''``. The datatype will appear
as the expression ``col + 0`` in a SELECT statement, so that the
value is coerced into an integer value in result sets.
This flag is required if one wishes
to persist a set that can store the blank string ``''`` as a value.
.. warning::
When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
essential that the list of set values is expressed in the
**exact same order** as exists on the MySQL database.
"""
self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
self.values = tuple(values)
if not self.retrieve_as_bitwise and "" in values:
raise exc.ArgumentError(
"Can't use the blank value '' in a SET without "
"setting retrieve_as_bitwise=True"
)
if self.retrieve_as_bitwise:
self._bitmap = {
value: 2**idx for idx, value in enumerate(self.values)
}
self._bitmap.update(
(2**idx, value) for idx, value in enumerate(self.values)
)
length = max([len(v) for v in values] + [0])
kw.setdefault("length", length)
super().__init__(**kw)
def column_expression(self, colexpr):
if self.retrieve_as_bitwise:
return sql.type_coerce(
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
)
else:
return colexpr
def result_processor(self, dialect, coltype):
if self.retrieve_as_bitwise:
def process(value):
if value is not None:
value = int(value)
return set(util.map_bits(self._bitmap.__getitem__, value))
else:
return None
else:
super_convert = super().result_processor(dialect, coltype)
def process(value):
if isinstance(value, str):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
return set(re.findall(r"[^,]+", value))
else:
# mysql-connector-python does a naive
# split(",") which throws in an empty string
if value is not None:
value.discard("")
return value
return process
def bind_processor(self, dialect):
super_convert = super().bind_processor(dialect)
if self.retrieve_as_bitwise:
def process(value):
if value is None:
return None
elif isinstance(value, (int, str)):
if super_convert:
return super_convert(value)
else:
return value
else:
int_value = 0
for v in value:
int_value |= self._bitmap[v]
return int_value
else:
def process(value):
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(value, (int, str)):
value = ",".join(value)
if super_convert:
return super_convert(value)
else:
return value
return process
def adapt(self, impltype, **kw):
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
return util.constructor_copy(self, impltype, *self.values, **kw)
def __repr__(self):
return util.generic_repr(
self,
to_inspect=[SET, _StringType],
additional_kw=[
("retrieve_as_bitwise", False),
],
)

View file

@ -0,0 +1,141 @@
# dialects/mysql/expression.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import exc
from ... import util
from ...sql import coercions
from ...sql import elements
from ...sql import operators
from ...sql import roles
from ...sql.base import _generative
from ...sql.base import Generative
from ...util.typing import Self
class match(Generative, elements.BinaryExpression):
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
E.g.::
from sqlalchemy import desc
from sqlalchemy.dialects.mysql import match
match_expr = match(
users_table.c.firstname,
users_table.c.lastname,
against="Firstname Lastname",
)
stmt = (
select(users_table)
.where(match_expr.in_boolean_mode())
.order_by(desc(match_expr))
)
Would produce SQL resembling::
SELECT id, firstname, lastname
FROM user
WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE)
ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC
The :func:`_mysql.match` function is a standalone version of the
:meth:`_sql.ColumnElement.match` method available on all
SQL expressions, as when :meth:`_expression.ColumnElement.match` is
used, but allows to pass multiple columns
:param cols: column expressions to match against
:param against: expression to be compared towards
:param in_boolean_mode: boolean, set "boolean mode" to true
:param in_natural_language_mode: boolean , set "natural language" to true
:param with_query_expansion: boolean, set "query expansion" to true
.. versionadded:: 1.4.19
.. seealso::
:meth:`_expression.ColumnElement.match`
"""
__visit_name__ = "mysql_match"
inherit_cache = True
def __init__(self, *cols, **kw):
if not cols:
raise exc.ArgumentError("columns are required")
against = kw.pop("against", None)
if against is None:
raise exc.ArgumentError("against is required")
against = coercions.expect(
roles.ExpressionElementRole,
against,
)
left = elements.BooleanClauseList._construct_raw(
operators.comma_op,
clauses=cols,
)
left.group = False
flags = util.immutabledict(
{
"mysql_boolean_mode": kw.pop("in_boolean_mode", False),
"mysql_natural_language": kw.pop(
"in_natural_language_mode", False
),
"mysql_query_expansion": kw.pop("with_query_expansion", False),
}
)
if kw:
raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw)))
super().__init__(left, against, operators.match_op, modifiers=flags)
@_generative
def in_boolean_mode(self) -> Self:
"""Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
:return: a new :class:`_mysql.match` instance with modifications
applied.
"""
self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
return self
@_generative
def in_natural_language_mode(self) -> Self:
"""Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
expression.
:return: a new :class:`_mysql.match` instance with modifications
applied.
"""
self.modifiers = self.modifiers.union({"mysql_natural_language": True})
return self
@_generative
def with_query_expansion(self) -> Self:
"""Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
:return: a new :class:`_mysql.match` instance with modifications
applied.
"""
self.modifiers = self.modifiers.union({"mysql_query_expansion": True})
return self

View file

@ -0,0 +1,81 @@
# dialects/mysql/json.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import types as sqltypes
class JSON(sqltypes.JSON):
"""MySQL JSON type.
MySQL supports JSON as of version 5.7.
MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2.
:class:`_mysql.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a MySQL or MariaDB backend.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The :class:`.mysql.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`_types.JSON`
datatype, by adapting the operations to render the ``JSON_EXTRACT``
function at the database level.
"""
pass
class _FormatTypeMixin:
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
)
)

View file

@ -0,0 +1,32 @@
# dialects/mysql/mariadb.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .base import MariaDBIdentifierPreparer
from .base import MySQLDialect
class MariaDBDialect(MySQLDialect):
is_mariadb = True
supports_statement_cache = True
name = "mariadb"
preparer = MariaDBIdentifierPreparer
def loader(driver):
driver_mod = __import__(
"sqlalchemy.dialects.mysql.%s" % driver
).dialects.mysql
driver_cls = getattr(driver_mod, driver).dialect
return type(
"MariaDBDialect_%s" % driver,
(
MariaDBDialect,
driver_cls,
),
{"supports_statement_cache": True},
)

View file

@ -0,0 +1,282 @@
# dialects/mysql/mariadbconnector.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
.. dialect:: mysql+mariadbconnector
:name: MariaDB Connector/Python
:dbapi: mariadb
:connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mariadb/
Driver Status
-------------
MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
It is written in C and uses MariaDB Connector/C client library for client server
communication.
Note that the default driver for a ``mariadb://`` connection URI continues to
be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
""" # noqa
import re
from uuid import UUID as _python_UUID
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from ... import sql
from ... import util
from ...sql import sqltypes
mariadb_cpy_minimum_version = (1, 0, 1)
class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
# work around JIRA issue
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
# this type can be removed.
def result_processor(self, dialect, coltype):
if self.as_uuid:
def process(value):
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
value = _python_UUID(value)
return value
return process
else:
def process(value):
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
value = str(_python_UUID(value))
return value
return process
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
_lastrowid = None
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(buffered=False)
def create_default_cursor(self):
return self._dbapi_connection.cursor(buffered=True)
def post_exec(self):
super().post_exec()
self._rowcount = self.cursor.rowcount
if self.isinsert and self.compiled.postfetch_lastrowid:
self._lastrowid = self.cursor.lastrowid
@property
def rowcount(self):
if self._rowcount is not None:
return self._rowcount
else:
return self.cursor.rowcount
def get_lastrowid(self):
return self._lastrowid
class MySQLCompiler_mariadbconnector(MySQLCompiler):
pass
class MySQLDialect_mariadbconnector(MySQLDialect):
driver = "mariadbconnector"
supports_statement_cache = True
# set this to True at the module level to prevent the driver from running
# against a backend that server detects as MySQL. currently this appears to
# be unnecessary as MariaDB client libraries have always worked against
# MySQL databases. However, if this changes at some point, this can be
# adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
# this change is made at some point to ensure the correct exception
# is raised at the correct point when running the driver against
# a MySQL backend.
# is_mariadb = True
supports_unicode_statements = True
encoding = "utf8mb4"
convert_unicode = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = "qmark"
execution_ctx_cls = MySQLExecutionContext_mariadbconnector
statement_compiler = MySQLCompiler_mariadbconnector
supports_server_side_cursors = True
colspecs = util.update_copy(
MySQLDialect.colspecs, {sqltypes.Uuid: _MariaDBUUID}
)
@util.memoized_property
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
int(x)
for x in re.findall(
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
)
]
)
else:
return (99, 99, 99)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.paramstyle = "qmark"
if self.dbapi is not None:
if self._dbapi_version < mariadb_cpy_minimum_version:
raise NotImplementedError(
"The minimum required version for MariaDB "
"Connector/Python is %s"
% ".".join(str(x) for x in mariadb_cpy_minimum_version)
)
@classmethod
def import_dbapi(cls):
return __import__("mariadb")
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return "not connected" in str_e or "isn't valid" in str_e
else:
return False
def create_connect_args(self, url):
opts = url.translate_connect_args()
int_params = [
"connect_timeout",
"read_timeout",
"write_timeout",
"client_flag",
"port",
"pool_size",
]
bool_params = [
"local_infile",
"ssl_verify_cert",
"ssl",
"pool_reset_connection",
]
for key in int_params:
util.coerce_kw_type(opts, key, int)
for key in bool_params:
util.coerce_kw_type(opts, key, bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get("client_flag", 0)
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts["client_flag"] = client_flag
return [[], opts]
def _extract_error_code(self, exception):
try:
rc = exception.errno
except:
rc = -1
return rc
def _detect_charset(self, connection):
return "utf8mb4"
def get_isolation_level_values(self, dbapi_connection):
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
)
def set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
connection.autocommit = True
else:
connection.autocommit = False
super().set_isolation_level(connection, level)
def do_begin_twophase(self, connection, xid):
connection.execute(
sql.text("XA BEGIN :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_prepare_twophase(self, connection, xid):
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
connection.execute(
sql.text("XA PREPARE :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
connection.execute(
sql.text("XA ROLLBACK :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(
sql.text("XA COMMIT :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
dialect = MySQLDialect_mariadbconnector

View file

@ -0,0 +1,179 @@
# dialects/mysql/mysqlconnector.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+mysqlconnector
:name: MySQL Connector/Python
:dbapi: myconnpy
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mysql-connector-python/
.. note::
The MySQL Connector/Python DBAPI has had many issues since its release,
some of which may remain unresolved, and the mysqlconnector dialect is
**not tested as part of SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
""" # noqa
import re
from .base import BIT
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLIdentifierPreparer
from ... import util
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return (
self.process(binary.left, **kw)
+ " % "
+ self.process(binary.right, **kw)
)
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
@property
def _double_percents(self):
return False
@_double_percents.setter
def _double_percents(self, value):
pass
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value
class _myconnpyBIT(BIT):
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
class MySQLDialect_mysqlconnector(MySQLDialect):
driver = "mysqlconnector"
supports_statement_cache = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = "format"
statement_compiler = MySQLCompiler_mysqlconnector
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
@classmethod
def import_dbapi(cls):
from mysql import connector
return connector
def do_ping(self, dbapi_connection):
dbapi_connection.ping(False)
return True
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
opts.update(url.query)
util.coerce_kw_type(opts, "allow_local_infile", bool)
util.coerce_kw_type(opts, "autocommit", bool)
util.coerce_kw_type(opts, "buffered", bool)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "connection_timeout", int)
util.coerce_kw_type(opts, "connect_timeout", int)
util.coerce_kw_type(opts, "consume_results", bool)
util.coerce_kw_type(opts, "force_ipv6", bool)
util.coerce_kw_type(opts, "get_warnings", bool)
util.coerce_kw_type(opts, "pool_reset_session", bool)
util.coerce_kw_type(opts, "pool_size", int)
util.coerce_kw_type(opts, "raise_on_warnings", bool)
util.coerce_kw_type(opts, "raw", bool)
util.coerce_kw_type(opts, "ssl_verify_cert", bool)
util.coerce_kw_type(opts, "use_pure", bool)
util.coerce_kw_type(opts, "use_unicode", bool)
# unfortunately, MySQL/connector python refuses to release a
# cursor without reading fully, so non-buffered isn't an option
opts.setdefault("buffered", True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector.constants import ClientFlag
client_flags = opts.get(
"client_flags", ClientFlag.get_default()
)
client_flags |= ClientFlag.FOUND_ROWS
opts["client_flags"] = client_flags
except Exception:
pass
return [[], opts]
@util.memoized_property
def _mysqlconnector_version_info(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return (
e.errno in errnos
or "MySQL Connection not available." in str(e)
or "Connection to MySQL is not available" in str(e)
)
else:
return False
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
_isolation_lookup = {
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
}
def _set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
connection.autocommit = True
else:
connection.autocommit = False
super()._set_isolation_level(connection, level)
dialect = MySQLDialect_mysqlconnector

View file

@ -0,0 +1,308 @@
# dialects/mysql/mysqldb.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
.. dialect:: mysql+mysqldb
:name: mysqlclient (maintained fork of MySQL-Python)
:dbapi: mysqldb
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mysqlclient/
Driver Status
-------------
The mysqlclient DBAPI is a maintained fork of the
`MySQL-Python <https://sourceforge.net/projects/mysql-python>`_ DBAPI
that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3
and is very stable.
.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
.. _mysqldb_unicode:
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
.. _mysqldb_ssl:
SSL Connections
----------------
The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the
key "ssl", which may be specified using the
:paramref:`_sa.create_engine.connect_args` dictionary::
engine = create_engine(
"mysql+mysqldb://scott:tiger@192.168.0.134/test",
connect_args={
"ssl": {
"ca": "/home/gord/client-ssl/ca.pem",
"cert": "/home/gord/client-ssl/client-cert.pem",
"key": "/home/gord/client-ssl/client-key.pem"
}
}
)
For convenience, the following keys may also be specified inline within the URL
where they will be interpreted into the "ssl" dictionary automatically:
"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher",
"ssl_check_hostname". An example is as follows::
connection_uri = (
"mysql+mysqldb://scott:tiger@192.168.0.134/test"
"?ssl_ca=/home/gord/client-ssl/ca.pem"
"&ssl_cert=/home/gord/client-ssl/client-cert.pem"
"&ssl_key=/home/gord/client-ssl/client-key.pem"
)
.. seealso::
:ref:`pymysql_ssl` in the PyMySQL dialect
Using MySQLdb with Google Cloud SQL
-----------------------------------
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
using a URL like the following::
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
Server Side Cursors
-------------------
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
import re
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .base import TEXT
from ... import sql
from ... import util
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
@property
def rowcount(self):
if hasattr(self, "_rowcount"):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLCompiler_mysqldb(MySQLCompiler):
pass
class MySQLDialect_mysqldb(MySQLDialect):
driver = "mysqldb"
supports_statement_cache = True
supports_unicode_statements = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = "format"
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._mysql_dbapi_version = (
self._parse_dbapi_version(self.dbapi.__version__)
if self.dbapi is not None and hasattr(self.dbapi, "__version__")
else (0, 0, 0)
)
def _parse_dbapi_version(self, version):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
else:
return (0, 0, 0)
@util.langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
cursors = __import__("MySQLdb.cursors").cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
return False
@classmethod
def import_dbapi(cls):
return __import__("MySQLdb")
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn):
if super_ is not None:
super_(conn)
charset_name = conn.character_set_name()
if charset_name is not None:
cursor = conn.cursor()
cursor.execute("SET NAMES %s" % charset_name)
cursor.close()
return on_connect
def do_ping(self, dbapi_connection):
dbapi_connection.ping()
return True
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
def _check_unicode_returns(self, connection):
# work around issue fixed in
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8mb4_bin collation and unicode returns
collation = connection.exec_driver_sql(
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
% (
self.identifier_preparer.quote("Charset"),
self.identifier_preparer.quote("Collation"),
)
).scalar()
has_utf8mb4_bin = self.server_version_info > (5,) and collation
if has_utf8mb4_bin:
additional_tests = [
sql.collate(
sql.cast(
sql.literal_column("'test collated returns'"),
TEXT(charset="utf8mb4"),
),
"utf8mb4_bin",
)
]
else:
additional_tests = []
return super()._check_unicode_returns(connection, additional_tests)
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(
database="db", username="user", password="passwd"
)
opts = url.translate_connect_args(**_translate_args)
opts.update(url.query)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "connect_timeout", int)
util.coerce_kw_type(opts, "read_timeout", int)
util.coerce_kw_type(opts, "write_timeout", int)
util.coerce_kw_type(opts, "client_flag", int)
util.coerce_kw_type(opts, "local_infile", int)
# Note: using either of the below will cause all strings to be
# returned as Unicode, both in raw SQL operations and with column
# types like String and MSString.
util.coerce_kw_type(opts, "use_unicode", bool)
util.coerce_kw_type(opts, "charset", str)
# Rich values 'cursorclass' and 'conv' are not supported via
# query string.
ssl = {}
keys = [
("ssl_ca", str),
("ssl_key", str),
("ssl_cert", str),
("ssl_capath", str),
("ssl_cipher", str),
("ssl_check_hostname", bool),
]
for key, kw_type in keys:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], kw_type)
del opts[key]
if ssl:
opts["ssl"] = ssl
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get("client_flag", 0)
client_flag_found_rows = self._found_rows_client_flag()
if client_flag_found_rows is not None:
client_flag |= client_flag_found_rows
opts["client_flag"] = client_flag
return [[], opts]
def _found_rows_client_flag(self):
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
except (AttributeError, ImportError):
return None
else:
return CLIENT_FLAGS.FOUND_ROWS
else:
return None
def _extract_error_code(self, exception):
return exception.args[0]
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
try:
# note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
cset_name = connection.connection.character_set_name
except AttributeError:
util.warn(
"No 'character_set_name' can be detected with "
"this MySQL-Python version; "
"please upgrade to a recent version of MySQL-Python. "
"Assuming latin1."
)
return "latin1"
else:
return cset_name()
def get_isolation_level_values(self, dbapi_connection):
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
)
def set_isolation_level(self, dbapi_connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit(True)
else:
dbapi_connection.autocommit(False)
super().set_isolation_level(dbapi_connection, level)
dialect = MySQLDialect_mysqldb

View file

@ -0,0 +1,107 @@
# dialects/mysql/provision.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import exc
from ...testing.provision import configure_follower
from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import generate_driver_url
from ...testing.provision import temp_table_keyword_args
from ...testing.provision import upsert
@generate_driver_url.for_db("mysql", "mariadb")
def generate_driver_url(url, driver, query_str):
backend = url.get_backend_name()
# NOTE: at the moment, tests are running mariadbconnector
# against both mariadb and mysql backends. if we want this to be
# limited, do the decision making here to reject a "mysql+mariadbconnector"
# URL. Optionally also re-enable the module level
# MySQLDialect_mariadbconnector.is_mysql flag as well, which must include
# a unit and/or functional test.
# all the Jenkins tests have been running mysqlclient Python library
# built against mariadb client drivers for years against all MySQL /
# MariaDB versions going back to MySQL 5.6, currently they can talk
# to MySQL databases without problems.
if backend == "mysql":
dialect_cls = url.get_dialect()
if dialect_cls._is_mariadb_from_url(url):
backend = "mariadb"
new_url = url.set(
drivername="%s+%s" % (backend, driver)
).update_query_string(query_str)
try:
new_url.get_dialect()
except exc.NoSuchModuleError:
return None
else:
return new_url
@create_db.for_db("mysql", "mariadb")
def _mysql_create_db(cfg, eng, ident):
with eng.begin() as conn:
try:
_mysql_drop_db(cfg, conn, ident)
except Exception:
pass
with eng.begin() as conn:
conn.exec_driver_sql(
"CREATE DATABASE %s CHARACTER SET utf8mb4" % ident
)
conn.exec_driver_sql(
"CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident
)
conn.exec_driver_sql(
"CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident
)
@configure_follower.for_db("mysql", "mariadb")
def _mysql_configure_follower(config, ident):
config.test_schema = "%s_test_schema" % ident
config.test_schema_2 = "%s_test_schema_2" % ident
@drop_db.for_db("mysql", "mariadb")
def _mysql_drop_db(cfg, eng, ident):
with eng.begin() as conn:
conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident)
conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident)
conn.exec_driver_sql("DROP DATABASE %s" % ident)
@temp_table_keyword_args.for_db("mysql", "mariadb")
def _mysql_temp_table_keyword_args(cfg, eng):
return {"prefixes": ["TEMPORARY"]}
@upsert.for_db("mariadb")
def _upsert(
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
):
from sqlalchemy.dialects.mysql import insert
stmt = insert(table)
if set_lambda:
stmt = stmt.on_duplicate_key_update(**set_lambda(stmt.inserted))
else:
pk1 = table.primary_key.c[0]
stmt = stmt.on_duplicate_key_update({pk1.key: pk1})
stmt = stmt.returning(
*returning, sort_by_parameter_order=sort_by_parameter_order
)
return stmt

View file

@ -0,0 +1,137 @@
# dialects/mysql/pymysql.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+pymysql
:name: PyMySQL
:dbapi: pymysql
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: https://pymysql.readthedocs.io/
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
.. _pymysql_ssl:
SSL Connections
------------------
The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb,
described at :ref:`mysqldb_ssl`. See that section for additional examples.
If the server uses an automatically-generated certificate that is self-signed
or does not match the host name (as seen from the client), it may also be
necessary to indicate ``ssl_check_hostname=false`` in PyMySQL::
connection_uri = (
"mysql+pymysql://scott:tiger@192.168.0.134/test"
"?ssl_ca=/home/gord/client-ssl/ca.pem"
"&ssl_cert=/home/gord/client-ssl/client-cert.pem"
"&ssl_key=/home/gord/client-ssl/client-key.pem"
"&ssl_check_hostname=false"
)
MySQL-Python Compatibility
--------------------------
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
and targets 100% compatibility. Most behavioral notes for MySQL-python apply
to the pymysql driver as well.
""" # noqa
from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
driver = "pymysql"
supports_statement_cache = True
description_encoding = None
@langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
cursors = __import__("pymysql.cursors").cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
return False
@classmethod
def import_dbapi(cls):
return __import__("pymysql")
@langhelpers.memoized_property
def _send_false_to_ping(self):
"""determine if pymysql has deprecated, changed the default of,
or removed the 'reconnect' argument of connection.ping().
See #10492 and
https://github.com/PyMySQL/mysqlclient/discussions/651#discussioncomment-7308971
for background.
""" # noqa: E501
try:
Connection = __import__(
"pymysql.connections"
).connections.Connection
except (ImportError, AttributeError):
return True
else:
insp = langhelpers.get_callable_argspec(Connection.ping)
try:
reconnect_arg = insp.args[1]
except IndexError:
return False
else:
return reconnect_arg == "reconnect" and (
not insp.defaults or insp.defaults[0] is not False
)
def do_ping(self, dbapi_connection):
if self._send_false_to_ping:
dbapi_connection.ping(False)
else:
dbapi_connection.ping()
return True
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(username="user")
return super().create_connect_args(
url, _translate_args=_translate_args
)
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return (
"already closed" in str_e or "connection was killed" in str_e
)
else:
return False
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
dialect = MySQLDialect_pymysql

View file

@ -0,0 +1,138 @@
# dialects/mysql/pyodbc.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
:url: https://pypi.org/project/pyodbc/
.. note::
The PyODBC for MySQL dialect is **not tested as part of
SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
However, if you want to use the mysql+pyodbc dialect and require
full support for ``utf8mb4`` characters (including supplementary
characters like emoji) be sure to use a current release of
MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode")
version of the driver in your DSN or connection string.
Pass through exact pyodbc connection string::
import urllib
connection_string = (
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
'SERVER=localhost;'
'PORT=3307;'
'DATABASE=mydb;'
'UID=root;'
'PWD=(whatever);'
'charset=utf8mb4;'
)
params = urllib.parse.quote_plus(connection_string)
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
""" # noqa
import re
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .types import TIME
from ... import exc
from ... import util
from ...connectors.pyodbc import PyODBCConnector
from ...sql.sqltypes import Time
class _pyodbcTIME(TIME):
def result_processor(self, dialect, coltype):
def process(value):
# pyodbc returns a datetime.time object; no need to convert
return value
return process
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
supports_statement_cache = True
colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME})
supports_unicode_statements = True
execution_ctx_cls = MySQLExecutionContext_pyodbc
pyodbc_driver_name = "MySQL"
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
# set this to None as _fetch_setting attempts to use it (None is OK)
self._connection_charset = None
try:
value = self._fetch_setting(connection, "character_set_client")
if value:
return value
except exc.DBAPIError:
pass
util.warn(
"Could not detect the connection character set. "
"Assuming latin1."
)
return "latin1"
def _get_server_version_info(self, connection):
return MySQLDialect._get_server_version_info(self, connection)
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
c = m.group(1)
if c:
return int(c)
else:
return None
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn):
if super_ is not None:
super_(conn)
# declare Unicode encoding for pyodbc as per
# https://github.com/mkleehammer/pyodbc/wiki/Unicode
pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR
pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR
conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8")
conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8")
conn.setencoding(encoding="utf-8")
return on_connect
dialect = MySQLDialect_pyodbc

View file

@ -0,0 +1,677 @@
# dialects/mysql/reflection.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import re
from .enumerated import ENUM
from .enumerated import SET
from .types import DATETIME
from .types import TIME
from .types import TIMESTAMP
from ... import log
from ... import types as sqltypes
from ... import util
class ReflectedState:
"""Stores raw information about a SHOW CREATE TABLE statement."""
def __init__(self):
self.columns = []
self.table_options = {}
self.table_name = None
self.keys = []
self.fk_constraints = []
self.ck_constraints = []
@log.class_logger
class MySQLTableDefinitionParser:
"""Parses the results of a SHOW CREATE TABLE statement."""
def __init__(self, dialect, preparer):
self.dialect = dialect
self.preparer = preparer
self._prep_regexes()
def parse(self, show_create, charset):
state = ReflectedState()
state.charset = charset
for line in re.split(r"\r?\n", show_create):
if line.startswith(" " + self.preparer.initial_quote):
self._parse_column(line, state)
# a regular table options line
elif line.startswith(") "):
self._parse_table_options(line, state)
# an ANSI-mode table options line
elif line == ")":
pass
elif line.startswith("CREATE "):
self._parse_table_name(line, state)
elif "PARTITION" in line:
self._parse_partition_options(line, state)
# Not present in real reflection, but may be if
# loading from a file.
elif not line:
pass
else:
type_, spec = self._parse_constraints(line)
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == "key":
state.keys.append(spec)
elif type_ == "fk_constraint":
state.fk_constraints.append(spec)
elif type_ == "ck_constraint":
state.ck_constraints.append(spec)
else:
pass
return state
def _check_view(self, sql: str) -> bool:
return bool(self._re_is_view.match(sql))
def _parse_constraints(self, line):
"""Parse a KEY or CONSTRAINT line.
:param line: A line of SHOW CREATE TABLE output
"""
# KEY
m = self._re_key.match(line)
if m:
spec = m.groupdict()
# convert columns into name, length pairs
# NOTE: we may want to consider SHOW INDEX as the
# format of indexes in MySQL becomes more complex
spec["columns"] = self._parse_keyexprs(spec["columns"])
if spec["version_sql"]:
m2 = self._re_key_version_sql.match(spec["version_sql"])
if m2 and m2.groupdict()["parser"]:
spec["parser"] = m2.groupdict()["parser"]
if spec["parser"]:
spec["parser"] = self.preparer.unformat_identifiers(
spec["parser"]
)[0]
return "key", spec
# FOREIGN KEY CONSTRAINT
m = self._re_fk_constraint.match(line)
if m:
spec = m.groupdict()
spec["table"] = self.preparer.unformat_identifiers(spec["table"])
spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])]
spec["foreign"] = [
c[0] for c in self._parse_keyexprs(spec["foreign"])
]
return "fk_constraint", spec
# CHECK constraint
m = self._re_ck_constraint.match(line)
if m:
spec = m.groupdict()
return "ck_constraint", spec
# PARTITION and SUBPARTITION
m = self._re_partition.match(line)
if m:
# Punt!
return "partition", line
# No match.
return (None, line)
def _parse_table_name(self, line, state):
"""Extract the table name.
:param line: The first line of SHOW CREATE TABLE
"""
regex, cleanup = self._pr_name
m = regex.match(line)
if m:
state.table_name = cleanup(m.group("name"))
def _parse_table_options(self, line, state):
"""Build a dictionary of all reflected table-level options.
:param line: The final line of SHOW CREATE TABLE output.
"""
options = {}
if line and line != ")":
rest_of_line = line
for regex, cleanup in self._pr_options:
m = regex.search(rest_of_line)
if not m:
continue
directive, value = m.group("directive"), m.group("val")
if cleanup:
value = cleanup(value)
options[directive.lower()] = value
rest_of_line = regex.sub("", rest_of_line)
for nope in ("auto_increment", "data directory", "index directory"):
options.pop(nope, None)
for opt, val in options.items():
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_partition_options(self, line, state):
options = {}
new_line = line[:]
while new_line.startswith("(") or new_line.startswith(" "):
new_line = new_line[1:]
for regex, cleanup in self._pr_options:
m = regex.search(new_line)
if not m or "PARTITION" not in regex.pattern:
continue
directive = m.group("directive")
directive = directive.lower()
is_subpartition = directive == "subpartition"
if directive == "partition" or is_subpartition:
new_line = new_line.replace(") */", "")
new_line = new_line.replace(",", "")
if is_subpartition and new_line.endswith(")"):
new_line = new_line[:-1]
if self.dialect.name == "mariadb" and new_line.endswith(")"):
if (
"MAXVALUE" in new_line
or "MINVALUE" in new_line
or "ENGINE" in new_line
):
# final line of MariaDB partition endswith ")"
new_line = new_line[:-1]
defs = "%s_%s_definitions" % (self.dialect.name, directive)
options[defs] = new_line
else:
directive = directive.replace(" ", "_")
value = m.group("val")
if cleanup:
value = cleanup(value)
options[directive] = value
break
for opt, val in options.items():
part_def = "%s_partition_definitions" % (self.dialect.name)
subpart_def = "%s_subpartition_definitions" % (self.dialect.name)
if opt == part_def or opt == subpart_def:
# builds a string of definitions
if opt not in state.table_options:
state.table_options[opt] = val
else:
state.table_options[opt] = "%s, %s" % (
state.table_options[opt],
val,
)
else:
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_column(self, line, state):
"""Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
:param line: Any column-bearing line from SHOW CREATE TABLE
"""
spec = None
m = self._re_column.match(line)
if m:
spec = m.groupdict()
spec["full"] = True
else:
m = self._re_column_loose.match(line)
if m:
spec = m.groupdict()
spec["full"] = False
if not spec:
util.warn("Unknown column definition %r" % line)
return
if not spec["full"]:
util.warn("Incomplete reflection of column definition %r" % line)
name, type_, args = spec["name"], spec["coltype"], spec["arg"]
try:
col_type = self.dialect.ischema_names[type_]
except KeyError:
util.warn(
"Did not recognize type '%s' of column '%s'" % (type_, name)
)
col_type = sqltypes.NullType
# Column type positional arguments eg. varchar(32)
if args is None or args == "":
type_args = []
elif args[0] == "'" and args[-1] == "'":
type_args = self._re_csv_str.findall(args)
else:
type_args = [int(v) for v in self._re_csv_int.findall(args)]
# Column type keyword options
type_kw = {}
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
if type_args:
type_kw["fsp"] = type_args.pop(0)
for kw in ("unsigned", "zerofill"):
if spec.get(kw, False):
type_kw[kw] = True
for kw in ("charset", "collate"):
if spec.get(kw, False):
type_kw[kw] = spec[kw]
if issubclass(col_type, (ENUM, SET)):
type_args = _strip_values(type_args)
if issubclass(col_type, SET) and "" in type_args:
type_kw["retrieve_as_bitwise"] = True
type_instance = col_type(*type_args, **type_kw)
col_kw = {}
# NOT NULL
col_kw["nullable"] = True
# this can be "NULL" in the case of TIMESTAMP
if spec.get("notnull", False) == "NOT NULL":
col_kw["nullable"] = False
# For generated columns, the nullability is marked in a different place
if spec.get("notnull_generated", False) == "NOT NULL":
col_kw["nullable"] = False
# AUTO_INCREMENT
if spec.get("autoincr", False):
col_kw["autoincrement"] = True
elif issubclass(col_type, sqltypes.Integer):
col_kw["autoincrement"] = False
# DEFAULT
default = spec.get("default", None)
if default == "NULL":
# eliminates the need to deal with this later.
default = None
comment = spec.get("comment", None)
if comment is not None:
comment = cleanup_text(comment)
sqltext = spec.get("generated")
if sqltext is not None:
computed = dict(sqltext=sqltext)
persisted = spec.get("persistence")
if persisted is not None:
computed["persisted"] = persisted == "STORED"
col_kw["computed"] = computed
col_d = dict(
name=name, type=type_instance, default=default, comment=comment
)
col_d.update(col_kw)
state.columns.append(col_d)
def _describe_to_create(self, table_name, columns):
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
DESCRIBE is a much simpler reflection and is sufficient for
reflecting views for runtime use. This method formats DDL
for columns only- keys are omitted.
:param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
SHOW FULL COLUMNS FROM rows must be rearranged for use with
this function.
"""
buffer = []
for row in columns:
(name, col_type, nullable, default, extra) = (
row[i] for i in (0, 1, 2, 4, 5)
)
line = [" "]
line.append(self.preparer.quote_identifier(name))
line.append(col_type)
if not nullable:
line.append("NOT NULL")
if default:
if "auto_increment" in default:
pass
elif col_type.startswith("timestamp") and default.startswith(
"C"
):
line.append("DEFAULT")
line.append(default)
elif default == "NULL":
line.append("DEFAULT")
line.append(default)
else:
line.append("DEFAULT")
line.append("'%s'" % default.replace("'", "''"))
if extra:
line.append(extra)
buffer.append(" ".join(line))
return "".join(
[
(
"CREATE TABLE %s (\n"
% self.preparer.quote_identifier(table_name)
),
",\n".join(buffer),
"\n) ",
]
)
def _parse_keyexprs(self, identifiers):
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
return [
(colname, int(length) if length else None, modifiers)
for colname, length, modifiers in self._re_keyexprs.findall(
identifiers
)
]
def _prep_regexes(self):
"""Pre-compile regular expressions."""
self._re_columns = []
self._pr_options = []
_final = self.preparer.final_quote
quotes = dict(
zip(
("iq", "fq", "esc_fq"),
[
re.escape(s)
for s in (
self.preparer.initial_quote,
_final,
self.preparer._escape_identifier(_final),
)
],
)
)
self._pr_name = _pr_compile(
r"^CREATE (?:\w+ +)?TABLE +"
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes,
self.preparer._unescape_identifier,
)
self._re_is_view = _re_compile(r"^CREATE(?! TABLE)(\s.*)?\sVIEW")
# `col`,`col2`(32),`col3`(15) DESC
#
self._re_keyexprs = _re_compile(
r"(?:"
r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)"
r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes
)
# 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27")
# 123 or 123,456
self._re_csv_int = _re_compile(r"\d+")
# `colname` <type> [type opts]
# (NOT NULL | NULL)
# DEFAULT ('value' | CURRENT_TIMESTAMP...)
# COMMENT 'comment'
# COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
# STORAGE (DISK|MEMORY)
self._re_column = _re_compile(
r" "
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"(?P<coltype>\w+)"
r"(?:\((?P<arg>(?:\d+|\d+,\d+|"
r"(?:'(?:''|[^'])*',?)+))\))?"
r"(?: +(?P<unsigned>UNSIGNED))?"
r"(?: +(?P<zerofill>ZEROFILL))?"
r"(?: +CHARACTER SET +(?P<charset>[\w_]+))?"
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
r"(?: +DEFAULT +(?P<default>"
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
r"))?"
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?"
r"(?: +(?P<notnull_generated>(?:NOT )?NULL))?"
r")?"
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
r"(?: +STORAGE +(?P<storage>\w+))?"
r"(?: +(?P<extra>.*))?"
r",?$" % quotes
)
# Fallback, try to parse as little as possible
self._re_column_loose = _re_compile(
r" "
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"(?P<coltype>\w+)"
r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?"
r".*?(?P<notnull>(?:NOT )NULL)?" % quotes
)
# (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
# (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
# KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */
self._re_key = _re_compile(
r" "
r"(?:(?P<type>\S+) )?KEY"
r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?"
r"(?: +USING +(?P<using_pre>\S+))?"
r" +\((?P<columns>.+?)\)"
r"(?: +USING +(?P<using_post>\S+))?"
r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?"
r"(?: +WITH PARSER +(?P<parser>\S+))?"
r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
r",?$" % quotes
)
# https://forums.mysql.com/read.php?20,567102,567111#msg-567111
# It means if the MySQL version >= \d+, execute what's in the comment
self._re_key_version_sql = _re_compile(
r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?"
)
# CONSTRAINT `name` FOREIGN KEY (`local_col`)
# REFERENCES `remote` (`remote_col`)
# MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
# ON DELETE CASCADE ON UPDATE RESTRICT
#
# unique constraints come back as KEYs
kw = quotes.copy()
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
self._re_fk_constraint = _re_compile(
r" "
r"CONSTRAINT +"
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"FOREIGN KEY +"
r"\((?P<local>[^\)]+?)\) REFERENCES +"
r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s"
r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +"
r"\((?P<foreign>(?:%(iq)s[^%(fq)s]+%(fq)s(?: *, *)?)+)\)"
r"(?: +(?P<match>MATCH \w+))?"
r"(?: +ON DELETE (?P<ondelete>%(on)s))?"
r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw
)
# CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)'
# testing on MariaDB 10.2 shows that the CHECK constraint
# is returned on a line by itself, so to match without worrying
# about parenthesis in the expression we go to the end of the line
self._re_ck_constraint = _re_compile(
r" "
r"CONSTRAINT +"
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"CHECK +"
r"\((?P<sqltext>.+)\),?" % kw
)
# PARTITION
#
# punt!
self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)")
# Table-level options (COLLATE, ENGINE, etc.)
# Do the string options first, since they have quoted
# strings we need to get rid of.
for option in _options_of_type_string:
self._add_option_string(option)
for option in (
"ENGINE",
"TYPE",
"AUTO_INCREMENT",
"AVG_ROW_LENGTH",
"CHARACTER SET",
"DEFAULT CHARSET",
"CHECKSUM",
"COLLATE",
"DELAY_KEY_WRITE",
"INSERT_METHOD",
"MAX_ROWS",
"MIN_ROWS",
"PACK_KEYS",
"ROW_FORMAT",
"KEY_BLOCK_SIZE",
"STATS_SAMPLE_PAGES",
):
self._add_option_word(option)
for option in (
"PARTITION BY",
"SUBPARTITION BY",
"PARTITIONS",
"SUBPARTITIONS",
"PARTITION",
"SUBPARTITION",
):
self._add_partition_option_word(option)
self._add_option_regex("UNION", r"\([^\)]+\)")
self._add_option_regex("TABLESPACE", r".*? STORAGE DISK")
self._add_option_regex(
"RAID_TYPE",
r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+",
)
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
def _add_option_string(self, directive):
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex, cleanup_text))
def _add_option_word(self, directive):
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex))
def _add_partition_option_word(self, directive):
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
re.escape(directive),
self._optional_equals,
)
elif directive == "SUBPARTITIONS" or directive == "PARTITIONS":
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\d+)" % (
re.escape(directive),
self._optional_equals,
)
else:
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
self._pr_options.append(_pr_compile(regex))
def _add_option_regex(self, directive, regex):
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
re.escape(directive),
self._optional_equals,
regex,
)
self._pr_options.append(_pr_compile(regex))
_options_of_type_string = (
"COMMENT",
"DATA DIRECTORY",
"INDEX DIRECTORY",
"PASSWORD",
"CONNECTION",
)
def _pr_compile(regex, cleanup=None):
"""Prepare a 2-tuple of compiled regex and callable."""
return (_re_compile(regex), cleanup)
def _re_compile(regex):
"""Compile a string to regex, I and UNICODE."""
return re.compile(regex, re.I | re.UNICODE)
def _strip_values(values):
"Strip reflected values quotes"
strip_values = []
for a in values:
if a[0:1] == '"' or a[0:1] == "'":
# strip enclosing quotes and unquote interior
a = a[1:-1].replace(a[0] * 2, a[0])
strip_values.append(a)
return strip_values
def cleanup_text(raw_text: str) -> str:
if "\\" in raw_text:
raw_text = re.sub(
_control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
)
return raw_text.replace("''", "'")
_control_char_map = {
"\\\\": "\\",
"\\0": "\0",
"\\a": "\a",
"\\b": "\b",
"\\t": "\t",
"\\n": "\n",
"\\v": "\v",
"\\f": "\f",
"\\r": "\r",
# '\\e':'\e',
}
_control_char_regexp = re.compile(
"|".join(re.escape(k) for k in _control_char_map)
)

View file

@ -0,0 +1,567 @@
# dialects/mysql/reserved_words.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# generated using:
# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9
#
# https://mariadb.com/kb/en/reserved-words/
# includes: Reserved Words, Oracle Mode (separate set unioned)
# excludes: Exceptions, Function Names
# mypy: ignore-errors
RESERVED_WORDS_MARIADB = {
"accessible",
"add",
"all",
"alter",
"analyze",
"and",
"as",
"asc",
"asensitive",
"before",
"between",
"bigint",
"binary",
"blob",
"both",
"by",
"call",
"cascade",
"case",
"change",
"char",
"character",
"check",
"collate",
"column",
"condition",
"constraint",
"continue",
"convert",
"create",
"cross",
"current_date",
"current_role",
"current_time",
"current_timestamp",
"current_user",
"cursor",
"database",
"databases",
"day_hour",
"day_microsecond",
"day_minute",
"day_second",
"dec",
"decimal",
"declare",
"default",
"delayed",
"delete",
"desc",
"describe",
"deterministic",
"distinct",
"distinctrow",
"div",
"do_domain_ids",
"double",
"drop",
"dual",
"each",
"else",
"elseif",
"enclosed",
"escaped",
"except",
"exists",
"exit",
"explain",
"false",
"fetch",
"float",
"float4",
"float8",
"for",
"force",
"foreign",
"from",
"fulltext",
"general",
"grant",
"group",
"having",
"high_priority",
"hour_microsecond",
"hour_minute",
"hour_second",
"if",
"ignore",
"ignore_domain_ids",
"ignore_server_ids",
"in",
"index",
"infile",
"inner",
"inout",
"insensitive",
"insert",
"int",
"int1",
"int2",
"int3",
"int4",
"int8",
"integer",
"intersect",
"interval",
"into",
"is",
"iterate",
"join",
"key",
"keys",
"kill",
"leading",
"leave",
"left",
"like",
"limit",
"linear",
"lines",
"load",
"localtime",
"localtimestamp",
"lock",
"long",
"longblob",
"longtext",
"loop",
"low_priority",
"master_heartbeat_period",
"master_ssl_verify_server_cert",
"match",
"maxvalue",
"mediumblob",
"mediumint",
"mediumtext",
"middleint",
"minute_microsecond",
"minute_second",
"mod",
"modifies",
"natural",
"no_write_to_binlog",
"not",
"null",
"numeric",
"offset",
"on",
"optimize",
"option",
"optionally",
"or",
"order",
"out",
"outer",
"outfile",
"over",
"page_checksum",
"parse_vcol_expr",
"partition",
"position",
"precision",
"primary",
"procedure",
"purge",
"range",
"read",
"read_write",
"reads",
"real",
"recursive",
"ref_system_id",
"references",
"regexp",
"release",
"rename",
"repeat",
"replace",
"require",
"resignal",
"restrict",
"return",
"returning",
"revoke",
"right",
"rlike",
"rows",
"row_number",
"schema",
"schemas",
"second_microsecond",
"select",
"sensitive",
"separator",
"set",
"show",
"signal",
"slow",
"smallint",
"spatial",
"specific",
"sql",
"sql_big_result",
"sql_calc_found_rows",
"sql_small_result",
"sqlexception",
"sqlstate",
"sqlwarning",
"ssl",
"starting",
"stats_auto_recalc",
"stats_persistent",
"stats_sample_pages",
"straight_join",
"table",
"terminated",
"then",
"tinyblob",
"tinyint",
"tinytext",
"to",
"trailing",
"trigger",
"true",
"undo",
"union",
"unique",
"unlock",
"unsigned",
"update",
"usage",
"use",
"using",
"utc_date",
"utc_time",
"utc_timestamp",
"values",
"varbinary",
"varchar",
"varcharacter",
"varying",
"when",
"where",
"while",
"window",
"with",
"write",
"xor",
"year_month",
"zerofill",
}.union(
{
"body",
"elsif",
"goto",
"history",
"others",
"package",
"period",
"raise",
"rowtype",
"system",
"system_time",
"versioning",
"without",
}
)
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
# includes: MySQL x.0 Keywords and Reserved Words
# excludes: MySQL x.0 New Keywords and Reserved Words,
# MySQL x.0 Removed Keywords and Reserved Words
RESERVED_WORDS_MYSQL = {
"accessible",
"add",
"admin",
"all",
"alter",
"analyze",
"and",
"array",
"as",
"asc",
"asensitive",
"before",
"between",
"bigint",
"binary",
"blob",
"both",
"by",
"call",
"cascade",
"case",
"change",
"char",
"character",
"check",
"collate",
"column",
"condition",
"constraint",
"continue",
"convert",
"create",
"cross",
"cube",
"cume_dist",
"current_date",
"current_time",
"current_timestamp",
"current_user",
"cursor",
"database",
"databases",
"day_hour",
"day_microsecond",
"day_minute",
"day_second",
"dec",
"decimal",
"declare",
"default",
"delayed",
"delete",
"dense_rank",
"desc",
"describe",
"deterministic",
"distinct",
"distinctrow",
"div",
"double",
"drop",
"dual",
"each",
"else",
"elseif",
"empty",
"enclosed",
"escaped",
"except",
"exists",
"exit",
"explain",
"false",
"fetch",
"first_value",
"float",
"float4",
"float8",
"for",
"force",
"foreign",
"from",
"fulltext",
"function",
"general",
"generated",
"get",
"get_master_public_key",
"grant",
"group",
"grouping",
"groups",
"having",
"high_priority",
"hour_microsecond",
"hour_minute",
"hour_second",
"if",
"ignore",
"ignore_server_ids",
"in",
"index",
"infile",
"inner",
"inout",
"insensitive",
"insert",
"int",
"int1",
"int2",
"int3",
"int4",
"int8",
"integer",
"interval",
"into",
"io_after_gtids",
"io_before_gtids",
"is",
"iterate",
"join",
"json_table",
"key",
"keys",
"kill",
"lag",
"last_value",
"lateral",
"lead",
"leading",
"leave",
"left",
"like",
"limit",
"linear",
"lines",
"load",
"localtime",
"localtimestamp",
"lock",
"long",
"longblob",
"longtext",
"loop",
"low_priority",
"master_bind",
"master_heartbeat_period",
"master_ssl_verify_server_cert",
"match",
"maxvalue",
"mediumblob",
"mediumint",
"mediumtext",
"member",
"middleint",
"minute_microsecond",
"minute_second",
"mod",
"modifies",
"natural",
"no_write_to_binlog",
"not",
"nth_value",
"ntile",
"null",
"numeric",
"of",
"on",
"optimize",
"optimizer_costs",
"option",
"optionally",
"or",
"order",
"out",
"outer",
"outfile",
"over",
"parse_gcol_expr",
"partition",
"percent_rank",
"persist",
"persist_only",
"precision",
"primary",
"procedure",
"purge",
"range",
"rank",
"read",
"read_write",
"reads",
"real",
"recursive",
"references",
"regexp",
"release",
"rename",
"repeat",
"replace",
"require",
"resignal",
"restrict",
"return",
"revoke",
"right",
"rlike",
"role",
"row",
"row_number",
"rows",
"schema",
"schemas",
"second_microsecond",
"select",
"sensitive",
"separator",
"set",
"show",
"signal",
"slow",
"smallint",
"spatial",
"specific",
"sql",
"sql_after_gtids",
"sql_before_gtids",
"sql_big_result",
"sql_calc_found_rows",
"sql_small_result",
"sqlexception",
"sqlstate",
"sqlwarning",
"ssl",
"starting",
"stored",
"straight_join",
"system",
"table",
"terminated",
"then",
"tinyblob",
"tinyint",
"tinytext",
"to",
"trailing",
"trigger",
"true",
"undo",
"union",
"unique",
"unlock",
"unsigned",
"update",
"usage",
"use",
"using",
"utc_date",
"utc_time",
"utc_timestamp",
"values",
"varbinary",
"varchar",
"varcharacter",
"varying",
"virtual",
"when",
"where",
"while",
"window",
"with",
"write",
"xor",
"year_month",
"zerofill",
}

View file

@ -0,0 +1,773 @@
# dialects/mysql/types.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import datetime
from ... import exc
from ... import util
from ...sql import sqltypes
class _NumericType:
"""Base for MySQL numeric types.
This is the base both for NUMERIC as well as INTEGER, hence
it's a mixin.
"""
def __init__(self, unsigned=False, zerofill=False, **kw):
self.unsigned = unsigned
self.zerofill = zerofill
super().__init__(**kw)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_NumericType, sqltypes.Numeric]
)
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and (
(precision is None and scale is not None)
or (precision is not None and scale is None)
):
raise exc.ArgumentError(
"You must specify both precision and scale or omit "
"both altogether."
)
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
self.scale = scale
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
)
class _IntegerType(_NumericType, sqltypes.Integer):
def __init__(self, display_width=None, **kw):
self.display_width = display_width
super().__init__(**kw)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
)
class _StringType(sqltypes.String):
"""Base for MySQL string types."""
def __init__(
self,
charset=None,
collation=None,
ascii=False, # noqa
binary=False,
unicode=False,
national=False,
**kw,
):
self.charset = charset
# allow collate= or collation=
kw.setdefault("collation", kw.pop("collate", collation))
self.ascii = ascii
self.unicode = unicode
self.binary = binary
self.national = national
super().__init__(**kw)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_StringType, sqltypes.String]
)
class _MatchType(sqltypes.Float, sqltypes.MatchType):
def __init__(self, **kw):
# TODO: float arguments?
sqltypes.Float.__init__(self)
sqltypes.MatchType.__init__(self)
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
__visit_name__ = "NUMERIC"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
__visit_name__ = "DECIMAL"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class DOUBLE(_FloatType, sqltypes.DOUBLE):
"""MySQL DOUBLE type."""
__visit_name__ = "DOUBLE"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
.. note::
The :class:`.DOUBLE` type by default converts from float
to Decimal, using a truncation that defaults to 10 digits.
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
to change this scale, or ``asdecimal=False`` to return values
directly as Python floating points.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class REAL(_FloatType, sqltypes.REAL):
"""MySQL REAL type."""
__visit_name__ = "REAL"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
.. note::
The :class:`.REAL` type by default converts from float
to Decimal, using a truncation that defaults to 10 digits.
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
to change this scale, or ``asdecimal=False`` to return values
directly as Python floating points.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
__visit_name__ = "FLOAT"
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
def bind_processor(self, dialect):
return None
class INTEGER(_IntegerType, sqltypes.INTEGER):
"""MySQL INTEGER type."""
__visit_name__ = "INTEGER"
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(display_width=display_width, **kw)
class BIGINT(_IntegerType, sqltypes.BIGINT):
"""MySQL BIGINTEGER type."""
__visit_name__ = "BIGINT"
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(display_width=display_width, **kw)
class MEDIUMINT(_IntegerType):
"""MySQL MEDIUMINTEGER type."""
__visit_name__ = "MEDIUMINT"
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(display_width=display_width, **kw)
class TINYINT(_IntegerType):
"""MySQL TINYINT type."""
__visit_name__ = "TINYINT"
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(display_width=display_width, **kw)
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
__visit_name__ = "SMALLINT"
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super().__init__(display_width=display_width, **kw)
class BIT(sqltypes.TypeEngine):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
MSTinyInteger() type.
"""
__visit_name__ = "BIT"
def __init__(self, length=None):
"""Construct a BIT.
:param length: Optional, number of bits.
"""
self.length = length
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
already do this, so this logic should be moved to those dialects.
"""
def process(value):
if value is not None:
v = 0
for i in value:
if not isinstance(i, int):
i = ord(i) # convert byte to int on Python 2
v = v << 8 | i
return v
return value
return process
class TIME(sqltypes.TIME):
"""MySQL TIME type."""
__visit_name__ = "TIME"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIME type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the TIME type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
"""
super().__init__(timezone=timezone)
self.fsp = fsp
def result_processor(self, dialect, coltype):
time = datetime.time
def process(value):
# convert from a timedelta value
if value is not None:
microseconds = value.microseconds
seconds = value.seconds
minutes = seconds // 60
return time(
minutes // 60,
minutes % 60,
seconds - minutes * 60,
microsecond=microseconds,
)
else:
return None
return process
class TIMESTAMP(sqltypes.TIMESTAMP):
"""MySQL TIMESTAMP type."""
__visit_name__ = "TIMESTAMP"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIMESTAMP type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6.4 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the TIMESTAMP type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
"""
super().__init__(timezone=timezone)
self.fsp = fsp
class DATETIME(sqltypes.DATETIME):
"""MySQL DATETIME type."""
__visit_name__ = "DATETIME"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL DATETIME type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6.4 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the DATETIME type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
"""
super().__init__(timezone=timezone)
self.fsp = fsp
class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
__visit_name__ = "YEAR"
def __init__(self, display_width=None):
self.display_width = display_width
class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for text up to 2^16 characters."""
__visit_name__ = "TEXT"
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` characters.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super().__init__(length=length, **kw)
class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
__visit_name__ = "TINYTEXT"
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super().__init__(**kwargs)
class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
__visit_name__ = "MEDIUMTEXT"
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super().__init__(**kwargs)
class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
__visit_name__ = "LONGTEXT"
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super().__init__(**kwargs)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MySQL VARCHAR type, for variable-length character data."""
__visit_name__ = "VARCHAR"
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super().__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""MySQL CHAR type, for fixed-length character data."""
__visit_name__ = "CHAR"
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
super().__init__(length=length, **kwargs)
@classmethod
def _adapt_string_for_cast(self, type_):
# copy the given string type into a CHAR
# for the purposes of rendering a CAST expression
type_ = sqltypes.to_instance(type_)
if isinstance(type_, sqltypes.CHAR):
return type_
elif isinstance(type_, _StringType):
return CHAR(
length=type_.length,
charset=type_.charset,
collation=type_.collation,
ascii=type_.ascii,
binary=type_.binary,
unicode=type_.unicode,
national=False, # not supported in CAST
)
else:
return CHAR(length=type_.length)
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
"""MySQL NVARCHAR type.
For variable-length character data in the server's configured national
character set.
"""
__visit_name__ = "NVARCHAR"
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
kwargs["national"] = True
super().__init__(length=length, **kwargs)
class NCHAR(_StringType, sqltypes.NCHAR):
"""MySQL NCHAR type.
For fixed-length character data in the server's configured national
character set.
"""
__visit_name__ = "NCHAR"
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
kwargs["national"] = True
super().__init__(length=length, **kwargs)
class TINYBLOB(sqltypes._Binary):
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
__visit_name__ = "TINYBLOB"
class MEDIUMBLOB(sqltypes._Binary):
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
__visit_name__ = "MEDIUMBLOB"
class LONGBLOB(sqltypes._Binary):
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
__visit_name__ = "LONGBLOB"

View file

@ -0,0 +1,67 @@
# dialects/oracle/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from types import ModuleType
from . import base # noqa
from . import cx_oracle # noqa
from . import oracledb # noqa
from .base import BFILE
from .base import BINARY_DOUBLE
from .base import BINARY_FLOAT
from .base import BLOB
from .base import CHAR
from .base import CLOB
from .base import DATE
from .base import DOUBLE_PRECISION
from .base import FLOAT
from .base import INTERVAL
from .base import LONG
from .base import NCHAR
from .base import NCLOB
from .base import NUMBER
from .base import NVARCHAR
from .base import NVARCHAR2
from .base import RAW
from .base import REAL
from .base import ROWID
from .base import TIMESTAMP
from .base import VARCHAR
from .base import VARCHAR2
# Alias oracledb also as oracledb_async
oracledb_async = type(
"oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async}
)
base.dialect = dialect = cx_oracle.dialect
__all__ = (
"VARCHAR",
"NVARCHAR",
"CHAR",
"NCHAR",
"DATE",
"NUMBER",
"BLOB",
"BFILE",
"CLOB",
"NCLOB",
"TIMESTAMP",
"RAW",
"FLOAT",
"DOUBLE_PRECISION",
"BINARY_DOUBLE",
"BINARY_FLOAT",
"LONG",
"dialect",
"INTERVAL",
"VARCHAR2",
"NVARCHAR2",
"ROWID",
"REAL",
)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,507 @@
# dialects/oracle/dictionary.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .types import DATE
from .types import LONG
from .types import NUMBER
from .types import RAW
from .types import VARCHAR2
from ... import Column
from ... import MetaData
from ... import Table
from ... import table
from ...sql.sqltypes import CHAR
# constants
DB_LINK_PLACEHOLDER = "__$sa_dblink$__"
# tables
dual = table("dual")
dictionary_meta = MetaData()
# NOTE: all the dictionary_meta are aliases because oracle does not like
# using the full table@dblink for every column in query, and complains with
# ORA-00960: ambiguous column naming in select list
all_tables = Table(
"all_tables" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("tablespace_name", VARCHAR2(30)),
Column("cluster_name", VARCHAR2(128)),
Column("iot_name", VARCHAR2(128)),
Column("status", VARCHAR2(8)),
Column("pct_free", NUMBER),
Column("pct_used", NUMBER),
Column("ini_trans", NUMBER),
Column("max_trans", NUMBER),
Column("initial_extent", NUMBER),
Column("next_extent", NUMBER),
Column("min_extents", NUMBER),
Column("max_extents", NUMBER),
Column("pct_increase", NUMBER),
Column("freelists", NUMBER),
Column("freelist_groups", NUMBER),
Column("logging", VARCHAR2(3)),
Column("backed_up", VARCHAR2(1)),
Column("num_rows", NUMBER),
Column("blocks", NUMBER),
Column("empty_blocks", NUMBER),
Column("avg_space", NUMBER),
Column("chain_cnt", NUMBER),
Column("avg_row_len", NUMBER),
Column("avg_space_freelist_blocks", NUMBER),
Column("num_freelist_blocks", NUMBER),
Column("degree", VARCHAR2(10)),
Column("instances", VARCHAR2(10)),
Column("cache", VARCHAR2(5)),
Column("table_lock", VARCHAR2(8)),
Column("sample_size", NUMBER),
Column("last_analyzed", DATE),
Column("partitioned", VARCHAR2(3)),
Column("iot_type", VARCHAR2(12)),
Column("temporary", VARCHAR2(1)),
Column("secondary", VARCHAR2(1)),
Column("nested", VARCHAR2(3)),
Column("buffer_pool", VARCHAR2(7)),
Column("flash_cache", VARCHAR2(7)),
Column("cell_flash_cache", VARCHAR2(7)),
Column("row_movement", VARCHAR2(8)),
Column("global_stats", VARCHAR2(3)),
Column("user_stats", VARCHAR2(3)),
Column("duration", VARCHAR2(15)),
Column("skip_corrupt", VARCHAR2(8)),
Column("monitoring", VARCHAR2(3)),
Column("cluster_owner", VARCHAR2(128)),
Column("dependencies", VARCHAR2(8)),
Column("compression", VARCHAR2(8)),
Column("compress_for", VARCHAR2(30)),
Column("dropped", VARCHAR2(3)),
Column("read_only", VARCHAR2(3)),
Column("segment_created", VARCHAR2(3)),
Column("result_cache", VARCHAR2(7)),
Column("clustering", VARCHAR2(3)),
Column("activity_tracking", VARCHAR2(23)),
Column("dml_timestamp", VARCHAR2(25)),
Column("has_identity", VARCHAR2(3)),
Column("container_data", VARCHAR2(3)),
Column("inmemory", VARCHAR2(8)),
Column("inmemory_priority", VARCHAR2(8)),
Column("inmemory_distribute", VARCHAR2(15)),
Column("inmemory_compression", VARCHAR2(17)),
Column("inmemory_duplicate", VARCHAR2(13)),
Column("default_collation", VARCHAR2(100)),
Column("duplicated", VARCHAR2(1)),
Column("sharded", VARCHAR2(1)),
Column("externally_sharded", VARCHAR2(1)),
Column("externally_duplicated", VARCHAR2(1)),
Column("external", VARCHAR2(3)),
Column("hybrid", VARCHAR2(3)),
Column("cellmemory", VARCHAR2(24)),
Column("containers_default", VARCHAR2(3)),
Column("container_map", VARCHAR2(3)),
Column("extended_data_link", VARCHAR2(3)),
Column("extended_data_link_map", VARCHAR2(3)),
Column("inmemory_service", VARCHAR2(12)),
Column("inmemory_service_name", VARCHAR2(1000)),
Column("container_map_object", VARCHAR2(3)),
Column("memoptimize_read", VARCHAR2(8)),
Column("memoptimize_write", VARCHAR2(8)),
Column("has_sensitive_column", VARCHAR2(3)),
Column("admit_null", VARCHAR2(3)),
Column("data_link_dml_enabled", VARCHAR2(3)),
Column("logical_replication", VARCHAR2(8)),
).alias("a_tables")
all_views = Table(
"all_views" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("view_name", VARCHAR2(128), nullable=False),
Column("text_length", NUMBER),
Column("text", LONG),
Column("text_vc", VARCHAR2(4000)),
Column("type_text_length", NUMBER),
Column("type_text", VARCHAR2(4000)),
Column("oid_text_length", NUMBER),
Column("oid_text", VARCHAR2(4000)),
Column("view_type_owner", VARCHAR2(128)),
Column("view_type", VARCHAR2(128)),
Column("superview_name", VARCHAR2(128)),
Column("editioning_view", VARCHAR2(1)),
Column("read_only", VARCHAR2(1)),
Column("container_data", VARCHAR2(1)),
Column("bequeath", VARCHAR2(12)),
Column("origin_con_id", VARCHAR2(256)),
Column("default_collation", VARCHAR2(100)),
Column("containers_default", VARCHAR2(3)),
Column("container_map", VARCHAR2(3)),
Column("extended_data_link", VARCHAR2(3)),
Column("extended_data_link_map", VARCHAR2(3)),
Column("has_sensitive_column", VARCHAR2(3)),
Column("admit_null", VARCHAR2(3)),
Column("pdb_local_only", VARCHAR2(3)),
).alias("a_views")
all_sequences = Table(
"all_sequences" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("sequence_owner", VARCHAR2(128), nullable=False),
Column("sequence_name", VARCHAR2(128), nullable=False),
Column("min_value", NUMBER),
Column("max_value", NUMBER),
Column("increment_by", NUMBER, nullable=False),
Column("cycle_flag", VARCHAR2(1)),
Column("order_flag", VARCHAR2(1)),
Column("cache_size", NUMBER, nullable=False),
Column("last_number", NUMBER, nullable=False),
Column("scale_flag", VARCHAR2(1)),
Column("extend_flag", VARCHAR2(1)),
Column("sharded_flag", VARCHAR2(1)),
Column("session_flag", VARCHAR2(1)),
Column("keep_value", VARCHAR2(1)),
).alias("a_sequences")
all_users = Table(
"all_users" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("username", VARCHAR2(128), nullable=False),
Column("user_id", NUMBER, nullable=False),
Column("created", DATE, nullable=False),
Column("common", VARCHAR2(3)),
Column("oracle_maintained", VARCHAR2(1)),
Column("inherited", VARCHAR2(3)),
Column("default_collation", VARCHAR2(100)),
Column("implicit", VARCHAR2(3)),
Column("all_shard", VARCHAR2(3)),
Column("external_shard", VARCHAR2(3)),
).alias("a_users")
all_mviews = Table(
"all_mviews" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("mview_name", VARCHAR2(128), nullable=False),
Column("container_name", VARCHAR2(128), nullable=False),
Column("query", LONG),
Column("query_len", NUMBER(38)),
Column("updatable", VARCHAR2(1)),
Column("update_log", VARCHAR2(128)),
Column("master_rollback_seg", VARCHAR2(128)),
Column("master_link", VARCHAR2(128)),
Column("rewrite_enabled", VARCHAR2(1)),
Column("rewrite_capability", VARCHAR2(9)),
Column("refresh_mode", VARCHAR2(6)),
Column("refresh_method", VARCHAR2(8)),
Column("build_mode", VARCHAR2(9)),
Column("fast_refreshable", VARCHAR2(18)),
Column("last_refresh_type", VARCHAR2(8)),
Column("last_refresh_date", DATE),
Column("last_refresh_end_time", DATE),
Column("staleness", VARCHAR2(19)),
Column("after_fast_refresh", VARCHAR2(19)),
Column("unknown_prebuilt", VARCHAR2(1)),
Column("unknown_plsql_func", VARCHAR2(1)),
Column("unknown_external_table", VARCHAR2(1)),
Column("unknown_consider_fresh", VARCHAR2(1)),
Column("unknown_import", VARCHAR2(1)),
Column("unknown_trusted_fd", VARCHAR2(1)),
Column("compile_state", VARCHAR2(19)),
Column("use_no_index", VARCHAR2(1)),
Column("stale_since", DATE),
Column("num_pct_tables", NUMBER),
Column("num_fresh_pct_regions", NUMBER),
Column("num_stale_pct_regions", NUMBER),
Column("segment_created", VARCHAR2(3)),
Column("evaluation_edition", VARCHAR2(128)),
Column("unusable_before", VARCHAR2(128)),
Column("unusable_beginning", VARCHAR2(128)),
Column("default_collation", VARCHAR2(100)),
Column("on_query_computation", VARCHAR2(1)),
Column("auto", VARCHAR2(3)),
).alias("a_mviews")
all_tab_identity_cols = Table(
"all_tab_identity_cols" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("column_name", VARCHAR2(128), nullable=False),
Column("generation_type", VARCHAR2(10)),
Column("sequence_name", VARCHAR2(128), nullable=False),
Column("identity_options", VARCHAR2(298)),
).alias("a_tab_identity_cols")
all_tab_cols = Table(
"all_tab_cols" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("column_name", VARCHAR2(128), nullable=False),
Column("data_type", VARCHAR2(128)),
Column("data_type_mod", VARCHAR2(3)),
Column("data_type_owner", VARCHAR2(128)),
Column("data_length", NUMBER, nullable=False),
Column("data_precision", NUMBER),
Column("data_scale", NUMBER),
Column("nullable", VARCHAR2(1)),
Column("column_id", NUMBER),
Column("default_length", NUMBER),
Column("data_default", LONG),
Column("num_distinct", NUMBER),
Column("low_value", RAW(1000)),
Column("high_value", RAW(1000)),
Column("density", NUMBER),
Column("num_nulls", NUMBER),
Column("num_buckets", NUMBER),
Column("last_analyzed", DATE),
Column("sample_size", NUMBER),
Column("character_set_name", VARCHAR2(44)),
Column("char_col_decl_length", NUMBER),
Column("global_stats", VARCHAR2(3)),
Column("user_stats", VARCHAR2(3)),
Column("avg_col_len", NUMBER),
Column("char_length", NUMBER),
Column("char_used", VARCHAR2(1)),
Column("v80_fmt_image", VARCHAR2(3)),
Column("data_upgraded", VARCHAR2(3)),
Column("hidden_column", VARCHAR2(3)),
Column("virtual_column", VARCHAR2(3)),
Column("segment_column_id", NUMBER),
Column("internal_column_id", NUMBER, nullable=False),
Column("histogram", VARCHAR2(15)),
Column("qualified_col_name", VARCHAR2(4000)),
Column("user_generated", VARCHAR2(3)),
Column("default_on_null", VARCHAR2(3)),
Column("identity_column", VARCHAR2(3)),
Column("evaluation_edition", VARCHAR2(128)),
Column("unusable_before", VARCHAR2(128)),
Column("unusable_beginning", VARCHAR2(128)),
Column("collation", VARCHAR2(100)),
Column("collated_column_id", NUMBER),
).alias("a_tab_cols")
all_tab_comments = Table(
"all_tab_comments" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("table_type", VARCHAR2(11)),
Column("comments", VARCHAR2(4000)),
Column("origin_con_id", NUMBER),
).alias("a_tab_comments")
all_col_comments = Table(
"all_col_comments" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("column_name", VARCHAR2(128), nullable=False),
Column("comments", VARCHAR2(4000)),
Column("origin_con_id", NUMBER),
).alias("a_col_comments")
all_mview_comments = Table(
"all_mview_comments" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("mview_name", VARCHAR2(128), nullable=False),
Column("comments", VARCHAR2(4000)),
).alias("a_mview_comments")
all_ind_columns = Table(
"all_ind_columns" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("index_owner", VARCHAR2(128), nullable=False),
Column("index_name", VARCHAR2(128), nullable=False),
Column("table_owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("column_name", VARCHAR2(4000)),
Column("column_position", NUMBER, nullable=False),
Column("column_length", NUMBER, nullable=False),
Column("char_length", NUMBER),
Column("descend", VARCHAR2(4)),
Column("collated_column_id", NUMBER),
).alias("a_ind_columns")
all_indexes = Table(
"all_indexes" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("index_name", VARCHAR2(128), nullable=False),
Column("index_type", VARCHAR2(27)),
Column("table_owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("table_type", CHAR(11)),
Column("uniqueness", VARCHAR2(9)),
Column("compression", VARCHAR2(13)),
Column("prefix_length", NUMBER),
Column("tablespace_name", VARCHAR2(30)),
Column("ini_trans", NUMBER),
Column("max_trans", NUMBER),
Column("initial_extent", NUMBER),
Column("next_extent", NUMBER),
Column("min_extents", NUMBER),
Column("max_extents", NUMBER),
Column("pct_increase", NUMBER),
Column("pct_threshold", NUMBER),
Column("include_column", NUMBER),
Column("freelists", NUMBER),
Column("freelist_groups", NUMBER),
Column("pct_free", NUMBER),
Column("logging", VARCHAR2(3)),
Column("blevel", NUMBER),
Column("leaf_blocks", NUMBER),
Column("distinct_keys", NUMBER),
Column("avg_leaf_blocks_per_key", NUMBER),
Column("avg_data_blocks_per_key", NUMBER),
Column("clustering_factor", NUMBER),
Column("status", VARCHAR2(8)),
Column("num_rows", NUMBER),
Column("sample_size", NUMBER),
Column("last_analyzed", DATE),
Column("degree", VARCHAR2(40)),
Column("instances", VARCHAR2(40)),
Column("partitioned", VARCHAR2(3)),
Column("temporary", VARCHAR2(1)),
Column("generated", VARCHAR2(1)),
Column("secondary", VARCHAR2(1)),
Column("buffer_pool", VARCHAR2(7)),
Column("flash_cache", VARCHAR2(7)),
Column("cell_flash_cache", VARCHAR2(7)),
Column("user_stats", VARCHAR2(3)),
Column("duration", VARCHAR2(15)),
Column("pct_direct_access", NUMBER),
Column("ityp_owner", VARCHAR2(128)),
Column("ityp_name", VARCHAR2(128)),
Column("parameters", VARCHAR2(1000)),
Column("global_stats", VARCHAR2(3)),
Column("domidx_status", VARCHAR2(12)),
Column("domidx_opstatus", VARCHAR2(6)),
Column("funcidx_status", VARCHAR2(8)),
Column("join_index", VARCHAR2(3)),
Column("iot_redundant_pkey_elim", VARCHAR2(3)),
Column("dropped", VARCHAR2(3)),
Column("visibility", VARCHAR2(9)),
Column("domidx_management", VARCHAR2(14)),
Column("segment_created", VARCHAR2(3)),
Column("orphaned_entries", VARCHAR2(3)),
Column("indexing", VARCHAR2(7)),
Column("auto", VARCHAR2(3)),
).alias("a_indexes")
all_ind_expressions = Table(
"all_ind_expressions" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("index_owner", VARCHAR2(128), nullable=False),
Column("index_name", VARCHAR2(128), nullable=False),
Column("table_owner", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("column_expression", LONG),
Column("column_position", NUMBER, nullable=False),
).alias("a_ind_expressions")
all_constraints = Table(
"all_constraints" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128)),
Column("constraint_name", VARCHAR2(128)),
Column("constraint_type", VARCHAR2(1)),
Column("table_name", VARCHAR2(128)),
Column("search_condition", LONG),
Column("search_condition_vc", VARCHAR2(4000)),
Column("r_owner", VARCHAR2(128)),
Column("r_constraint_name", VARCHAR2(128)),
Column("delete_rule", VARCHAR2(9)),
Column("status", VARCHAR2(8)),
Column("deferrable", VARCHAR2(14)),
Column("deferred", VARCHAR2(9)),
Column("validated", VARCHAR2(13)),
Column("generated", VARCHAR2(14)),
Column("bad", VARCHAR2(3)),
Column("rely", VARCHAR2(4)),
Column("last_change", DATE),
Column("index_owner", VARCHAR2(128)),
Column("index_name", VARCHAR2(128)),
Column("invalid", VARCHAR2(7)),
Column("view_related", VARCHAR2(14)),
Column("origin_con_id", VARCHAR2(256)),
).alias("a_constraints")
all_cons_columns = Table(
"all_cons_columns" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("constraint_name", VARCHAR2(128), nullable=False),
Column("table_name", VARCHAR2(128), nullable=False),
Column("column_name", VARCHAR2(4000)),
Column("position", NUMBER),
).alias("a_cons_columns")
# TODO figure out if it's still relevant, since there is no mention from here
# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html
# original note:
# using user_db_links here since all_db_links appears
# to have more restricted permissions.
# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
# will need to hear from more users if we are doing
# the right thing here. See [ticket:2619]
all_db_links = Table(
"all_db_links" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("db_link", VARCHAR2(128), nullable=False),
Column("username", VARCHAR2(128)),
Column("host", VARCHAR2(2000)),
Column("created", DATE, nullable=False),
Column("hidden", VARCHAR2(3)),
Column("shard_internal", VARCHAR2(3)),
Column("valid", VARCHAR2(3)),
Column("intra_cdb", VARCHAR2(3)),
).alias("a_db_links")
all_synonyms = Table(
"all_synonyms" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128)),
Column("synonym_name", VARCHAR2(128)),
Column("table_owner", VARCHAR2(128)),
Column("table_name", VARCHAR2(128)),
Column("db_link", VARCHAR2(128)),
Column("origin_con_id", VARCHAR2(256)),
).alias("a_synonyms")
all_objects = Table(
"all_objects" + DB_LINK_PLACEHOLDER,
dictionary_meta,
Column("owner", VARCHAR2(128), nullable=False),
Column("object_name", VARCHAR2(128), nullable=False),
Column("subobject_name", VARCHAR2(128)),
Column("object_id", NUMBER, nullable=False),
Column("data_object_id", NUMBER),
Column("object_type", VARCHAR2(23)),
Column("created", DATE, nullable=False),
Column("last_ddl_time", DATE, nullable=False),
Column("timestamp", VARCHAR2(19)),
Column("status", VARCHAR2(7)),
Column("temporary", VARCHAR2(1)),
Column("generated", VARCHAR2(1)),
Column("secondary", VARCHAR2(1)),
Column("namespace", NUMBER, nullable=False),
Column("edition_name", VARCHAR2(128)),
Column("sharing", VARCHAR2(13)),
Column("editionable", VARCHAR2(1)),
Column("oracle_maintained", VARCHAR2(1)),
Column("application", VARCHAR2(1)),
Column("default_collation", VARCHAR2(100)),
Column("duplicated", VARCHAR2(1)),
Column("sharded", VARCHAR2(1)),
Column("created_appid", NUMBER),
Column("created_vsnid", NUMBER),
Column("modified_appid", NUMBER),
Column("modified_vsnid", NUMBER),
).alias("a_objects")

View file

@ -0,0 +1,311 @@
# dialects/oracle/oracledb.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: oracle+oracledb
:name: python-oracledb
:dbapi: oracledb
:connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
:url: https://oracle.github.io/python-oracledb/
python-oracledb is released by Oracle to supersede the cx_Oracle driver.
It is fully compatible with cx_Oracle and features both a "thin" client
mode that requires no dependencies, as well as a "thick" mode that uses
the Oracle Client Interface in the same way as cx_Oracle.
.. seealso::
:ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver
as well.
The SQLAlchemy ``oracledb`` dialect provides both a sync and an async
implementation under the same dialect name. The proper version is
selected depending on how the engine is created:
* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will
automatically select the sync version, e.g.::
from sqlalchemy import create_engine
sync_engine = create_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
* calling :func:`_asyncio.create_async_engine` with
``oracle+oracledb://...`` will automatically select the async version,
e.g.::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
The asyncio version of the dialect may also be specified explicitly using the
``oracledb_async`` suffix, as::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine("oracle+oracledb_async://scott:tiger@localhost/?service_name=XEPDB1")
.. versionadded:: 2.0.25 added support for the async version of oracledb.
Thick mode support
------------------
By default the ``python-oracledb`` is started in thin mode, that does not
require oracle client libraries to be installed in the system. The
``python-oracledb`` driver also support a "thick" mode, that behaves
similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI)
is installed.
To enable this mode, the user may call ``oracledb.init_oracle_client``
manually, or by passing the parameter ``thick_mode=True`` to
:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``,
like the ``lib_dir`` path, a dict may be passed to this parameter, as in::
engine = sa.create_engine("oracle+oracledb://...", thick_mode={
"lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app"
})
.. seealso::
https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client
.. versionadded:: 2.0.0 added support for oracledb driver.
""" # noqa
from __future__ import annotations
import collections
import re
from typing import Any
from typing import TYPE_CHECKING
from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
from ... import exc
from ... import pool
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
from ...util import asbool
from ...util import await_fallback
from ...util import await_only
if TYPE_CHECKING:
from oracledb import AsyncConnection
from oracledb import AsyncCursor
class OracleDialect_oracledb(_OracleDialect_cx_oracle):
supports_statement_cache = True
driver = "oracledb"
_min_version = (1,)
def __init__(
self,
auto_convert_lobs=True,
coerce_to_decimal=True,
arraysize=None,
encoding_errors=None,
thick_mode=None,
**kwargs,
):
super().__init__(
auto_convert_lobs,
coerce_to_decimal,
arraysize,
encoding_errors,
**kwargs,
)
if self.dbapi is not None and (
thick_mode or isinstance(thick_mode, dict)
):
kw = thick_mode if isinstance(thick_mode, dict) else {}
self.dbapi.init_oracle_client(**kw)
@classmethod
def import_dbapi(cls):
import oracledb
return oracledb
@classmethod
def is_thin_mode(cls, connection):
return connection.connection.dbapi_connection.thin
@classmethod
def get_async_dialect_cls(cls, url):
return OracleDialectAsync_oracledb
def _load_version(self, dbapi_module):
version = (0, 0, 0)
if dbapi_module is not None:
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", dbapi_module.version)
if m:
version = tuple(
int(x) for x in m.group(1, 2, 3) if x is not None
)
self.oracledb_ver = version
if (
self.oracledb_ver > (0, 0, 0)
and self.oracledb_ver < self._min_version
):
raise exc.InvalidRequestError(
f"oracledb version {self._min_version} and above are supported"
)
class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
_cursor: AsyncCursor
__slots__ = ()
@property
def outputtypehandler(self):
return self._cursor.outputtypehandler
@outputtypehandler.setter
def outputtypehandler(self, value):
self._cursor.outputtypehandler = value
def var(self, *args, **kwargs):
return self._cursor.var(*args, **kwargs)
def close(self):
self._rows.clear()
self._cursor.close()
def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
return self._cursor.setinputsizes(*args, **kwargs)
def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor:
try:
return cursor.__enter__()
except Exception as error:
self._adapt_connection._handle_exception(error)
async def _execute_async(self, operation, parameters):
# override to not use mutex, oracledb already has mutex
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
if self._cursor.description and not self.server_side:
self._rows = collections.deque(await self._cursor.fetchall())
return result
async def _executemany_async(
self,
operation,
seq_of_parameters,
):
# override to not use mutex, oracledb already has mutex
return await self._cursor.executemany(operation, seq_of_parameters)
def __enter__(self):
return self
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
self.close()
class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
_connection: AsyncConnection
__slots__ = ()
thin = True
_cursor_cls = AsyncAdapt_oracledb_cursor
_ss_cursor_cls = None
@property
def autocommit(self):
return self._connection.autocommit
@autocommit.setter
def autocommit(self, value):
self._connection.autocommit = value
@property
def outputtypehandler(self):
return self._connection.outputtypehandler
@outputtypehandler.setter
def outputtypehandler(self, value):
self._connection.outputtypehandler = value
@property
def version(self):
return self._connection.version
@property
def stmtcachesize(self):
return self._connection.stmtcachesize
@stmtcachesize.setter
def stmtcachesize(self, value):
self._connection.stmtcachesize = value
def cursor(self):
return AsyncAdapt_oracledb_cursor(self)
class AsyncAdaptFallback_oracledb_connection(
AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection
):
__slots__ = ()
class OracledbAdaptDBAPI:
def __init__(self, oracledb) -> None:
self.oracledb = oracledb
for k, v in self.oracledb.__dict__.items():
if k != "connect":
self.__dict__[k] = v
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
if asbool(async_fallback):
return AsyncAdaptFallback_oracledb_connection(
self, await_fallback(creator_fn(*arg, **kw))
)
else:
return AsyncAdapt_oracledb_connection(
self, await_only(creator_fn(*arg, **kw))
)
class OracleDialectAsync_oracledb(OracleDialect_oracledb):
is_async = True
supports_statement_cache = True
_min_version = (2,)
# thick_mode mode is not supported by asyncio, oracledb will raise
@classmethod
def import_dbapi(cls):
import oracledb
return OracledbAdaptDBAPI(oracledb)
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def get_driver_connection(self, connection):
return connection._connection
dialect = OracleDialect_oracledb
dialect_async = OracleDialectAsync_oracledb

View file

@ -0,0 +1,220 @@
# dialects/oracle/provision.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import create_engine
from ... import exc
from ... import inspect
from ...engine import url as sa_url
from ...testing.provision import configure_follower
from ...testing.provision import create_db
from ...testing.provision import drop_all_schema_objects_post_tables
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
from ...testing.provision import update_db_opts
@create_db.for_db("oracle")
def _oracle_create_db(cfg, eng, ident):
# NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
# similar, so that the default tablespace is not "system"; reflection will
# fail otherwise
with eng.begin() as conn:
conn.exec_driver_sql("create user %s identified by xe" % ident)
conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident)
conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident)
conn.exec_driver_sql("grant dba to %s" % (ident,))
conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
# these are needed to create materialized views
conn.exec_driver_sql("grant create table to %s" % ident)
conn.exec_driver_sql("grant create table to %s_ts1" % ident)
conn.exec_driver_sql("grant create table to %s_ts2" % ident)
@configure_follower.for_db("oracle")
def _oracle_configure_follower(config, ident):
config.test_schema = "%s_ts1" % ident
config.test_schema_2 = "%s_ts2" % ident
def _ora_drop_ignore(conn, dbname):
try:
conn.exec_driver_sql("drop user %s cascade" % dbname)
log.info("Reaped db: %s", dbname)
return True
except exc.DatabaseError as err:
log.warning("couldn't drop db: %s", err)
return False
@drop_all_schema_objects_pre_tables.for_db("oracle")
def _ora_drop_all_schema_objects_pre_tables(cfg, eng):
_purge_recyclebin(eng)
_purge_recyclebin(eng, cfg.test_schema)
@drop_all_schema_objects_post_tables.for_db("oracle")
def _ora_drop_all_schema_objects_post_tables(cfg, eng):
with eng.begin() as conn:
for syn in conn.dialect._get_synonyms(conn, None, None, None):
conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}")
for syn in conn.dialect._get_synonyms(
conn, cfg.test_schema, None, None
):
conn.exec_driver_sql(
f"drop synonym {cfg.test_schema}.{syn['synonym_name']}"
)
for tmp_table in inspect(conn).get_temp_table_names():
conn.exec_driver_sql(f"drop table {tmp_table}")
@drop_db.for_db("oracle")
def _oracle_drop_db(cfg, eng, ident):
with eng.begin() as conn:
# cx_Oracle seems to occasionally leak open connections when a large
# suite it run, even if we confirm we have zero references to
# connection objects.
# while there is a "kill session" command in Oracle,
# it unfortunately does not release the connection sufficiently.
_ora_drop_ignore(conn, ident)
_ora_drop_ignore(conn, "%s_ts1" % ident)
_ora_drop_ignore(conn, "%s_ts2" % ident)
@stop_test_class_outside_fixtures.for_db("oracle")
def _ora_stop_test_class_outside_fixtures(config, db, cls):
try:
_purge_recyclebin(db)
except exc.DatabaseError as err:
log.warning("purge recyclebin command failed: %s", err)
# clear statement cache on all connections that were used
# https://github.com/oracle/python-cx_Oracle/issues/519
for cx_oracle_conn in _all_conns:
try:
sc = cx_oracle_conn.stmtcachesize
except db.dialect.dbapi.InterfaceError:
# connection closed
pass
else:
cx_oracle_conn.stmtcachesize = 0
cx_oracle_conn.stmtcachesize = sc
_all_conns.clear()
def _purge_recyclebin(eng, schema=None):
with eng.begin() as conn:
if schema is None:
# run magic command to get rid of identity sequences
# https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
conn.exec_driver_sql("purge recyclebin")
else:
# per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501
for owner, object_name, type_ in conn.exec_driver_sql(
"select owner, object_name,type from "
"dba_recyclebin where owner=:schema and type='TABLE'",
{"schema": conn.dialect.denormalize_name(schema)},
).all():
conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"')
_all_conns = set()
@post_configure_engine.for_db("oracle")
def _oracle_post_configure_engine(url, engine, follower_ident):
from sqlalchemy import event
@event.listens_for(engine, "checkout")
def checkout(dbapi_con, con_record, con_proxy):
_all_conns.add(dbapi_con)
@event.listens_for(engine, "checkin")
def checkin(dbapi_connection, connection_record):
# work around cx_Oracle issue:
# https://github.com/oracle/python-cx_Oracle/issues/530
# invalidate oracle connections that had 2pc set up
if "cx_oracle_xid" in connection_record.info:
connection_record.invalidate()
@run_reap_dbs.for_db("oracle")
def _reap_oracle_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
with eng.begin() as conn:
log.info("identifiers in file: %s", ", ".join(idents))
to_reap = conn.exec_driver_sql(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
"from v$session where username=u.username)"
)
all_names = {username.lower() for (username,) in to_reap}
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
continue
elif name in idents:
to_drop.add(name)
if "%s_ts1" % name in all_names:
to_drop.add("%s_ts1" % name)
if "%s_ts2" % name in all_names:
to_drop.add("%s_ts2" % name)
dropped = total = 0
for total, username in enumerate(to_drop, 1):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
@follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
return url.set(username=ident, password="xe")
@temp_table_keyword_args.for_db("oracle")
def _oracle_temp_table_keyword_args(cfg, eng):
return {
"prefixes": ["GLOBAL TEMPORARY"],
"oracle_on_commit": "PRESERVE ROWS",
}
@set_default_schema_on_connection.for_db("oracle")
def _oracle_set_default_schema_on_connection(
cfg, dbapi_connection, schema_name
):
cursor = dbapi_connection.cursor()
cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name)
cursor.close()
@update_db_opts.for_db("oracle")
def _update_db_opts(db_url, db_opts, options):
"""Set database options (db_opts) for a test database that we created."""
if (
options.oracledb_thick_mode
and sa_url.make_url(db_url).get_driver_name() == "oracledb"
):
db_opts["thick_mode"] = True

View file

@ -0,0 +1,287 @@
# dialects/oracle/types.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import datetime as dt
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from ... import exc
from ...sql import sqltypes
from ...types import NVARCHAR
from ...types import VARCHAR
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.type_api import _LiteralProcessorType
class RAW(sqltypes._Binary):
__visit_name__ = "RAW"
OracleRaw = RAW
class NCLOB(sqltypes.Text):
__visit_name__ = "NCLOB"
class VARCHAR2(VARCHAR):
__visit_name__ = "VARCHAR2"
NVARCHAR2 = NVARCHAR
class NUMBER(sqltypes.Numeric, sqltypes.Integer):
__visit_name__ = "NUMBER"
def __init__(self, precision=None, scale=None, asdecimal=None):
if asdecimal is None:
asdecimal = bool(scale and scale > 0)
super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)
def adapt(self, impltype):
ret = super().adapt(impltype)
# leave a hint for the DBAPI handler
ret._is_oracle_number = True
return ret
@property
def _type_affinity(self):
if bool(self.scale and self.scale > 0):
return sqltypes.Numeric
else:
return sqltypes.Integer
class FLOAT(sqltypes.FLOAT):
"""Oracle FLOAT.
This is the same as :class:`_sqltypes.FLOAT` except that
an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
parameter is accepted, and
the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
Oracle FLOAT types indicate precision in terms of "binary precision", which
defaults to 126. For a REAL type, the value is 63. This parameter does not
cleanly map to a specific number of decimal places but is roughly
equivalent to the desired number of decimal places divided by 0.3103.
.. versionadded:: 2.0
"""
__visit_name__ = "FLOAT"
def __init__(
self,
binary_precision=None,
asdecimal=False,
decimal_return_scale=None,
):
r"""
Construct a FLOAT
:param binary_precision: Oracle binary precision value to be rendered
in DDL. This may be approximated to the number of decimal characters
using the formula "decimal precision = 0.30103 * binary precision".
The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
:param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
:param decimal_return_scale: See
:paramref:`_sqltypes.Float.decimal_return_scale`
"""
super().__init__(
asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
)
self.binary_precision = binary_precision
class BINARY_DOUBLE(sqltypes.Double):
__visit_name__ = "BINARY_DOUBLE"
class BINARY_FLOAT(sqltypes.Float):
__visit_name__ = "BINARY_FLOAT"
class BFILE(sqltypes.LargeBinary):
__visit_name__ = "BFILE"
class LONG(sqltypes.Text):
__visit_name__ = "LONG"
class _OracleDateLiteralRender:
def _literal_processor_datetime(self, dialect):
def process(value):
if getattr(value, "microsecond", None):
value = (
f"""TO_TIMESTAMP"""
f"""('{value.isoformat().replace("T", " ")}', """
"""'YYYY-MM-DD HH24:MI:SS.FF')"""
)
else:
value = (
f"""TO_DATE"""
f"""('{value.isoformat().replace("T", " ")}', """
"""'YYYY-MM-DD HH24:MI:SS')"""
)
return value
return process
def _literal_processor_date(self, dialect):
def process(value):
if getattr(value, "microsecond", None):
value = (
f"""TO_TIMESTAMP"""
f"""('{value.isoformat().split("T")[0]}', """
"""'YYYY-MM-DD')"""
)
else:
value = (
f"""TO_DATE"""
f"""('{value.isoformat().split("T")[0]}', """
"""'YYYY-MM-DD')"""
)
return value
return process
class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
"""Provide the oracle DATE type.
This type has no special Python behavior, except that it subclasses
:class:`_types.DateTime`; this is to suit the fact that the Oracle
``DATE`` type supports a time value.
"""
__visit_name__ = "DATE"
def literal_processor(self, dialect):
return self._literal_processor_datetime(dialect)
def _compare_type_affinity(self, other):
return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
def literal_processor(self, dialect):
return self._literal_processor_date(dialect)
class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
__visit_name__ = "INTERVAL"
def __init__(self, day_precision=None, second_precision=None):
"""Construct an INTERVAL.
Note that only DAY TO SECOND intervals are currently supported.
This is due to a lack of support for YEAR TO MONTH intervals
within available DBAPIs.
:param day_precision: the day precision value. this is the number of
digits to store for the day field. Defaults to "2"
:param second_precision: the second precision value. this is the
number of digits to store for the fractional seconds field.
Defaults to "6".
"""
self.day_precision = day_precision
self.second_precision = second_precision
@classmethod
def _adapt_from_generic_interval(cls, interval):
return INTERVAL(
day_precision=interval.day_precision,
second_precision=interval.second_precision,
)
@classmethod
def adapt_emulated_to_native(
cls, interval: sqltypes.Interval, **kw # type: ignore[override]
):
return INTERVAL(
day_precision=interval.day_precision,
second_precision=interval.second_precision,
)
@property
def _type_affinity(self):
return sqltypes.Interval
def as_generic(self, allow_nulltype=False):
return sqltypes.Interval(
native=True,
second_precision=self.second_precision,
day_precision=self.day_precision,
)
@property
def python_type(self) -> Type[dt.timedelta]:
return dt.timedelta
def literal_processor(
self, dialect: Dialect
) -> Optional[_LiteralProcessorType[dt.timedelta]]:
def process(value: dt.timedelta) -> str:
return f"NUMTODSINTERVAL({value.total_seconds()}, 'SECOND')"
return process
class TIMESTAMP(sqltypes.TIMESTAMP):
"""Oracle implementation of ``TIMESTAMP``, which supports additional
Oracle-specific modes
.. versionadded:: 2.0
"""
def __init__(self, timezone: bool = False, local_timezone: bool = False):
"""Construct a new :class:`_oracle.TIMESTAMP`.
:param timezone: boolean. Indicates that the TIMESTAMP type should
use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype.
:param local_timezone: boolean. Indicates that the TIMESTAMP type
should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype.
"""
if timezone and local_timezone:
raise exc.ArgumentError(
"timezone and local_timezone are mutually exclusive"
)
super().__init__(timezone=timezone)
self.local_timezone = local_timezone
class ROWID(sqltypes.TypeEngine):
"""Oracle ROWID type.
When used in a cast() or similar, generates ROWID.
"""
__visit_name__ = "ROWID"
class _OracleBoolean(sqltypes.Boolean):
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER

View file

@ -0,0 +1,167 @@
# dialects/postgresql/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from types import ModuleType
from . import array as arraylib # noqa # must be above base and other dialects
from . import asyncpg # noqa
from . import base
from . import pg8000 # noqa
from . import psycopg # noqa
from . import psycopg2 # noqa
from . import psycopg2cffi # noqa
from .array import All
from .array import Any
from .array import ARRAY
from .array import array
from .base import BIGINT
from .base import BOOLEAN
from .base import CHAR
from .base import DATE
from .base import DOMAIN
from .base import DOUBLE_PRECISION
from .base import FLOAT
from .base import INTEGER
from .base import NUMERIC
from .base import REAL
from .base import SMALLINT
from .base import TEXT
from .base import UUID
from .base import VARCHAR
from .dml import Insert
from .dml import insert
from .ext import aggregate_order_by
from .ext import array_agg
from .ext import ExcludeConstraint
from .ext import phraseto_tsquery
from .ext import plainto_tsquery
from .ext import to_tsquery
from .ext import to_tsvector
from .ext import ts_headline
from .ext import websearch_to_tsquery
from .hstore import HSTORE
from .hstore import hstore
from .json import JSON
from .json import JSONB
from .json import JSONPATH
from .named_types import CreateDomainType
from .named_types import CreateEnumType
from .named_types import DropDomainType
from .named_types import DropEnumType
from .named_types import ENUM
from .named_types import NamedType
from .ranges import AbstractMultiRange
from .ranges import AbstractRange
from .ranges import AbstractSingleRange
from .ranges import DATEMULTIRANGE
from .ranges import DATERANGE
from .ranges import INT4MULTIRANGE
from .ranges import INT4RANGE
from .ranges import INT8MULTIRANGE
from .ranges import INT8RANGE
from .ranges import MultiRange
from .ranges import NUMMULTIRANGE
from .ranges import NUMRANGE
from .ranges import Range
from .ranges import TSMULTIRANGE
from .ranges import TSRANGE
from .ranges import TSTZMULTIRANGE
from .ranges import TSTZRANGE
from .types import BIT
from .types import BYTEA
from .types import CIDR
from .types import CITEXT
from .types import INET
from .types import INTERVAL
from .types import MACADDR
from .types import MACADDR8
from .types import MONEY
from .types import OID
from .types import REGCLASS
from .types import REGCONFIG
from .types import TIME
from .types import TIMESTAMP
from .types import TSQUERY
from .types import TSVECTOR
# Alias psycopg also as psycopg_async
psycopg_async = type(
"psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async}
)
base.dialect = dialect = psycopg2.dialect
__all__ = (
"INTEGER",
"BIGINT",
"SMALLINT",
"VARCHAR",
"CHAR",
"TEXT",
"NUMERIC",
"FLOAT",
"REAL",
"INET",
"CIDR",
"CITEXT",
"UUID",
"BIT",
"MACADDR",
"MACADDR8",
"MONEY",
"OID",
"REGCLASS",
"REGCONFIG",
"TSQUERY",
"TSVECTOR",
"DOUBLE_PRECISION",
"TIMESTAMP",
"TIME",
"DATE",
"BYTEA",
"BOOLEAN",
"INTERVAL",
"ARRAY",
"ENUM",
"DOMAIN",
"dialect",
"array",
"HSTORE",
"hstore",
"INT4RANGE",
"INT8RANGE",
"NUMRANGE",
"DATERANGE",
"INT4MULTIRANGE",
"INT8MULTIRANGE",
"NUMMULTIRANGE",
"DATEMULTIRANGE",
"TSVECTOR",
"TSRANGE",
"TSTZRANGE",
"TSMULTIRANGE",
"TSTZMULTIRANGE",
"JSON",
"JSONB",
"JSONPATH",
"Any",
"All",
"DropEnumType",
"DropDomainType",
"CreateDomainType",
"NamedType",
"CreateEnumType",
"ExcludeConstraint",
"Range",
"aggregate_order_by",
"array_agg",
"insert",
"Insert",
)

View file

@ -0,0 +1,187 @@
# dialects/postgresql/_psycopg_common.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import decimal
from .array import ARRAY as PGARRAY
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
from .base import PGDialect
from .base import PGExecutionContext
from .hstore import HSTORE
from .pg_catalog import _SpaceVector
from .pg_catalog import INT2VECTOR
from .pg_catalog import OIDVECTOR
from ... import exc
from ... import types as sqltypes
from ... import util
from ...engine import processors
_server_side_id = util.counter()
class _PsycopgNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal, self._effective_decimal_return_scale
)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# psycopg returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
else:
if coltype in _FLOAT_TYPES:
# psycopg returns float natively for 701
return None
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
class _PsycopgFloat(_PsycopgNumeric):
__visit_name__ = "float"
class _PsycopgHStore(HSTORE):
def bind_processor(self, dialect):
if dialect._has_native_hstore:
return None
else:
return super().bind_processor(dialect)
def result_processor(self, dialect, coltype):
if dialect._has_native_hstore:
return None
else:
return super().result_processor(dialect, coltype)
class _PsycopgARRAY(PGARRAY):
render_bind_cast = True
class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR):
pass
class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR):
pass
class _PGExecutionContext_common_psycopg(PGExecutionContext):
def create_server_side_cursor(self):
# use server-side cursors:
# psycopg
# https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors
# psycopg2
# https://www.psycopg.org/docs/usage.html#server-side-cursors
ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
return self._dbapi_connection.cursor(ident)
class _PGDialect_common_psycopg(PGDialect):
supports_statement_cache = True
supports_server_side_cursors = True
default_paramstyle = "pyformat"
_has_native_hstore = True
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric: _PsycopgNumeric,
sqltypes.Float: _PsycopgFloat,
HSTORE: _PsycopgHStore,
sqltypes.ARRAY: _PsycopgARRAY,
INT2VECTOR: _PsycopgINT2VECTOR,
OIDVECTOR: _PsycopgOIDVECTOR,
},
)
def __init__(
self,
client_encoding=None,
use_native_hstore=True,
**kwargs,
):
PGDialect.__init__(self, **kwargs)
if not use_native_hstore:
self._has_native_hstore = False
self.use_native_hstore = use_native_hstore
self.client_encoding = client_encoding
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user", database="dbname")
multihosts, multiports = self._split_multihost_from_url(url)
if opts or url.query:
if not opts:
opts = {}
if "port" in opts:
opts["port"] = int(opts["port"])
opts.update(url.query)
if multihosts:
opts["host"] = ",".join(multihosts)
comma_ports = ",".join(str(p) if p else "" for p in multiports)
if comma_ports:
opts["port"] = comma_ports
return ([], opts)
else:
# no connection arguments whatsoever; psycopg2.connect()
# requires that "dsn" be present as a blank string.
return ([""], opts)
def get_isolation_level_values(self, dbapi_connection):
return (
"AUTOCOMMIT",
"READ COMMITTED",
"READ UNCOMMITTED",
"REPEATABLE READ",
"SERIALIZABLE",
)
def set_deferrable(self, connection, value):
connection.deferrable = value
def get_deferrable(self, connection):
return connection.deferrable
def _do_autocommit(self, connection, value):
connection.autocommit = value
def do_ping(self, dbapi_connection):
cursor = None
before_autocommit = dbapi_connection.autocommit
if not before_autocommit:
dbapi_connection.autocommit = True
cursor = dbapi_connection.cursor()
try:
cursor.execute(self._dialect_specific_select_one)
finally:
cursor.close()
if not before_autocommit and not dbapi_connection.closed:
dbapi_connection.autocommit = before_autocommit
return True

View file

@ -0,0 +1,424 @@
# dialects/postgresql/array.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TypeVar
from .operators import CONTAINED_BY
from .operators import CONTAINS
from .operators import OVERLAP
from ... import types as sqltypes
from ... import util
from ...sql import expression
from ...sql import operators
from ...sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
def Any(other, arrexpr, operator=operators.eq):
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
See that method for details.
"""
return arrexpr.any(other, operator)
def All(other, arrexpr, operator=operators.eq):
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
See that method for details.
"""
return arrexpr.all(other, operator)
class array(expression.ExpressionClauseList[_T]):
"""A PostgreSQL ARRAY literal.
This is used to produce ARRAY literals in SQL expressions, e.g.::
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects import postgresql
from sqlalchemy import select, func
stmt = select(array([1,2]) + array([3,4,5]))
print(stmt.compile(dialect=postgresql.dialect()))
Produces the SQL::
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
An instance of :class:`.array` will always have the datatype
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
the values present, unless the ``type_`` keyword argument is passed::
array(['foo', 'bar'], type_=CHAR)
Multidimensional arrays are produced by nesting :class:`.array` constructs.
The dimensionality of the final :class:`_types.ARRAY`
type is calculated by
recursively adding the dimensions of the inner :class:`_types.ARRAY`
type::
stmt = select(
array([
array([1, 2]), array([3, 4]), array([column('q'), column('x')])
])
)
print(stmt.compile(dialect=postgresql.dialect()))
Produces::
SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
.. versionadded:: 1.3.6 added support for multidimensional array literals
.. seealso::
:class:`_postgresql.ARRAY`
"""
__visit_name__ = "array"
stringify_dialect = "postgresql"
inherit_cache = True
def __init__(self, clauses, **kw):
type_arg = kw.pop("type_", None)
super().__init__(operators.comma_op, *clauses, **kw)
self._type_tuple = [arg.type for arg in self.clauses]
main_type = (
type_arg
if type_arg is not None
else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE
)
if isinstance(main_type, ARRAY):
self.type = ARRAY(
main_type.item_type,
dimensions=(
main_type.dimensions + 1
if main_type.dimensions is not None
else 2
),
)
else:
self.type = ARRAY(main_type)
@property
def _select_iterable(self):
return (self,)
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
if _assume_scalar or operator is operators.getitem:
return expression.BindParameter(
None,
obj,
_compared_to_operator=operator,
type_=type_,
_compared_to_type=self.type,
unique=True,
)
else:
return array(
[
self._bind_param(
operator, o, _assume_scalar=True, type_=type_
)
for o in obj
]
)
def self_group(self, against=None):
if against in (operators.any_op, operators.all_op, operators.getitem):
return expression.Grouping(self)
else:
return self
class ARRAY(sqltypes.ARRAY):
"""PostgreSQL ARRAY type.
The :class:`_postgresql.ARRAY` type is constructed in the same way
as the core :class:`_types.ARRAY` type; a member type is required, and a
number of dimensions is recommended if the type is to be used for more
than one dimension::
from sqlalchemy.dialects import postgresql
mytable = Table("mytable", metadata,
Column("data", postgresql.ARRAY(Integer, dimensions=2))
)
The :class:`_postgresql.ARRAY` type provides all operations defined on the
core :class:`_types.ARRAY` type, including support for "dimensions",
indexed access, and simple matching such as
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY`
class also
provides PostgreSQL-specific methods for containment operations, including
:meth:`.postgresql.ARRAY.Comparator.contains`
:meth:`.postgresql.ARRAY.Comparator.contained_by`, and
:meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
mytable.c.data.contains([1, 2])
The :class:`_postgresql.ARRAY` type may not be supported on all
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
Additionally, the :class:`_postgresql.ARRAY`
type does not work directly in
conjunction with the :class:`.ENUM` type. For a workaround, see the
special type at :ref:`postgresql_array_of_enum`.
.. container:: topic
**Detecting Changes in ARRAY columns when using the ORM**
The :class:`_postgresql.ARRAY` type, when used with the SQLAlchemy ORM,
does not detect in-place mutations to the array. In order to detect
these, the :mod:`sqlalchemy.ext.mutable` extension must be used, using
the :class:`.MutableList` class::
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.mutable import MutableList
class SomeOrmClass(Base):
# ...
data = Column(MutableList.as_mutable(ARRAY(Integer)))
This extension will allow "in-place" changes such to the array
such as ``.append()`` to produce events which will be detected by the
unit of work. Note that changes to elements **inside** the array,
including subarrays that are mutated in place, are **not** detected.
Alternatively, assigning a new array value to an ORM element that
replaces the old one will always trigger a change event.
.. seealso::
:class:`_types.ARRAY` - base array type
:class:`_postgresql.array` - produces a literal array value.
"""
class Comparator(sqltypes.ARRAY.Comparator):
"""Define comparison operations for :class:`_types.ARRAY`.
Note that these operations are in addition to those provided
by the base :class:`.types.ARRAY.Comparator` class, including
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`.
"""
def contains(self, other, **kwargs):
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def overlap(self, other):
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
comparator_factory = Comparator
def __init__(
self,
item_type: _TypeEngineArgument[Any],
as_tuple: bool = False,
dimensions: Optional[int] = None,
zero_indexes: bool = False,
):
"""Construct an ARRAY.
E.g.::
Column('myarray', ARRAY(Integer))
Arguments are:
:param item_type: The data type of items of this array. Note that
dimensionality is irrelevant here, so multi-dimensional arrays like
``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
``ARRAY(ARRAY(Integer))`` or such.
:param as_tuple=False: Specify whether return results
should be converted to tuples from lists. DBAPIs such
as psycopg2 return lists by default. When tuples are
returned, the results are hashable.
:param dimensions: if non-None, the ARRAY will assume a fixed
number of dimensions. This will cause the DDL emitted for this
ARRAY to include the exact number of bracket clauses ``[]``,
and will also optimize the performance of the type overall.
Note that PG arrays are always implicitly "non-dimensioned",
meaning they can store any number of dimensions no matter how
they were declared.
:param zero_indexes=False: when True, index values will be converted
between Python zero-based and PostgreSQL one-based indexes, e.g.
a value of one will be added to all index values before passing
to the database.
"""
if isinstance(item_type, ARRAY):
raise ValueError(
"Do not nest ARRAY types; ARRAY(basetype) "
"handles multi-dimensional arrays of basetype"
)
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
self.as_tuple = as_tuple
self.dimensions = dimensions
self.zero_indexes = zero_indexes
@property
def hashable(self):
return self.as_tuple
@property
def python_type(self):
return list
def compare_values(self, x, y):
return x == y
@util.memoized_property
def _against_native_enum(self):
return (
isinstance(self.item_type, sqltypes.Enum)
and self.item_type.native_enum
)
def literal_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
dialect
)
if item_proc is None:
return None
def to_str(elements):
return f"ARRAY[{', '.join(elements)}]"
def process(value):
inner = self._apply_item_processor(
value, item_proc, self.dimensions, to_str
)
return inner
return process
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
dialect
)
def process(value):
if value is None:
return value
else:
return self._apply_item_processor(
value, item_proc, self.dimensions, list
)
return process
def result_processor(self, dialect, coltype):
item_proc = self.item_type.dialect_impl(dialect).result_processor(
dialect, coltype
)
def process(value):
if value is None:
return value
else:
return self._apply_item_processor(
value,
item_proc,
self.dimensions,
tuple if self.as_tuple else list,
)
if self._against_native_enum:
super_rp = process
pattern = re.compile(r"^{(.*)}$")
def handle_raw_string(value):
inner = pattern.match(value).group(1)
return _split_enum_values(inner)
def process(value):
if value is None:
return value
# isinstance(value, str) is required to handle
# the case where a TypeDecorator for and Array of Enum is
# used like was required in sa < 1.3.17
return super_rp(
handle_raw_string(value)
if isinstance(value, str)
else value
)
return process
def _split_enum_values(array_string):
if '"' not in array_string:
# no escape char is present so it can just split on the comma
return array_string.split(",") if array_string else []
# handles quoted strings from:
# r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
# returns
# ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr']
text = array_string.replace(r"\"", "_$ESC_QUOTE$_")
text = text.replace(r"\\", "\\")
result = []
on_quotes = re.split(r'(")', text)
in_quotes = False
for tok in on_quotes:
if tok == '"':
in_quotes = not in_quotes
elif in_quotes:
result.append(tok.replace("_$ESC_QUOTE$_", '"'))
else:
result.extend(re.findall(r"([^\s,]+),?", tok))
return result

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,310 @@
# dialects/postgresql/dml.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import Optional
from . import ext
from .._typing import _OnConflictConstraintT
from .._typing import _OnConflictIndexElementsT
from .._typing import _OnConflictIndexWhereT
from .._typing import _OnConflictSetT
from .._typing import _OnConflictWhereT
from ... import util
from ...sql import coercions
from ...sql import roles
from ...sql import schema
from ...sql._typing import _DMLTableArgument
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ColumnCollection
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.elements import KeyedColumnElement
from ...sql.expression import alias
from ...util.typing import Self
__all__ = ("Insert", "insert")
def insert(table: _DMLTableArgument) -> Insert:
"""Construct a PostgreSQL-specific variant :class:`_postgresql.Insert`
construct.
.. container:: inherited_member
The :func:`sqlalchemy.dialects.postgresql.insert` function creates
a :class:`sqlalchemy.dialects.postgresql.Insert`. This class is based
on the dialect-agnostic :class:`_sql.Insert` construct which may
be constructed using the :func:`_sql.insert` function in
SQLAlchemy Core.
The :class:`_postgresql.Insert` construct includes additional methods
:meth:`_postgresql.Insert.on_conflict_do_update`,
:meth:`_postgresql.Insert.on_conflict_do_nothing`.
"""
return Insert(table)
class Insert(StandardInsert):
"""PostgreSQL-specific implementation of INSERT.
Adds methods for PG-specific syntaxes such as ON CONFLICT.
The :class:`_postgresql.Insert` object is created using the
:func:`sqlalchemy.dialects.postgresql.insert` function.
"""
stringify_dialect = "postgresql"
inherit_cache = False
@util.memoized_property
def excluded(
self,
) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
PG's ON CONFLICT clause allows reference to the row that would
be inserted, known as ``excluded``. This attribute provides
all columns in this row to be referenceable.
.. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an
instance of :class:`_expression.ColumnCollection`, which provides
an interface the same as that of the :attr:`_schema.Table.c`
collection described at :ref:`metadata_tables_and_columns`.
With this collection, ordinary names are accessible like attributes
(e.g. ``stmt.excluded.some_column``), but special names and
dictionary method names should be accessed using indexed access,
such as ``stmt.excluded["column name"]`` or
``stmt.excluded["values"]``. See the docstring for
:class:`_expression.ColumnCollection` for further examples.
.. seealso::
:ref:`postgresql_insert_on_conflict` - example of how
to use :attr:`_expression.Insert.excluded`
"""
return alias(self.table, name="excluded").columns
_on_conflict_exclusive = _exclusive_against(
"_post_values_clause",
msgs={
"_post_values_clause": "This Insert construct already has "
"an ON CONFLICT clause established"
},
)
@_generative
@_on_conflict_exclusive
def on_conflict_do_update(
self,
constraint: _OnConflictConstraintT = None,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
set_: _OnConflictSetT = None,
where: _OnConflictWhereT = None,
) -> Self:
r"""
Specifies a DO UPDATE SET action for ON CONFLICT clause.
Either the ``constraint`` or ``index_elements`` argument is
required, but only one of these can be specified.
:param constraint:
The name of a unique or exclusion constraint on the table,
or the constraint object itself if it has a .name attribute.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
:param set\_:
A dictionary or other mapping object
where the keys are either names of columns in the target table,
or :class:`_schema.Column` objects or other ORM-mapped columns
matching that of the target table, and expressions or literals
as values, specifying the ``SET`` actions to take.
.. versionadded:: 1.4 The
:paramref:`_postgresql.Insert.on_conflict_do_update.set_`
parameter supports :class:`_schema.Column` objects from the target
:class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`_schema.Column.onupdate`.
These values will not be exercised for an ON CONFLICT style of
UPDATE, unless they are manually specified in the
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
.. seealso::
:ref:`postgresql_insert_on_conflict`
"""
self._post_values_clause = OnConflictDoUpdate(
constraint, index_elements, index_where, set_, where
)
return self
@_generative
@_on_conflict_exclusive
def on_conflict_do_nothing(
self,
constraint: _OnConflictConstraintT = None,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
) -> Self:
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
The ``constraint`` and ``index_elements`` arguments
are optional, but only one of these can be specified.
:param constraint:
The name of a unique or exclusion constraint on the table,
or the constraint object itself if it has a .name attribute.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
.. seealso::
:ref:`postgresql_insert_on_conflict`
"""
self._post_values_clause = OnConflictDoNothing(
constraint, index_elements, index_where
)
return self
class OnConflictClause(ClauseElement):
stringify_dialect = "postgresql"
constraint_target: Optional[str]
inferred_target_elements: _OnConflictIndexElementsT
inferred_target_whereclause: _OnConflictIndexWhereT
def __init__(
self,
constraint: _OnConflictConstraintT = None,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
):
if constraint is not None:
if not isinstance(constraint, str) and isinstance(
constraint,
(schema.Constraint, ext.ExcludeConstraint),
):
constraint = getattr(constraint, "name") or constraint
if constraint is not None:
if index_elements is not None:
raise ValueError(
"'constraint' and 'index_elements' are mutually exclusive"
)
if isinstance(constraint, str):
self.constraint_target = constraint
self.inferred_target_elements = None
self.inferred_target_whereclause = None
elif isinstance(constraint, schema.Index):
index_elements = constraint.expressions
index_where = constraint.dialect_options["postgresql"].get(
"where"
)
elif isinstance(constraint, ext.ExcludeConstraint):
index_elements = constraint.columns
index_where = constraint.where
else:
index_elements = constraint.columns
index_where = constraint.dialect_options["postgresql"].get(
"where"
)
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
elif constraint is None:
self.constraint_target = self.inferred_target_elements = (
self.inferred_target_whereclause
) = None
class OnConflictDoNothing(OnConflictClause):
__visit_name__ = "on_conflict_do_nothing"
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
def __init__(
self,
constraint: _OnConflictConstraintT = None,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
set_: _OnConflictSetT = None,
where: _OnConflictWhereT = None,
):
super().__init__(
constraint=constraint,
index_elements=index_elements,
index_where=index_where,
)
if (
self.inferred_target_elements is None
and self.constraint_target is None
):
raise ValueError(
"Either constraint or index_elements, "
"but not both, must be specified unless DO NOTHING"
)
if isinstance(set_, dict):
if not set_:
raise ValueError("set parameter dictionary must not be empty")
elif isinstance(set_, ColumnCollection):
set_ = dict(set_)
else:
raise ValueError(
"set parameter must be a non-empty dictionary "
"or a ColumnCollection such as the `.c.` collection "
"of a Table object"
)
self.update_values_to_set = [
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
self.update_whereclause = where

View file

@ -0,0 +1,496 @@
# dialects/postgresql/ext.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
from typing import TypeVar
from . import types
from .array import ARRAY
from ...sql import coercions
from ...sql import elements
from ...sql import expression
from ...sql import functions
from ...sql import roles
from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
from ...sql.sqltypes import TEXT
from ...sql.visitors import InternalTraversal
_T = TypeVar("_T", bound=Any)
if TYPE_CHECKING:
from ...sql.visitors import _TraverseInternalsType
class aggregate_order_by(expression.ColumnElement):
"""Represent a PostgreSQL aggregate order by expression.
E.g.::
from sqlalchemy.dialects.postgresql import aggregate_order_by
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
stmt = select(expr)
would represent the expression::
SELECT array_agg(a ORDER BY b DESC) FROM table;
Similarly::
expr = func.string_agg(
table.c.a,
aggregate_order_by(literal_column("','"), table.c.a)
)
stmt = select(expr)
Would represent::
SELECT string_agg(a, ',' ORDER BY a) FROM table;
.. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms
.. seealso::
:class:`_functions.array_agg`
"""
__visit_name__ = "aggregate_order_by"
stringify_dialect = "postgresql"
_traverse_internals: _TraverseInternalsType = [
("target", InternalTraversal.dp_clauseelement),
("type", InternalTraversal.dp_type),
("order_by", InternalTraversal.dp_clauseelement),
]
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
self.type = self.target.type
_lob = len(order_by)
if _lob == 0:
raise TypeError("at least one ORDER BY element is required")
elif _lob == 1:
self.order_by = coercions.expect(
roles.ExpressionElementRole, order_by[0]
)
else:
self.order_by = elements.ClauseList(
*order_by, _literal_as_text_role=roles.ExpressionElementRole
)
def self_group(self, against=None):
return self
def get_children(self, **kwargs):
return self.target, self.order_by
def _copy_internals(self, clone=elements._clone, **kw):
self.target = clone(self.target, **kw)
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
return self.target._from_objects + self.order_by._from_objects
class ExcludeConstraint(ColumnCollectionConstraint):
"""A table-level EXCLUDE constraint.
Defines an EXCLUDE constraint as described in the `PostgreSQL
documentation`__.
__ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
""" # noqa
__visit_name__ = "exclude_constraint"
where = None
inherit_cache = False
create_drop_stringify_dialect = "postgresql"
@elements._document_text_coercion(
"where",
":class:`.ExcludeConstraint`",
":paramref:`.ExcludeConstraint.where`",
)
def __init__(self, *elements, **kw):
r"""
Create an :class:`.ExcludeConstraint` object.
E.g.::
const = ExcludeConstraint(
(Column('period'), '&&'),
(Column('group'), '='),
where=(Column('group') != 'some group'),
ops={'group': 'my_operator_class'}
)
The constraint is normally embedded into the :class:`_schema.Table`
construct
directly, or added later using :meth:`.append_constraint`::
some_table = Table(
'some_table', metadata,
Column('id', Integer, primary_key=True),
Column('period', TSRANGE()),
Column('group', String)
)
some_table.append_constraint(
ExcludeConstraint(
(some_table.c.period, '&&'),
(some_table.c.group, '='),
where=some_table.c.group != 'some group',
name='some_table_excl_const',
ops={'group': 'my_operator_class'}
)
)
The exclude constraint defined in this example requires the
``btree_gist`` extension, that can be created using the
command ``CREATE EXTENSION btree_gist;``.
:param \*elements:
A sequence of two tuples of the form ``(column, operator)`` where
"column" is either a :class:`_schema.Column` object, or a SQL
expression element (e.g. ``func.int8range(table.from, table.to)``)
or the name of a column as string, and "operator" is a string
containing the operator to use (e.g. `"&&"` or `"="`).
In order to specify a column name when a :class:`_schema.Column`
object is not available, while ensuring
that any necessary quoting rules take effect, an ad-hoc
:class:`_schema.Column` or :func:`_expression.column`
object should be used.
The ``column`` may also be a string SQL expression when
passed as :func:`_expression.literal_column` or
:func:`_expression.text`
:param name:
Optional, the in-database name of this constraint.
:param deferrable:
Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
issuing DDL for this constraint.
:param initially:
Optional string. If set, emit INITIALLY <value> when issuing DDL
for this constraint.
:param using:
Optional string. If set, emit USING <index_method> when issuing DDL
for this constraint. Defaults to 'gist'.
:param where:
Optional SQL expression construct or literal SQL string.
If set, emit WHERE <predicate> when issuing DDL
for this constraint.
:param ops:
Optional dictionary. Used to define operator classes for the
elements; works the same way as that of the
:ref:`postgresql_ops <postgresql_operator_classes>`
parameter specified to the :class:`_schema.Index` construct.
.. versionadded:: 1.3.21
.. seealso::
:ref:`postgresql_operator_classes` - general description of how
PostgreSQL operator classes are specified.
"""
columns = []
render_exprs = []
self.operators = {}
expressions, operators = zip(*elements)
for (expr, column, strname, add_element), operator in zip(
coercions.expect_col_expression_collection(
roles.DDLConstraintColumnRole, expressions
),
operators,
):
if add_element is not None:
columns.append(add_element)
name = column.name if column is not None else strname
if name is not None:
# backwards compat
self.operators[name] = operator
render_exprs.append((expr, name, operator))
self._render_exprs = render_exprs
ColumnCollectionConstraint.__init__(
self,
*columns,
name=kw.get("name"),
deferrable=kw.get("deferrable"),
initially=kw.get("initially"),
)
self.using = kw.get("using", "gist")
where = kw.get("where")
if where is not None:
self.where = coercions.expect(roles.StatementOptionRole, where)
self.ops = kw.get("ops", {})
def _set_parent(self, table, **kw):
super()._set_parent(table)
self._render_exprs = [
(
expr if not isinstance(expr, str) else table.c[expr],
name,
operator,
)
for expr, name, operator in (self._render_exprs)
]
def _copy(self, target_table=None, **kw):
elements = [
(
schema._copy_expression(expr, self.parent, target_table),
operator,
)
for expr, _, operator in self._render_exprs
]
c = self.__class__(
*elements,
name=self.name,
deferrable=self.deferrable,
initially=self.initially,
where=self.where,
using=self.using,
)
c.dispatch._update(self.dispatch)
return c
def array_agg(*arg, **kw):
"""PostgreSQL-specific form of :class:`_functions.array_agg`, ensures
return type is :class:`_postgresql.ARRAY` and not
the plain :class:`_types.ARRAY`, unless an explicit ``type_``
is passed.
"""
kw["_default_array_type"] = ARRAY
return functions.func.array_agg(*arg, **kw)
class _regconfig_fn(functions.GenericFunction[_T]):
inherit_cache = True
def __init__(self, *args, **kwargs):
args = list(args)
if len(args) > 1:
initial_arg = coercions.expect(
roles.ExpressionElementRole,
args.pop(0),
name=getattr(self, "name", None),
apply_propagate_attrs=self,
type_=types.REGCONFIG,
)
initial_arg = [initial_arg]
else:
initial_arg = []
addtl_args = [
coercions.expect(
roles.ExpressionElementRole,
c,
name=getattr(self, "name", None),
apply_propagate_attrs=self,
)
for c in args
]
super().__init__(*(initial_arg + addtl_args), **kwargs)
class to_tsvector(_regconfig_fn):
"""The PostgreSQL ``to_tsvector`` SQL function.
This function applies automatic casting of the REGCONFIG argument
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
and applies a return type of :class:`_postgresql.TSVECTOR`.
Assuming the PostgreSQL dialect has been imported, either by invoking
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
engine using ``create_engine("postgresql...")``,
:class:`_postgresql.to_tsvector` will be used automatically when invoking
``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return
type handlers are used at compile and execution time.
.. versionadded:: 2.0.0rc1
"""
inherit_cache = True
type = types.TSVECTOR
class to_tsquery(_regconfig_fn):
"""The PostgreSQL ``to_tsquery`` SQL function.
This function applies automatic casting of the REGCONFIG argument
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
and applies a return type of :class:`_postgresql.TSQUERY`.
Assuming the PostgreSQL dialect has been imported, either by invoking
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
engine using ``create_engine("postgresql...")``,
:class:`_postgresql.to_tsquery` will be used automatically when invoking
``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return
type handlers are used at compile and execution time.
.. versionadded:: 2.0.0rc1
"""
inherit_cache = True
type = types.TSQUERY
class plainto_tsquery(_regconfig_fn):
"""The PostgreSQL ``plainto_tsquery`` SQL function.
This function applies automatic casting of the REGCONFIG argument
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
and applies a return type of :class:`_postgresql.TSQUERY`.
Assuming the PostgreSQL dialect has been imported, either by invoking
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
engine using ``create_engine("postgresql...")``,
:class:`_postgresql.plainto_tsquery` will be used automatically when
invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct
argument and return type handlers are used at compile and execution time.
.. versionadded:: 2.0.0rc1
"""
inherit_cache = True
type = types.TSQUERY
class phraseto_tsquery(_regconfig_fn):
"""The PostgreSQL ``phraseto_tsquery`` SQL function.
This function applies automatic casting of the REGCONFIG argument
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
and applies a return type of :class:`_postgresql.TSQUERY`.
Assuming the PostgreSQL dialect has been imported, either by invoking
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
engine using ``create_engine("postgresql...")``,
:class:`_postgresql.phraseto_tsquery` will be used automatically when
invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct
argument and return type handlers are used at compile and execution time.
.. versionadded:: 2.0.0rc1
"""
inherit_cache = True
type = types.TSQUERY
class websearch_to_tsquery(_regconfig_fn):
"""The PostgreSQL ``websearch_to_tsquery`` SQL function.
This function applies automatic casting of the REGCONFIG argument
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
and applies a return type of :class:`_postgresql.TSQUERY`.
Assuming the PostgreSQL dialect has been imported, either by invoking
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
engine using ``create_engine("postgresql...")``,
:class:`_postgresql.websearch_to_tsquery` will be used automatically when
invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct
argument and return type handlers are used at compile and execution time.
.. versionadded:: 2.0.0rc1
"""
inherit_cache = True
type = types.TSQUERY
class ts_headline(_regconfig_fn):
"""The PostgreSQL ``ts_headline`` SQL function.
This function applies automatic casting of the REGCONFIG argument
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
and applies a return type of :class:`_types.TEXT`.
Assuming the PostgreSQL dialect has been imported, either by invoking
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
engine using ``create_engine("postgresql...")``,
:class:`_postgresql.ts_headline` will be used automatically when invoking
``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return
type handlers are used at compile and execution time.
.. versionadded:: 2.0.0rc1
"""
inherit_cache = True
type = TEXT
def __init__(self, *args, **kwargs):
args = list(args)
# parse types according to
# https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE
if len(args) < 2:
# invalid args; don't do anything
has_regconfig = False
elif (
isinstance(args[1], elements.ColumnElement)
and args[1].type._type_affinity is types.TSQUERY
):
# tsquery is second argument, no regconfig argument
has_regconfig = False
else:
has_regconfig = True
if has_regconfig:
initial_arg = coercions.expect(
roles.ExpressionElementRole,
args.pop(0),
apply_propagate_attrs=self,
name=getattr(self, "name", None),
type_=types.REGCONFIG,
)
initial_arg = [initial_arg]
else:
initial_arg = []
addtl_args = [
coercions.expect(
roles.ExpressionElementRole,
c,
name=getattr(self, "name", None),
apply_propagate_attrs=self,
)
for c in args
]
super().__init__(*(initial_arg + addtl_args), **kwargs)

View file

@ -0,0 +1,397 @@
# dialects/postgresql/hstore.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import re
from .array import ARRAY
from .operators import CONTAINED_BY
from .operators import CONTAINS
from .operators import GETITEM
from .operators import HAS_ALL
from .operators import HAS_ANY
from .operators import HAS_KEY
from ... import types as sqltypes
from ...sql import functions as sqlfunc
__all__ = ("HSTORE", "hstore")
class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
"""Represent the PostgreSQL HSTORE type.
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', HSTORE)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
:class:`.HSTORE` provides for a wide range of operations, including:
* Index operations::
data_table.c.data['some key'] == 'some value'
* Containment operations::
data_table.c.data.has_key('some key')
data_table.c.data.has_all(['one', 'two', 'three'])
* Concatenation::
data_table.c.data + {"k1": "v1"}
For a full list of special methods see
:class:`.HSTORE.comparator_factory`.
.. container:: topic
**Detecting Changes in HSTORE columns when using the ORM**
For usage with the SQLAlchemy ORM, it may be desirable to combine the
usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary now
part of the :mod:`sqlalchemy.ext.mutable` extension. This extension
will allow "in-place" changes to the dictionary, e.g. addition of new
keys or replacement/removal of existing keys to/from the current
dictionary, to produce events which will be detected by the unit of
work::
from sqlalchemy.ext.mutable import MutableDict
class MyClass(Base):
__tablename__ = 'data_table'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(HSTORE))
my_object = session.query(MyClass).one()
# in-place mutation, requires Mutable extension
# in order for the ORM to detect
my_object.data['some_key'] = 'some value'
session.commit()
When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
will not be alerted to any changes to the contents of an existing
dictionary, unless that dictionary value is re-assigned to the
HSTORE-attribute itself, thus generating a change event.
.. seealso::
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
"""
__visit_name__ = "HSTORE"
hashable = False
text_type = sqltypes.Text()
def __init__(self, text_type=None):
"""Construct a new :class:`.HSTORE`.
:param text_type: the type that should be used for indexed values.
Defaults to :class:`_types.Text`.
"""
if text_type is not None:
self.text_type = text_type
class Comparator(
sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
):
"""Define comparison operations for :class:`.HSTORE`."""
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def _setup_getitem(self, index):
return GETITEM, index, self.type.text_type
def defined(self, key):
"""Boolean expression. Test for presence of a non-NULL value for
the key. Note that the key may be a SQLA expression.
"""
return _HStoreDefinedFunction(self.expr, key)
def delete(self, key):
"""HStore expression. Returns the contents of this hstore with the
given key deleted. Note that the key may be a SQLA expression.
"""
if isinstance(key, dict):
key = _serialize_hstore(key)
return _HStoreDeleteFunction(self.expr, key)
def slice(self, array):
"""HStore expression. Returns a subset of an hstore defined by
array of keys.
"""
return _HStoreSliceFunction(self.expr, array)
def keys(self):
"""Text array expression. Returns array of keys."""
return _HStoreKeysFunction(self.expr)
def vals(self):
"""Text array expression. Returns array of values."""
return _HStoreValsFunction(self.expr)
def array(self):
"""Text array expression. Returns array of alternating keys and
values.
"""
return _HStoreArrayFunction(self.expr)
def matrix(self):
"""Text array expression. Returns array of [key, value] pairs."""
return _HStoreMatrixFunction(self.expr)
comparator_factory = Comparator
def bind_processor(self, dialect):
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value)
else:
return value
return process
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
return _parse_hstore(value)
else:
return value
return process
class hstore(sqlfunc.GenericFunction):
"""Construct an hstore value within a SQL expression using the
PostgreSQL ``hstore()`` function.
The :class:`.hstore` function accepts one or two arguments as described
in the PostgreSQL documentation.
E.g.::
from sqlalchemy.dialects.postgresql import array, hstore
select(hstore('key1', 'value1'))
select(
hstore(
array(['key1', 'key2', 'key3']),
array(['value1', 'value2', 'value3'])
)
)
.. seealso::
:class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
"""
type = HSTORE
name = "hstore"
inherit_cache = True
class _HStoreDefinedFunction(sqlfunc.GenericFunction):
type = sqltypes.Boolean
name = "defined"
inherit_cache = True
class _HStoreDeleteFunction(sqlfunc.GenericFunction):
type = HSTORE
name = "delete"
inherit_cache = True
class _HStoreSliceFunction(sqlfunc.GenericFunction):
type = HSTORE
name = "slice"
inherit_cache = True
class _HStoreKeysFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "akeys"
inherit_cache = True
class _HStoreValsFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "avals"
inherit_cache = True
class _HStoreArrayFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "hstore_to_array"
inherit_cache = True
class _HStoreMatrixFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "hstore_to_matrix"
inherit_cache = True
#
# parsing. note that none of this is used with the psycopg2 backend,
# which provides its own native extensions.
#
# My best guess at the parsing rules of hstore literals, since no formal
# grammar is given. This is mostly reverse engineered from PG's input parser
# behavior.
HSTORE_PAIR_RE = re.compile(
r"""
(
"(?P<key> (\\ . | [^"])* )" # Quoted key
)
[ ]* => [ ]* # Pair operator, optional adjoining whitespace
(
(?P<value_null> NULL ) # NULL value
| "(?P<value> (\\ . | [^"])* )" # Quoted value
)
""",
re.VERBOSE,
)
HSTORE_DELIMITER_RE = re.compile(
r"""
[ ]* , [ ]*
""",
re.VERBOSE,
)
def _parse_error(hstore_str, pos):
"""format an unmarshalling error."""
ctx = 20
hslen = len(hstore_str)
parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)]
residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)]
if len(parsed_tail) > ctx:
parsed_tail = "[...]" + parsed_tail[1:]
if len(residual) > ctx:
residual = residual[:-1] + "[...]"
return "After %r, could not parse residual at position %d: %r" % (
parsed_tail,
pos,
residual,
)
def _parse_hstore(hstore_str):
"""Parse an hstore from its literal string representation.
Attempts to approximate PG's hstore input parsing rules as closely as
possible. Although currently this is not strictly necessary, since the
current implementation of hstore's output syntax is stricter than what it
accepts as input, the documentation makes no guarantees that will always
be the case.
"""
result = {}
pos = 0
pair_match = HSTORE_PAIR_RE.match(hstore_str)
while pair_match is not None:
key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\")
if pair_match.group("value_null"):
value = None
else:
value = (
pair_match.group("value")
.replace(r"\"", '"')
.replace("\\\\", "\\")
)
result[key] = value
pos += pair_match.end()
delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
if delim_match is not None:
pos += delim_match.end()
pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
if pos != len(hstore_str):
raise ValueError(_parse_error(hstore_str, pos))
return result
def _serialize_hstore(val):
"""Serialize a dictionary into an hstore literal. Keys and values must
both be strings (except None for values).
"""
def esc(s, position):
if position == "value" and s is None:
return "NULL"
elif isinstance(s, str):
return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"")
else:
raise ValueError(
"%r in %s position is not a string." % (s, position)
)
return ", ".join(
"%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items()
)

View file

@ -0,0 +1,325 @@
# dialects/postgresql/json.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .array import ARRAY
from .array import array as _pg_array
from .operators import ASTEXT
from .operators import CONTAINED_BY
from .operators import CONTAINS
from .operators import DELETE_PATH
from .operators import HAS_ALL
from .operators import HAS_ANY
from .operators import HAS_KEY
from .operators import JSONPATH_ASTEXT
from .operators import PATH_EXISTS
from .operators import PATH_MATCH
from ... import types as sqltypes
from ...sql import cast
__all__ = ("JSON", "JSONB")
class JSONPathType(sqltypes.JSON.JSONPathType):
def _processor(self, dialect, super_proc):
def process(value):
if isinstance(value, str):
# If it's already a string assume that it's in json path
# format. This allows using cast with json paths literals
return value
elif value:
# If it's already a string assume that it's in json path
# format. This allows using cast with json paths literals
value = "{%s}" % (", ".join(map(str, value)))
else:
value = "{}"
if super_proc:
value = super_proc(value)
return value
return process
def bind_processor(self, dialect):
return self._processor(dialect, self.string_bind_processor(dialect))
def literal_processor(self, dialect):
return self._processor(dialect, self.string_literal_processor(dialect))
class JSONPATH(JSONPathType):
"""JSON Path Type.
This is usually required to cast literal values to json path when using
json search like function, such as ``jsonb_path_query_array`` or
``jsonb_path_exists``::
stmt = sa.select(
sa.func.jsonb_path_query_array(
table.c.jsonb_col, cast("$.address.id", JSONPATH)
)
)
"""
__visit_name__ = "JSONPATH"
class JSON(sqltypes.JSON):
"""Represent the PostgreSQL JSON type.
:class:`_postgresql.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a PostgreSQL backend,
however base :class:`_types.JSON` datatype does not provide Python
accessors for PostgreSQL-specific comparison methods such as
:meth:`_postgresql.JSON.Comparator.astext`; additionally, to use
PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should
be used explicitly.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The operators provided by the PostgreSQL version of :class:`_types.JSON`
include:
* Index operations (the ``->`` operator)::
data_table.c.data['some key']
data_table.c.data[5]
* Index operations returning text (the ``->>`` operator)::
data_table.c.data['some key'].astext == 'some value'
Note that equivalent functionality is available via the
:attr:`.JSON.Comparator.as_string` accessor.
* Index operations with CAST
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
data_table.c.data['some key'].astext.cast(Integer) == 5
Note that equivalent functionality is available via the
:attr:`.JSON.Comparator.as_integer` and similar accessors.
* Path index operations (the ``#>`` operator)::
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
* Path index operations returning text (the ``#>>`` operator)::
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
Index operations return an expression object whose type defaults to
:class:`_types.JSON` by default,
so that further JSON-oriented instructions
may be called upon the result type.
Custom serializers and deserializers are specified at the dialect level,
that is using :func:`_sa.create_engine`. The reason for this is that when
using psycopg2, the DBAPI only allows serializers at the per-cursor
or per-connection level. E.g.::
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
json_serializer=my_serialize_fn,
json_deserializer=my_deserialize_fn
)
When using the psycopg2 dialect, the json_deserializer is registered
against the database using ``psycopg2.extras.register_default_json``.
.. seealso::
:class:`_types.JSON` - Core level JSON type
:class:`_postgresql.JSONB`
""" # noqa
astext_type = sqltypes.Text()
def __init__(self, none_as_null=False, astext_type=None):
"""Construct a :class:`_types.JSON` type.
:param none_as_null: if True, persist the value ``None`` as a
SQL NULL value, not the JSON encoding of ``null``. Note that
when this flag is False, the :func:`.null` construct can still
be used to persist a NULL value::
from sqlalchemy import null
conn.execute(table.insert(), data=null())
.. seealso::
:attr:`_types.JSON.NULL`
:param astext_type: the type to use for the
:attr:`.JSON.Comparator.astext`
accessor on indexed attributes. Defaults to :class:`_types.Text`.
"""
super().__init__(none_as_null=none_as_null)
if astext_type is not None:
self.astext_type = astext_type
class Comparator(sqltypes.JSON.Comparator):
"""Define comparison operations for :class:`_types.JSON`."""
@property
def astext(self):
"""On an indexed expression, use the "astext" (e.g. "->>")
conversion when rendered in SQL.
E.g.::
select(data_table.c.data['some key'].astext)
.. seealso::
:meth:`_expression.ColumnElement.cast`
"""
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
return self.expr.left.operate(
JSONPATH_ASTEXT,
self.expr.right,
result_type=self.type.astext_type,
)
else:
return self.expr.left.operate(
ASTEXT, self.expr.right, result_type=self.type.astext_type
)
comparator_factory = Comparator
class JSONB(JSON):
"""Represent the PostgreSQL JSONB type.
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
e.g.::
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', JSONB)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
The :class:`_postgresql.JSONB` type includes all operations provided by
:class:`_types.JSON`, including the same behaviors for indexing
operations.
It also adds additional operators specific to JSONB, including
:meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
:meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
:meth:`.JSONB.Comparator.contained_by`,
:meth:`.JSONB.Comparator.delete_path`,
:meth:`.JSONB.Comparator.path_exists` and
:meth:`.JSONB.Comparator.path_match`.
Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB`
type does not detect
in-place changes when used with the ORM, unless the
:mod:`sqlalchemy.ext.mutable` extension is used.
Custom serializers and deserializers
are shared with the :class:`_types.JSON` class,
using the ``json_serializer``
and ``json_deserializer`` keyword arguments. These must be specified
at the dialect level using :func:`_sa.create_engine`. When using
psycopg2, the serializers are associated with the jsonb type using
``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
in the same way that ``psycopg2.extras.register_default_json`` is used
to register these handlers with the json type.
.. seealso::
:class:`_types.JSON`
"""
__visit_name__ = "JSONB"
class Comparator(JSON.Comparator):
"""Define comparison operations for :class:`_types.JSON`."""
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def delete_path(self, array):
"""JSONB expression. Deletes field or array element specified in
the argument array.
The input may be a list of strings that will be coerced to an
``ARRAY`` or an instance of :meth:`_postgres.array`.
.. versionadded:: 2.0
"""
if not isinstance(array, _pg_array):
array = _pg_array(array)
right_side = cast(array, ARRAY(sqltypes.TEXT))
return self.operate(DELETE_PATH, right_side, result_type=JSONB)
def path_exists(self, other):
"""Boolean expression. Test for presence of item given by the
argument JSONPath expression.
.. versionadded:: 2.0
"""
return self.operate(
PATH_EXISTS, other, result_type=sqltypes.Boolean
)
def path_match(self, other):
"""Boolean expression. Test if JSONPath predicate given by the
argument JSONPath expression matches.
Only the first item of the result is taken into account.
.. versionadded:: 2.0
"""
return self.operate(
PATH_MATCH, other, result_type=sqltypes.Boolean
)
comparator_factory = Comparator

View file

@ -0,0 +1,495 @@
# dialects/postgresql/named_types.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from ... import schema
from ... import util
from ...sql import coercions
from ...sql import elements
from ...sql import roles
from ...sql import sqltypes
from ...sql import type_api
from ...sql.base import _NoArg
from ...sql.ddl import InvokeCreateDDLBase
from ...sql.ddl import InvokeDropDDLBase
if TYPE_CHECKING:
from ...sql._typing import _TypeEngineArgument
class NamedType(sqltypes.TypeEngine):
"""Base for named types."""
__abstract__ = True
DDLGenerator: Type[NamedTypeGenerator]
DDLDropper: Type[NamedTypeDropper]
create_type: bool
def create(self, bind, checkfirst=True, **kw):
"""Emit ``CREATE`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type does not exist already before
creating.
"""
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
def drop(self, bind, checkfirst=True, **kw):
"""Emit ``DROP`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type actually exists before dropping.
"""
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
note our name in that collection.
This to ensure a particular named type is operated
upon only once within any kind of create/drop
sequence without relying upon "checkfirst".
"""
if not self.create_type:
return True
if "_ddl_runner" in kw:
ddl_runner = kw["_ddl_runner"]
type_name = f"pg_{self.__visit_name__}"
if type_name in ddl_runner.memo:
existing = ddl_runner.memo[type_name]
else:
existing = ddl_runner.memo[type_name] = set()
present = (self.schema, self.name) in existing
existing.add((self.schema, self.name))
return present
else:
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
if (
checkfirst
or (
not self.metadata
and not kw.get("_is_metadata_operation", False)
)
) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
if (
not self.metadata
and not kw.get("_is_metadata_operation", False)
and not self._check_for_name_in_memos(checkfirst, kw)
):
self.drop(bind=bind, checkfirst=checkfirst)
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.drop(bind=bind, checkfirst=checkfirst)
class NamedTypeGenerator(InvokeCreateDDLBase):
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
super().__init__(connection, **kwargs)
self.checkfirst = checkfirst
def _can_create_type(self, type_):
if not self.checkfirst:
return True
effective_schema = self.connection.schema_for_object(type_)
return not self.connection.dialect.has_type(
self.connection, type_.name, schema=effective_schema
)
class NamedTypeDropper(InvokeDropDDLBase):
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
super().__init__(connection, **kwargs)
self.checkfirst = checkfirst
def _can_drop_type(self, type_):
if not self.checkfirst:
return True
effective_schema = self.connection.schema_for_object(type_)
return self.connection.dialect.has_type(
self.connection, type_.name, schema=effective_schema
)
class EnumGenerator(NamedTypeGenerator):
def visit_enum(self, enum):
if not self._can_create_type(enum):
return
with self.with_ddl_events(enum):
self.connection.execute(CreateEnumType(enum))
class EnumDropper(NamedTypeDropper):
def visit_enum(self, enum):
if not self._can_drop_type(enum):
return
with self.with_ddl_events(enum):
self.connection.execute(DropEnumType(enum))
class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
"""PostgreSQL ENUM type.
This is a subclass of :class:`_types.Enum` which includes
support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
When the builtin type :class:`_types.Enum` is used and the
:paramref:`.Enum.native_enum` flag is left at its default of
True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
type as the implementation, so the special create/drop rules
will be used.
The create/drop behavior of ENUM is necessarily intricate, due to the
awkward relationship the ENUM type has in relationship to the
parent table, in that it may be "owned" by just a single table, or
may be shared among many tables.
When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
corresponding to when the :meth:`_schema.Table.create` and
:meth:`_schema.Table.drop`
methods are called::
table = Table('sometable', metadata,
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
)
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
table.drop(engine) # will emit DROP TABLE and DROP ENUM
To use a common enumerated type between multiple tables, the best
practice is to declare the :class:`_types.Enum` or
:class:`_postgresql.ENUM` independently, and associate it with the
:class:`_schema.MetaData` object itself::
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
t1 = Table('sometable_one', metadata,
Column('some_enum', myenum)
)
t2 = Table('sometable_two', metadata,
Column('some_enum', myenum)
)
When this pattern is used, care must still be taken at the level
of individual table creates. Emitting CREATE TABLE without also
specifying ``checkfirst=True`` will still cause issues::
t1.create(engine) # will fail: no such type 'myenum'
If we specify ``checkfirst=True``, the individual table-level create
operation will check for the ``ENUM`` and create if not exists::
# will check if enum exists, and emit CREATE TYPE if not
t1.create(engine, checkfirst=True)
When using a metadata-level ENUM type, the type will always be created
and dropped if either the metadata-wide create/drop is called::
metadata.create_all(engine) # will emit CREATE TYPE
metadata.drop_all(engine) # will emit DROP TYPE
The type can also be created and dropped directly::
my_enum.create(engine)
my_enum.drop(engine)
"""
native_enum = True
DDLGenerator = EnumGenerator
DDLDropper = EnumDropper
def __init__(
self,
*enums,
name: Union[str, _NoArg, None] = _NoArg.NO_ARG,
create_type: bool = True,
**kw,
):
"""Construct an :class:`_postgresql.ENUM`.
Arguments are the same as that of
:class:`_types.Enum`, but also including
the following parameters.
:param create_type: Defaults to True.
Indicates that ``CREATE TYPE`` should be
emitted, after optionally checking for the
presence of the type, when the parent
table is being created; and additionally
that ``DROP TYPE`` is called when the table
is dropped. When ``False``, no check
will be performed and no ``CREATE TYPE``
or ``DROP TYPE`` is emitted, unless
:meth:`~.postgresql.ENUM.create`
or :meth:`~.postgresql.ENUM.drop`
are called directly.
Setting to ``False`` is helpful
when invoking a creation scheme to a SQL file
without access to the actual database -
the :meth:`~.postgresql.ENUM.create` and
:meth:`~.postgresql.ENUM.drop` methods can
be used to emit SQL to a target bind.
"""
native_enum = kw.pop("native_enum", None)
if native_enum is False:
util.warn(
"the native_enum flag does not apply to the "
"sqlalchemy.dialects.postgresql.ENUM datatype; this type "
"always refers to ENUM. Use sqlalchemy.types.Enum for "
"non-native enum."
)
self.create_type = create_type
if name is not _NoArg.NO_ARG:
kw["name"] = name
super().__init__(*enums, **kw)
def coerce_compared_value(self, op, value):
super_coerced_type = super().coerce_compared_value(op, value)
if (
super_coerced_type._type_affinity
is type_api.STRINGTYPE._type_affinity
):
return self
else:
return super_coerced_type
@classmethod
def __test_init__(cls):
return cls(name="name")
@classmethod
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
:class:`.Enum`.
"""
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("name", impl.name)
kw.setdefault("schema", impl.schema)
kw.setdefault("inherit_schema", impl.inherit_schema)
kw.setdefault("metadata", impl.metadata)
kw.setdefault("_create_events", False)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
kw.setdefault("_adapted_from", impl)
if type_api._is_native_for_emulated(impl.__class__):
kw.setdefault("create_type", impl.create_type)
return cls(**kw)
def create(self, bind=None, checkfirst=True):
"""Emit ``CREATE TYPE`` for this
:class:`_postgresql.ENUM`.
If the underlying dialect does not support
PostgreSQL CREATE TYPE, no action is taken.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type does not exist already before
creating.
"""
if not bind.dialect.supports_native_enum:
return
super().create(bind, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Emit ``DROP TYPE`` for this
:class:`_postgresql.ENUM`.
If the underlying dialect does not support
PostgreSQL DROP TYPE, no action is taken.
:param bind: a connectable :class:`_engine.Engine`,
:class:`_engine.Connection`, or similar object to emit
SQL.
:param checkfirst: if ``True``, a query against
the PG catalog will be first performed to see
if the type actually exists before dropping.
"""
if not bind.dialect.supports_native_enum:
return
super().drop(bind, checkfirst=checkfirst)
def get_dbapi_type(self, dbapi):
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
a different type"""
return None
class DomainGenerator(NamedTypeGenerator):
def visit_DOMAIN(self, domain):
if not self._can_create_type(domain):
return
with self.with_ddl_events(domain):
self.connection.execute(CreateDomainType(domain))
class DomainDropper(NamedTypeDropper):
def visit_DOMAIN(self, domain):
if not self._can_drop_type(domain):
return
with self.with_ddl_events(domain):
self.connection.execute(DropDomainType(domain))
class DOMAIN(NamedType, sqltypes.SchemaType):
r"""Represent the DOMAIN PostgreSQL type.
A domain is essentially a data type with optional constraints
that restrict the allowed set of values. E.g.::
PositiveInt = DOMAIN(
"pos_int", Integer, check="VALUE > 0", not_null=True
)
UsPostalCode = DOMAIN(
"us_postal_code",
Text,
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
)
See the `PostgreSQL documentation`__ for additional details
__ https://www.postgresql.org/docs/current/sql-createdomain.html
.. versionadded:: 2.0
"""
DDLGenerator = DomainGenerator
DDLDropper = DomainDropper
__visit_name__ = "DOMAIN"
def __init__(
self,
name: str,
data_type: _TypeEngineArgument[Any],
*,
collation: Optional[str] = None,
default: Optional[Union[str, elements.TextClause]] = None,
constraint_name: Optional[str] = None,
not_null: Optional[bool] = None,
check: Optional[str] = None,
create_type: bool = True,
**kw: Any,
):
"""
Construct a DOMAIN.
:param name: the name of the domain
:param data_type: The underlying data type of the domain.
This can include array specifiers.
:param collation: An optional collation for the domain.
If no collation is specified, the underlying data type's default
collation is used. The underlying type must be collatable if
``collation`` is specified.
:param default: The DEFAULT clause specifies a default value for
columns of the domain data type. The default should be a string
or a :func:`_expression.text` value.
If no default value is specified, then the default value is
the null value.
:param constraint_name: An optional name for a constraint.
If not specified, the backend generates a name.
:param not_null: Values of this domain are prevented from being null.
By default domain are allowed to be null. If not specified
no nullability clause will be emitted.
:param check: CHECK clause specify integrity constraint or test
which values of the domain must satisfy. A constraint must be
an expression producing a Boolean result that can use the key
word VALUE to refer to the value being tested.
Differently from PostgreSQL, only a single check clause is
currently allowed in SQLAlchemy.
:param schema: optional schema name
:param metadata: optional :class:`_schema.MetaData` object which
this :class:`_postgresql.DOMAIN` will be directly associated
:param create_type: Defaults to True.
Indicates that ``CREATE TYPE`` should be emitted, after optionally
checking for the presence of the type, when the parent table is
being created; and additionally that ``DROP TYPE`` is called
when the table is dropped.
"""
self.data_type = type_api.to_instance(data_type)
self.default = default
self.collation = collation
self.constraint_name = constraint_name
self.not_null = not_null
if check is not None:
check = coercions.expect(roles.DDLExpressionRole, check)
self.check = check
self.create_type = create_type
super().__init__(name=name, **kw)
@classmethod
def __test_init__(cls):
return cls("name", sqltypes.Integer)
class CreateEnumType(schema._CreateDropBase):
__visit_name__ = "create_enum_type"
class DropEnumType(schema._CreateDropBase):
__visit_name__ = "drop_enum_type"
class CreateDomainType(schema._CreateDropBase):
"""Represent a CREATE DOMAIN statement."""
__visit_name__ = "create_domain_type"
class DropDomainType(schema._CreateDropBase):
"""Represent a DROP DOMAIN statement."""
__visit_name__ = "drop_domain_type"

View file

@ -0,0 +1,129 @@
# dialects/postgresql/operators.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ...sql import operators
_getitem_precedence = operators._PRECEDENCE[operators.json_getitem_op]
_eq_precedence = operators._PRECEDENCE[operators.eq]
# JSON + JSONB
ASTEXT = operators.custom_op(
"->>",
precedence=_getitem_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
JSONPATH_ASTEXT = operators.custom_op(
"#>>",
precedence=_getitem_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
# JSONB + HSTORE
HAS_KEY = operators.custom_op(
"?",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
HAS_ALL = operators.custom_op(
"?&",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
HAS_ANY = operators.custom_op(
"?|",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
# JSONB
DELETE_PATH = operators.custom_op(
"#-",
precedence=_getitem_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
PATH_EXISTS = operators.custom_op(
"@?",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
PATH_MATCH = operators.custom_op(
"@@",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
# JSONB + ARRAY + HSTORE + RANGE
CONTAINS = operators.custom_op(
"@>",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
CONTAINED_BY = operators.custom_op(
"<@",
precedence=_eq_precedence,
natural_self_precedent=True,
eager_grouping=True,
is_comparison=True,
)
# ARRAY + RANGE
OVERLAP = operators.custom_op(
"&&",
precedence=_eq_precedence,
is_comparison=True,
)
# RANGE
STRICTLY_LEFT_OF = operators.custom_op(
"<<", precedence=_eq_precedence, is_comparison=True
)
STRICTLY_RIGHT_OF = operators.custom_op(
">>", precedence=_eq_precedence, is_comparison=True
)
NOT_EXTEND_RIGHT_OF = operators.custom_op(
"&<", precedence=_eq_precedence, is_comparison=True
)
NOT_EXTEND_LEFT_OF = operators.custom_op(
"&>", precedence=_eq_precedence, is_comparison=True
)
ADJACENT_TO = operators.custom_op(
"-|-", precedence=_eq_precedence, is_comparison=True
)
# HSTORE
GETITEM = operators.custom_op(
"->",
precedence=_getitem_precedence,
natural_self_precedent=True,
eager_grouping=True,
)

View file

@ -0,0 +1,662 @@
# dialects/postgresql/pg8000.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: postgresql+pg8000
:name: pg8000
:dbapi: pg8000
:connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/pg8000/
.. versionchanged:: 1.4 The pg8000 dialect has been updated for version
1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
with full feature support.
.. _pg8000_unicode:
Unicode
-------
pg8000 will encode / decode string values between it and the server using the
PostgreSQL ``client_encoding`` parameter; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf-8``, as a more useful default::
#client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
The ``client_encoding`` can be overridden for a session by executing the SQL:
SET CLIENT_ENCODING TO 'utf8';
SQLAlchemy will execute this SQL on all new connections based on the value
passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
engine = create_engine(
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
.. _pg8000_ssl:
SSL Connections
---------------
pg8000 accepts a Python ``SSLContext`` object which may be specified using the
:paramref:`_sa.create_engine.connect_args` dictionary::
import ssl
ssl_context = ssl.create_default_context()
engine = sa.create_engine(
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
connect_args={"ssl_context": ssl_context},
)
If the server uses an automatically-generated certificate that is self-signed
or does not match the host name (as seen from the client), it may also be
necessary to disable hostname checking::
import ssl
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
engine = sa.create_engine(
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
connect_args={"ssl_context": ssl_context},
)
.. _pg8000_isolation_level:
pg8000 Transaction Isolation Level
-------------------------------------
The pg8000 dialect offers the same isolation level settings as that
of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`psycopg2_isolation_level`
""" # noqa
import decimal
import re
from . import ranges
from .array import ARRAY as PGARRAY
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
from .base import ENUM
from .base import INTERVAL
from .base import PGCompiler
from .base import PGDialect
from .base import PGExecutionContext
from .base import PGIdentifierPreparer
from .json import JSON
from .json import JSONB
from .json import JSONPathType
from .pg_catalog import _SpaceVector
from .pg_catalog import OIDVECTOR
from .types import CITEXT
from ... import exc
from ... import util
from ...engine import processors
from ...sql import sqltypes
from ...sql.elements import quoted_name
class _PGString(sqltypes.String):
render_bind_cast = True
class _PGNumeric(sqltypes.Numeric):
render_bind_cast = True
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal, self._effective_decimal_return_scale
)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
else:
if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701
return None
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
class _PGFloat(_PGNumeric, sqltypes.Float):
__visit_name__ = "float"
render_bind_cast = True
class _PGNumericNoBind(_PGNumeric):
def bind_processor(self, dialect):
return None
class _PGJSON(JSON):
render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class _PGJSONB(JSONB):
render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
def get_dbapi_type(self, dbapi):
raise NotImplementedError("should not be here")
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
__visit_name__ = "json_int_index"
render_bind_cast = True
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
__visit_name__ = "json_str_index"
render_bind_cast = True
class _PGJSONPathType(JSONPathType):
pass
# DBAPI type 1009
class _PGEnum(ENUM):
def get_dbapi_type(self, dbapi):
return dbapi.UNKNOWN
class _PGInterval(INTERVAL):
render_bind_cast = True
def get_dbapi_type(self, dbapi):
return dbapi.INTERVAL
@classmethod
def adapt_emulated_to_native(cls, interval, **kw):
return _PGInterval(precision=interval.second_precision)
class _PGTimeStamp(sqltypes.DateTime):
render_bind_cast = True
class _PGDate(sqltypes.Date):
render_bind_cast = True
class _PGTime(sqltypes.Time):
render_bind_cast = True
class _PGInteger(sqltypes.Integer):
render_bind_cast = True
class _PGSmallInteger(sqltypes.SmallInteger):
render_bind_cast = True
class _PGNullType(sqltypes.NullType):
pass
class _PGBigInteger(sqltypes.BigInteger):
render_bind_cast = True
class _PGBoolean(sqltypes.Boolean):
render_bind_cast = True
class _PGARRAY(PGARRAY):
render_bind_cast = True
class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
pass
class _Pg8000Range(ranges.AbstractSingleRangeImpl):
def bind_processor(self, dialect):
pg8000_Range = dialect.dbapi.Range
def to_range(value):
if isinstance(value, ranges.Range):
value = pg8000_Range(
value.lower, value.upper, value.bounds, value.empty
)
return value
return to_range
def result_processor(self, dialect, coltype):
def to_range(value):
if value is not None:
value = ranges.Range(
value.lower,
value.upper,
bounds=value.bounds,
empty=value.is_empty,
)
return value
return to_range
class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl):
def bind_processor(self, dialect):
pg8000_Range = dialect.dbapi.Range
def to_multirange(value):
if isinstance(value, list):
mr = []
for v in value:
if isinstance(v, ranges.Range):
mr.append(
pg8000_Range(v.lower, v.upper, v.bounds, v.empty)
)
else:
mr.append(v)
return mr
else:
return value
return to_multirange
def result_processor(self, dialect, coltype):
def to_multirange(value):
if value is None:
return None
else:
return ranges.MultiRange(
ranges.Range(
v.lower, v.upper, bounds=v.bounds, empty=v.is_empty
)
for v in value
)
return to_multirange
_server_side_id = util.counter()
class PGExecutionContext_pg8000(PGExecutionContext):
def create_server_side_cursor(self):
ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
return ServerSideCursor(self._dbapi_connection.cursor(), ident)
def pre_exec(self):
if not self.compiled:
return
class ServerSideCursor:
server_side = True
def __init__(self, cursor, ident):
self.ident = ident
self.cursor = cursor
@property
def connection(self):
return self.cursor.connection
@property
def rowcount(self):
return self.cursor.rowcount
@property
def description(self):
return self.cursor.description
def execute(self, operation, args=(), stream=None):
op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
self.cursor.execute(op, args, stream=stream)
return self
def executemany(self, operation, param_sets):
self.cursor.executemany(operation, param_sets)
return self
def fetchone(self):
self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
return self.cursor.fetchone()
def fetchmany(self, num=None):
if num is None:
return self.fetchall()
else:
self.cursor.execute(
"FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
)
return self.cursor.fetchall()
def fetchall(self):
self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
return self.cursor.fetchall()
def close(self):
self.cursor.execute("CLOSE " + self.ident)
self.cursor.close()
def setinputsizes(self, *sizes):
self.cursor.setinputsizes(*sizes)
def setoutputsize(self, size, column=None):
pass
class PGCompiler_pg8000(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return (
self.process(binary.left, **kw)
+ " %% "
+ self.process(binary.right, **kw)
)
class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
def __init__(self, *args, **kwargs):
PGIdentifierPreparer.__init__(self, *args, **kwargs)
self._double_percents = False
class PGDialect_pg8000(PGDialect):
driver = "pg8000"
supports_statement_cache = True
supports_unicode_statements = True
supports_unicode_binds = True
default_paramstyle = "format"
supports_sane_multi_rowcount = True
execution_ctx_cls = PGExecutionContext_pg8000
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
supports_server_side_cursors = True
render_bind_cast = True
# reversed as of pg8000 1.16.6. 1.16.5 and lower
# are no longer compatible
description_encoding = None
# description_encoding = "use_encoding"
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.String: _PGString,
sqltypes.Numeric: _PGNumericNoBind,
sqltypes.Float: _PGFloat,
sqltypes.JSON: _PGJSON,
sqltypes.Boolean: _PGBoolean,
sqltypes.NullType: _PGNullType,
JSONB: _PGJSONB,
CITEXT: CITEXT,
sqltypes.JSON.JSONPathType: _PGJSONPathType,
sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
sqltypes.Interval: _PGInterval,
INTERVAL: _PGInterval,
sqltypes.DateTime: _PGTimeStamp,
sqltypes.DateTime: _PGTimeStamp,
sqltypes.Date: _PGDate,
sqltypes.Time: _PGTime,
sqltypes.Integer: _PGInteger,
sqltypes.SmallInteger: _PGSmallInteger,
sqltypes.BigInteger: _PGBigInteger,
sqltypes.Enum: _PGEnum,
sqltypes.ARRAY: _PGARRAY,
OIDVECTOR: _PGOIDVECTOR,
ranges.INT4RANGE: _Pg8000Range,
ranges.INT8RANGE: _Pg8000Range,
ranges.NUMRANGE: _Pg8000Range,
ranges.DATERANGE: _Pg8000Range,
ranges.TSRANGE: _Pg8000Range,
ranges.TSTZRANGE: _Pg8000Range,
ranges.INT4MULTIRANGE: _Pg8000MultiRange,
ranges.INT8MULTIRANGE: _Pg8000MultiRange,
ranges.NUMMULTIRANGE: _Pg8000MultiRange,
ranges.DATEMULTIRANGE: _Pg8000MultiRange,
ranges.TSMULTIRANGE: _Pg8000MultiRange,
ranges.TSTZMULTIRANGE: _Pg8000MultiRange,
},
)
def __init__(self, client_encoding=None, **kwargs):
PGDialect.__init__(self, **kwargs)
self.client_encoding = client_encoding
if self._dbapi_version < (1, 16, 6):
raise NotImplementedError("pg8000 1.16.6 or greater is required")
if self._native_inet_types:
raise NotImplementedError(
"The pg8000 dialect does not fully implement "
"ipaddress type handling; INET is supported by default, "
"CIDR is not"
)
@util.memoized_property
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
int(x)
for x in re.findall(
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
)
]
)
else:
return (99, 99, 99)
@classmethod
def import_dbapi(cls):
return __import__("pg8000")
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
if "port" in opts:
opts["port"] = int(opts["port"])
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
e
):
# new as of pg8000 1.19.0 for broken connections
return True
# connection was closed normally
return "connection is closed" in str(e)
def get_isolation_level_values(self, dbapi_connection):
return (
"AUTOCOMMIT",
"READ COMMITTED",
"READ UNCOMMITTED",
"REPEATABLE READ",
"SERIALIZABLE",
)
def set_isolation_level(self, dbapi_connection, level):
level = level.replace("_", " ")
if level == "AUTOCOMMIT":
dbapi_connection.autocommit = True
else:
dbapi_connection.autocommit = False
cursor = dbapi_connection.cursor()
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
f"ISOLATION LEVEL {level}"
)
cursor.execute("COMMIT")
cursor.close()
def set_readonly(self, connection, value):
cursor = connection.cursor()
try:
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION %s"
% ("READ ONLY" if value else "READ WRITE")
)
cursor.execute("COMMIT")
finally:
cursor.close()
def get_readonly(self, connection):
cursor = connection.cursor()
try:
cursor.execute("show transaction_read_only")
val = cursor.fetchone()[0]
finally:
cursor.close()
return val == "on"
def set_deferrable(self, connection, value):
cursor = connection.cursor()
try:
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION %s"
% ("DEFERRABLE" if value else "NOT DEFERRABLE")
)
cursor.execute("COMMIT")
finally:
cursor.close()
def get_deferrable(self, connection):
cursor = connection.cursor()
try:
cursor.execute("show transaction_deferrable")
val = cursor.fetchone()[0]
finally:
cursor.close()
return val == "on"
def _set_client_encoding(self, dbapi_connection, client_encoding):
cursor = dbapi_connection.cursor()
cursor.execute(
f"""SET CLIENT_ENCODING TO '{
client_encoding.replace("'", "''")
}'"""
)
cursor.execute("COMMIT")
cursor.close()
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ""))
def do_prepare_twophase(self, connection, xid):
connection.connection.tpc_prepare()
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
connection.connection.tpc_rollback((0, xid, ""))
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
connection.connection.tpc_commit((0, xid, ""))
def do_recover_twophase(self, connection):
return [row[1] for row in connection.connection.tpc_recover()]
def on_connect(self):
fns = []
def on_connect(conn):
conn.py_types[quoted_name] = conn.py_types[str]
fns.append(on_connect)
if self.client_encoding is not None:
def on_connect(conn):
self._set_client_encoding(conn, self.client_encoding)
fns.append(on_connect)
if self._native_inet_types is False:
def on_connect(conn):
# inet
conn.register_in_adapter(869, lambda s: s)
# cidr
conn.register_in_adapter(650, lambda s: s)
fns.append(on_connect)
if self._json_deserializer:
def on_connect(conn):
# json
conn.register_in_adapter(114, self._json_deserializer)
# jsonb
conn.register_in_adapter(3802, self._json_deserializer)
fns.append(on_connect)
if len(fns) > 0:
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
else:
return None
@util.memoized_property
def _dialect_specific_select_one(self):
return ";"
dialect = PGDialect_pg8000

View file

@ -0,0 +1,294 @@
# dialects/postgresql/pg_catalog.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .array import ARRAY
from .types import OID
from .types import REGCLASS
from ... import Column
from ... import func
from ... import MetaData
from ... import Table
from ...types import BigInteger
from ...types import Boolean
from ...types import CHAR
from ...types import Float
from ...types import Integer
from ...types import SmallInteger
from ...types import String
from ...types import Text
from ...types import TypeDecorator
# types
class NAME(TypeDecorator):
impl = String(64, collation="C")
cache_ok = True
class PG_NODE_TREE(TypeDecorator):
impl = Text(collation="C")
cache_ok = True
class INT2VECTOR(TypeDecorator):
impl = ARRAY(SmallInteger)
cache_ok = True
class OIDVECTOR(TypeDecorator):
impl = ARRAY(OID)
cache_ok = True
class _SpaceVector:
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return value
return [int(p) for p in value.split(" ")]
return process
REGPROC = REGCLASS # seems an alias
# functions
_pg_cat = func.pg_catalog
quote_ident = _pg_cat.quote_ident
pg_table_is_visible = _pg_cat.pg_table_is_visible
pg_type_is_visible = _pg_cat.pg_type_is_visible
pg_get_viewdef = _pg_cat.pg_get_viewdef
pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence
format_type = _pg_cat.format_type
pg_get_expr = _pg_cat.pg_get_expr
pg_get_constraintdef = _pg_cat.pg_get_constraintdef
pg_get_indexdef = _pg_cat.pg_get_indexdef
# constants
RELKINDS_TABLE_NO_FOREIGN = ("r", "p")
RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",)
RELKINDS_VIEW = ("v",)
RELKINDS_MAT_VIEW = ("m",)
RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
# tables
pg_catalog_meta = MetaData()
pg_namespace = Table(
"pg_namespace",
pg_catalog_meta,
Column("oid", OID),
Column("nspname", NAME),
Column("nspowner", OID),
schema="pg_catalog",
)
pg_class = Table(
"pg_class",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("relname", NAME),
Column("relnamespace", OID),
Column("reltype", OID),
Column("reloftype", OID),
Column("relowner", OID),
Column("relam", OID),
Column("relfilenode", OID),
Column("reltablespace", OID),
Column("relpages", Integer),
Column("reltuples", Float),
Column("relallvisible", Integer, info={"server_version": (9, 2)}),
Column("reltoastrelid", OID),
Column("relhasindex", Boolean),
Column("relisshared", Boolean),
Column("relpersistence", CHAR, info={"server_version": (9, 1)}),
Column("relkind", CHAR),
Column("relnatts", SmallInteger),
Column("relchecks", SmallInteger),
Column("relhasrules", Boolean),
Column("relhastriggers", Boolean),
Column("relhassubclass", Boolean),
Column("relrowsecurity", Boolean),
Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}),
Column("relispopulated", Boolean, info={"server_version": (9, 3)}),
Column("relreplident", CHAR, info={"server_version": (9, 4)}),
Column("relispartition", Boolean, info={"server_version": (10,)}),
Column("relrewrite", OID, info={"server_version": (11,)}),
Column("reloptions", ARRAY(Text)),
schema="pg_catalog",
)
pg_type = Table(
"pg_type",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("typname", NAME),
Column("typnamespace", OID),
Column("typowner", OID),
Column("typlen", SmallInteger),
Column("typbyval", Boolean),
Column("typtype", CHAR),
Column("typcategory", CHAR),
Column("typispreferred", Boolean),
Column("typisdefined", Boolean),
Column("typdelim", CHAR),
Column("typrelid", OID),
Column("typelem", OID),
Column("typarray", OID),
Column("typinput", REGPROC),
Column("typoutput", REGPROC),
Column("typreceive", REGPROC),
Column("typsend", REGPROC),
Column("typmodin", REGPROC),
Column("typmodout", REGPROC),
Column("typanalyze", REGPROC),
Column("typalign", CHAR),
Column("typstorage", CHAR),
Column("typnotnull", Boolean),
Column("typbasetype", OID),
Column("typtypmod", Integer),
Column("typndims", Integer),
Column("typcollation", OID, info={"server_version": (9, 1)}),
Column("typdefault", Text),
schema="pg_catalog",
)
pg_index = Table(
"pg_index",
pg_catalog_meta,
Column("indexrelid", OID),
Column("indrelid", OID),
Column("indnatts", SmallInteger),
Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}),
Column("indisunique", Boolean),
Column("indnullsnotdistinct", Boolean, info={"server_version": (15,)}),
Column("indisprimary", Boolean),
Column("indisexclusion", Boolean, info={"server_version": (9, 1)}),
Column("indimmediate", Boolean),
Column("indisclustered", Boolean),
Column("indisvalid", Boolean),
Column("indcheckxmin", Boolean),
Column("indisready", Boolean),
Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3
Column("indisreplident", Boolean),
Column("indkey", INT2VECTOR),
Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1
Column("indclass", OIDVECTOR),
Column("indoption", INT2VECTOR),
Column("indexprs", PG_NODE_TREE),
Column("indpred", PG_NODE_TREE),
schema="pg_catalog",
)
pg_attribute = Table(
"pg_attribute",
pg_catalog_meta,
Column("attrelid", OID),
Column("attname", NAME),
Column("atttypid", OID),
Column("attstattarget", Integer),
Column("attlen", SmallInteger),
Column("attnum", SmallInteger),
Column("attndims", Integer),
Column("attcacheoff", Integer),
Column("atttypmod", Integer),
Column("attbyval", Boolean),
Column("attstorage", CHAR),
Column("attalign", CHAR),
Column("attnotnull", Boolean),
Column("atthasdef", Boolean),
Column("atthasmissing", Boolean, info={"server_version": (11,)}),
Column("attidentity", CHAR, info={"server_version": (10,)}),
Column("attgenerated", CHAR, info={"server_version": (12,)}),
Column("attisdropped", Boolean),
Column("attislocal", Boolean),
Column("attinhcount", Integer),
Column("attcollation", OID, info={"server_version": (9, 1)}),
schema="pg_catalog",
)
pg_constraint = Table(
"pg_constraint",
pg_catalog_meta,
Column("oid", OID), # 9.3
Column("conname", NAME),
Column("connamespace", OID),
Column("contype", CHAR),
Column("condeferrable", Boolean),
Column("condeferred", Boolean),
Column("convalidated", Boolean, info={"server_version": (9, 1)}),
Column("conrelid", OID),
Column("contypid", OID),
Column("conindid", OID),
Column("conparentid", OID, info={"server_version": (11,)}),
Column("confrelid", OID),
Column("confupdtype", CHAR),
Column("confdeltype", CHAR),
Column("confmatchtype", CHAR),
Column("conislocal", Boolean),
Column("coninhcount", Integer),
Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
Column("conkey", ARRAY(SmallInteger)),
Column("confkey", ARRAY(SmallInteger)),
schema="pg_catalog",
)
pg_sequence = Table(
"pg_sequence",
pg_catalog_meta,
Column("seqrelid", OID),
Column("seqtypid", OID),
Column("seqstart", BigInteger),
Column("seqincrement", BigInteger),
Column("seqmax", BigInteger),
Column("seqmin", BigInteger),
Column("seqcache", BigInteger),
Column("seqcycle", Boolean),
schema="pg_catalog",
info={"server_version": (10,)},
)
pg_attrdef = Table(
"pg_attrdef",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("adrelid", OID),
Column("adnum", SmallInteger),
Column("adbin", PG_NODE_TREE),
schema="pg_catalog",
)
pg_description = Table(
"pg_description",
pg_catalog_meta,
Column("objoid", OID),
Column("classoid", OID),
Column("objsubid", Integer),
Column("description", Text(collation="C")),
schema="pg_catalog",
)
pg_enum = Table(
"pg_enum",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("enumtypid", OID),
Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
Column("enumlabel", NAME),
schema="pg_catalog",
)
pg_am = Table(
"pg_am",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("amname", NAME),
Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
Column("amtype", CHAR, info={"server_version": (9, 6)}),
schema="pg_catalog",
)

View file

@ -0,0 +1,175 @@
# dialects/postgresql/provision.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import time
from ... import exc
from ... import inspect
from ... import text
from ...testing import warn_test_suite
from ...testing.provision import create_db
from ...testing.provision import drop_all_schema_objects_post_tables
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import prepare_for_drop_tables
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import temp_table_keyword_args
from ...testing.provision import upsert
@create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
if not template_db:
template_db = conn.exec_driver_sql(
"select current_database()"
).scalar()
attempt = 0
while True:
try:
conn.exec_driver_sql(
"CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
)
except exc.OperationalError as err:
attempt += 1
if attempt >= 3:
raise
if "accessed by other users" in str(err):
log.info(
"Waiting to create %s, URI %r, "
"template DB %s is in use sleeping for .5",
ident,
eng.url,
template_db,
)
time.sleep(0.5)
except:
raise
else:
break
@drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
with conn.begin():
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
),
dict(dname=ident),
)
conn.exec_driver_sql("DROP DATABASE %s" % ident)
@temp_table_keyword_args.for_db("postgresql")
def _postgresql_temp_table_keyword_args(cfg, eng):
return {"prefixes": ["TEMPORARY"]}
@set_default_schema_on_connection.for_db("postgresql")
def _postgresql_set_default_schema_on_connection(
cfg, dbapi_connection, schema_name
):
existing_autocommit = dbapi_connection.autocommit
dbapi_connection.autocommit = True
cursor = dbapi_connection.cursor()
cursor.execute("SET SESSION search_path='%s'" % schema_name)
cursor.close()
dbapi_connection.autocommit = existing_autocommit
@drop_all_schema_objects_pre_tables.for_db("postgresql")
def drop_all_schema_objects_pre_tables(cfg, eng):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
for xid in conn.exec_driver_sql(
"select gid from pg_prepared_xacts"
).scalars():
conn.execute("ROLLBACK PREPARED '%s'" % xid)
@drop_all_schema_objects_post_tables.for_db("postgresql")
def drop_all_schema_objects_post_tables(cfg, eng):
from sqlalchemy.dialects import postgresql
inspector = inspect(eng)
with eng.begin() as conn:
for enum in inspector.get_enums("*"):
conn.execute(
postgresql.DropEnumType(
postgresql.ENUM(name=enum["name"], schema=enum["schema"])
)
)
@prepare_for_drop_tables.for_db("postgresql")
def prepare_for_drop_tables(config, connection):
"""Ensure there are no locks on the current username/database."""
result = connection.exec_driver_sql(
"select pid, state, wait_event_type, query "
# "select pg_terminate_backend(pid), state, wait_event_type "
"from pg_stat_activity where "
"usename=current_user "
"and datname=current_database() and state='idle in transaction' "
"and pid != pg_backend_pid()"
)
rows = result.all() # noqa
if rows:
warn_test_suite(
"PostgreSQL may not be able to DROP tables due to "
"idle in transaction: %s"
% ("; ".join(row._mapping["query"] for row in rows))
)
@upsert.for_db("postgresql")
def _upsert(
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
):
from sqlalchemy.dialects.postgresql import insert
stmt = insert(table)
table_pk = inspect(table).selectable
if set_lambda:
stmt = stmt.on_conflict_do_update(
index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
)
else:
stmt = stmt.on_conflict_do_nothing()
stmt = stmt.returning(
*returning, sort_by_parameter_order=sort_by_parameter_order
)
return stmt
_extensions = [
("citext", (13,)),
("hstore", (13,)),
]
@post_configure_engine.for_db("postgresql")
def _create_citext_extension(url, engine, follower_ident):
with engine.connect() as conn:
for extension, min_version in _extensions:
if conn.dialect.server_version_info >= min_version:
conn.execute(
text(f"CREATE EXTENSION IF NOT EXISTS {extension}")
)
conn.commit()

View file

@ -0,0 +1,749 @@
# dialects/postgresql/psycopg.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: postgresql+psycopg
:name: psycopg (a.k.a. psycopg 3)
:dbapi: psycopg
:connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/psycopg/
``psycopg`` is the package and module name for version 3 of the ``psycopg``
database driver, formerly known as ``psycopg2``. This driver is different
enough from its ``psycopg2`` predecessor that SQLAlchemy supports it
via a totally separate dialect; support for ``psycopg2`` is expected to remain
for as long as that package continues to function for modern Python versions,
and also remains the default dialect for the ``postgresql://`` dialect
series.
The SQLAlchemy ``psycopg`` dialect provides both a sync and an async
implementation under the same dialect name. The proper version is
selected depending on how the engine is created:
* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will
automatically select the sync version, e.g.::
from sqlalchemy import create_engine
sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
* calling :func:`_asyncio.create_async_engine` with
``postgresql+psycopg://...`` will automatically select the async version,
e.g.::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
The asyncio version of the dialect may also be specified explicitly using the
``psycopg_async`` suffix, as::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
.. seealso::
:ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg``
dialect shares most of its behavior with the ``psycopg2`` dialect.
Further documentation is available there.
""" # noqa
from __future__ import annotations
import logging
import re
from typing import cast
from typing import TYPE_CHECKING
from . import ranges
from ._psycopg_common import _PGDialect_common_psycopg
from ._psycopg_common import _PGExecutionContext_common_psycopg
from .base import INTERVAL
from .base import PGCompiler
from .base import PGIdentifierPreparer
from .base import REGCONFIG
from .json import JSON
from .json import JSONB
from .json import JSONPathType
from .types import CITEXT
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...sql import sqltypes
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from typing import Iterable
from psycopg import AsyncConnection
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
class _PGString(sqltypes.String):
render_bind_cast = True
class _PGREGCONFIG(REGCONFIG):
render_bind_cast = True
class _PGJSON(JSON):
render_bind_cast = True
def bind_processor(self, dialect):
return self._make_bind_processor(None, dialect._psycopg_Json)
def result_processor(self, dialect, coltype):
return None
class _PGJSONB(JSONB):
render_bind_cast = True
def bind_processor(self, dialect):
return self._make_bind_processor(None, dialect._psycopg_Jsonb)
def result_processor(self, dialect, coltype):
return None
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
__visit_name__ = "json_int_index"
render_bind_cast = True
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
__visit_name__ = "json_str_index"
render_bind_cast = True
class _PGJSONPathType(JSONPathType):
pass
class _PGInterval(INTERVAL):
render_bind_cast = True
class _PGTimeStamp(sqltypes.DateTime):
render_bind_cast = True
class _PGDate(sqltypes.Date):
render_bind_cast = True
class _PGTime(sqltypes.Time):
render_bind_cast = True
class _PGInteger(sqltypes.Integer):
render_bind_cast = True
class _PGSmallInteger(sqltypes.SmallInteger):
render_bind_cast = True
class _PGNullType(sqltypes.NullType):
render_bind_cast = True
class _PGBigInteger(sqltypes.BigInteger):
render_bind_cast = True
class _PGBoolean(sqltypes.Boolean):
render_bind_cast = True
class _PsycopgRange(ranges.AbstractSingleRangeImpl):
def bind_processor(self, dialect):
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
def to_range(value):
if isinstance(value, ranges.Range):
value = psycopg_Range(
value.lower, value.upper, value.bounds, value.empty
)
return value
return to_range
def result_processor(self, dialect, coltype):
def to_range(value):
if value is not None:
value = ranges.Range(
value._lower,
value._upper,
bounds=value._bounds if value._bounds else "[)",
empty=not value._bounds,
)
return value
return to_range
class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
def bind_processor(self, dialect):
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
psycopg_Multirange = cast(
PGDialect_psycopg, dialect
)._psycopg_Multirange
NoneType = type(None)
def to_range(value):
if isinstance(value, (str, NoneType, psycopg_Multirange)):
return value
return psycopg_Multirange(
[
psycopg_Range(
element.lower,
element.upper,
element.bounds,
element.empty,
)
for element in cast("Iterable[ranges.Range]", value)
]
)
return to_range
def result_processor(self, dialect, coltype):
def to_range(value):
if value is None:
return None
else:
return ranges.MultiRange(
ranges.Range(
elem._lower,
elem._upper,
bounds=elem._bounds if elem._bounds else "[)",
empty=not elem._bounds,
)
for elem in value
)
return to_range
class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg):
pass
class PGCompiler_psycopg(PGCompiler):
pass
class PGIdentifierPreparer_psycopg(PGIdentifierPreparer):
pass
def _log_notices(diagnostic):
logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary)
class PGDialect_psycopg(_PGDialect_common_psycopg):
driver = "psycopg"
supports_statement_cache = True
supports_server_side_cursors = True
default_paramstyle = "pyformat"
supports_sane_multi_rowcount = True
execution_ctx_cls = PGExecutionContext_psycopg
statement_compiler = PGCompiler_psycopg
preparer = PGIdentifierPreparer_psycopg
psycopg_version = (0, 0)
_has_native_hstore = True
_psycopg_adapters_map = None
colspecs = util.update_copy(
_PGDialect_common_psycopg.colspecs,
{
sqltypes.String: _PGString,
REGCONFIG: _PGREGCONFIG,
JSON: _PGJSON,
CITEXT: CITEXT,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
sqltypes.JSON.JSONPathType: _PGJSONPathType,
sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
sqltypes.Interval: _PGInterval,
INTERVAL: _PGInterval,
sqltypes.Date: _PGDate,
sqltypes.DateTime: _PGTimeStamp,
sqltypes.Time: _PGTime,
sqltypes.Integer: _PGInteger,
sqltypes.SmallInteger: _PGSmallInteger,
sqltypes.BigInteger: _PGBigInteger,
ranges.AbstractSingleRange: _PsycopgRange,
ranges.AbstractMultiRange: _PsycopgMultiRange,
},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.dbapi:
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
self.psycopg_version = tuple(
int(x) for x in m.group(1, 2, 3) if x is not None
)
if self.psycopg_version < (3, 0, 2):
raise ImportError(
"psycopg version 3.0.2 or higher is required."
)
from psycopg.adapt import AdaptersMap
self._psycopg_adapters_map = adapters_map = AdaptersMap(
self.dbapi.adapters
)
if self._native_inet_types is False:
import psycopg.types.string
adapters_map.register_loader(
"inet", psycopg.types.string.TextLoader
)
adapters_map.register_loader(
"cidr", psycopg.types.string.TextLoader
)
if self._json_deserializer:
from psycopg.types.json import set_json_loads
set_json_loads(self._json_deserializer, adapters_map)
if self._json_serializer:
from psycopg.types.json import set_json_dumps
set_json_dumps(self._json_serializer, adapters_map)
def create_connect_args(self, url):
# see https://github.com/psycopg/psycopg/issues/83
cargs, cparams = super().create_connect_args(url)
if self._psycopg_adapters_map:
cparams["context"] = self._psycopg_adapters_map
if self.client_encoding is not None:
cparams["client_encoding"] = self.client_encoding
return cargs, cparams
def _type_info_fetch(self, connection, name):
from psycopg.types import TypeInfo
return TypeInfo.fetch(connection.connection.driver_connection, name)
def initialize(self, connection):
super().initialize(connection)
# PGDialect.initialize() checks server version for <= 8.2 and sets
# this flag to False if so
if not self.insert_returning:
self.insert_executemany_returning = False
# HSTORE can't be registered until we have a connection so that
# we can look up its OID, so we set up this adapter in
# initialize()
if self.use_native_hstore:
info = self._type_info_fetch(connection, "hstore")
self._has_native_hstore = info is not None
if self._has_native_hstore:
from psycopg.types.hstore import register_hstore
# register the adapter for connections made subsequent to
# this one
register_hstore(info, self._psycopg_adapters_map)
# register the adapter for this connection
register_hstore(info, connection.connection)
@classmethod
def import_dbapi(cls):
import psycopg
return psycopg
@classmethod
def get_async_dialect_cls(cls, url):
return PGDialectAsync_psycopg
@util.memoized_property
def _isolation_lookup(self):
return {
"READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED,
"READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED,
"REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ,
"SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE,
}
@util.memoized_property
def _psycopg_Json(self):
from psycopg.types import json
return json.Json
@util.memoized_property
def _psycopg_Jsonb(self):
from psycopg.types import json
return json.Jsonb
@util.memoized_property
def _psycopg_TransactionStatus(self):
from psycopg.pq import TransactionStatus
return TransactionStatus
@util.memoized_property
def _psycopg_Range(self):
from psycopg.types.range import Range
return Range
@util.memoized_property
def _psycopg_Multirange(self):
from psycopg.types.multirange import Multirange
return Multirange
def _do_isolation_level(self, connection, autocommit, isolation_level):
connection.autocommit = autocommit
connection.isolation_level = isolation_level
def get_isolation_level(self, dbapi_connection):
status_before = dbapi_connection.info.transaction_status
value = super().get_isolation_level(dbapi_connection)
# don't rely on psycopg providing enum symbols, compare with
# eq/ne
if status_before == self._psycopg_TransactionStatus.IDLE:
dbapi_connection.rollback()
return value
def set_isolation_level(self, dbapi_connection, level):
if level == "AUTOCOMMIT":
self._do_isolation_level(
dbapi_connection, autocommit=True, isolation_level=None
)
else:
self._do_isolation_level(
dbapi_connection,
autocommit=False,
isolation_level=self._isolation_lookup[level],
)
def set_readonly(self, connection, value):
connection.read_only = value
def get_readonly(self, connection):
return connection.read_only
def on_connect(self):
def notices(conn):
conn.add_notice_handler(_log_notices)
fns = [notices]
if self.isolation_level is not None:
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(on_connect)
# fns always has the notices function
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error) and connection is not None:
if connection.closed or connection.broken:
return True
return False
def _do_prepared_twophase(self, connection, command, recover=False):
dbapi_conn = connection.connection.dbapi_connection
if (
recover
# don't rely on psycopg providing enum symbols, compare with
# eq/ne
or dbapi_conn.info.transaction_status
!= self._psycopg_TransactionStatus.IDLE
):
dbapi_conn.rollback()
before_autocommit = dbapi_conn.autocommit
try:
if not before_autocommit:
self._do_autocommit(dbapi_conn, True)
dbapi_conn.execute(command)
finally:
if not before_autocommit:
self._do_autocommit(dbapi_conn, before_autocommit)
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if is_prepared:
self._do_prepared_twophase(
connection, f"ROLLBACK PREPARED '{xid}'", recover=recover
)
else:
self.do_rollback(connection.connection)
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if is_prepared:
self._do_prepared_twophase(
connection, f"COMMIT PREPARED '{xid}'", recover=recover
)
else:
self.do_commit(connection.connection)
@util.memoized_property
def _dialect_specific_select_one(self):
return ";"
class AsyncAdapt_psycopg_cursor:
__slots__ = ("_cursor", "await_", "_rows")
_psycopg_ExecStatus = None
def __init__(self, cursor, await_) -> None:
self._cursor = cursor
self.await_ = await_
self._rows = []
def __getattr__(self, name):
return getattr(self._cursor, name)
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
def close(self):
self._rows.clear()
# Normal cursor just call _close() in a non-sync way.
self._cursor._close()
def execute(self, query, params=None, **kw):
result = self.await_(self._cursor.execute(query, params, **kw))
# sqlalchemy result is not async, so need to pull all rows here
res = self._cursor.pgresult
# don't rely on psycopg providing enum symbols, compare with
# eq/ne
if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
rows = self.await_(self._cursor.fetchall())
if not isinstance(rows, list):
self._rows = list(rows)
else:
self._rows = rows
return result
def executemany(self, query, params_seq):
return self.await_(self._cursor.executemany(query, params_seq))
def __iter__(self):
# TODO: try to avoid pop(0) on a list
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
# TODO: try to avoid pop(0) on a list
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self._cursor.arraysize
retval = self._rows[0:size]
self._rows = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows
self._rows = []
return retval
class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
def execute(self, query, params=None, **kw):
self.await_(self._cursor.execute(query, params, **kw))
return self
def close(self):
self.await_(self._cursor.close())
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=0):
return self.await_(self._cursor.fetchmany(size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
def __iter__(self):
iterator = self._cursor.__aiter__()
while True:
try:
yield self.await_(iterator.__anext__())
except StopAsyncIteration:
break
class AsyncAdapt_psycopg_connection(AdaptedConnection):
_connection: AsyncConnection
__slots__ = ()
await_ = staticmethod(await_only)
def __init__(self, connection) -> None:
self._connection = connection
def __getattr__(self, name):
return getattr(self._connection, name)
def execute(self, query, params=None, **kw):
cursor = self.await_(self._connection.execute(query, params, **kw))
return AsyncAdapt_psycopg_cursor(cursor, self.await_)
def cursor(self, *args, **kw):
cursor = self._connection.cursor(*args, **kw)
if hasattr(cursor, "name"):
return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_)
else:
return AsyncAdapt_psycopg_cursor(cursor, self.await_)
def commit(self):
self.await_(self._connection.commit())
def rollback(self):
self.await_(self._connection.rollback())
def close(self):
self.await_(self._connection.close())
@property
def autocommit(self):
return self._connection.autocommit
@autocommit.setter
def autocommit(self, value):
self.set_autocommit(value)
def set_autocommit(self, value):
self.await_(self._connection.set_autocommit(value))
def set_isolation_level(self, value):
self.await_(self._connection.set_isolation_level(value))
def set_read_only(self, value):
self.await_(self._connection.set_read_only(value))
def set_deferrable(self, value):
self.await_(self._connection.set_deferrable(value))
class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)
class PsycopgAdaptDBAPI:
def __init__(self, psycopg) -> None:
self.psycopg = psycopg
for k, v in self.psycopg.__dict__.items():
if k != "connect":
self.__dict__[k] = v
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop(
"async_creator_fn", self.psycopg.AsyncConnection.connect
)
if util.asbool(async_fallback):
return AsyncAdaptFallback_psycopg_connection(
await_fallback(creator_fn(*arg, **kw))
)
else:
return AsyncAdapt_psycopg_connection(
await_only(creator_fn(*arg, **kw))
)
class PGDialectAsync_psycopg(PGDialect_psycopg):
is_async = True
supports_statement_cache = True
@classmethod
def import_dbapi(cls):
import psycopg
from psycopg.pq import ExecStatus
AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus
return PsycopgAdaptDBAPI(psycopg)
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def _type_info_fetch(self, connection, name):
from psycopg.types import TypeInfo
adapted = connection.connection
return adapted.await_(TypeInfo.fetch(adapted.driver_connection, name))
def _do_isolation_level(self, connection, autocommit, isolation_level):
connection.set_autocommit(autocommit)
connection.set_isolation_level(isolation_level)
def _do_autocommit(self, connection, value):
connection.set_autocommit(value)
def set_readonly(self, connection, value):
connection.set_read_only(value)
def set_deferrable(self, connection, value):
connection.set_deferrable(value)
def get_driver_connection(self, connection):
return connection._connection
dialect = PGDialect_psycopg
dialect_async = PGDialectAsync_psycopg

View file

@ -0,0 +1,876 @@
# dialects/postgresql/psycopg2.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: postgresql+psycopg2
:name: psycopg2
:dbapi: psycopg2
:connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/psycopg2/
.. _psycopg2_toplevel:
psycopg2 Connect Arguments
--------------------------
Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect
may be passed to :func:`_sa.create_engine()`, and include the following:
* ``isolation_level``: This option, available for all PostgreSQL dialects,
includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
dialect. This option sets the **default** isolation level for the
connection that is set immediately upon connection to the database before
the connection is pooled. This option is generally superseded by the more
modern :paramref:`_engine.Connection.execution_options.isolation_level`
execution option, detailed at :ref:`dbapi_autocommit`.
.. seealso::
:ref:`psycopg2_isolation_level`
:ref:`dbapi_autocommit`
* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
using psycopg2's ``set_client_encoding()`` method.
.. seealso::
:ref:`psycopg2_unicode`
* ``executemany_mode``, ``executemany_batch_page_size``,
``executemany_values_page_size``: Allows use of psycopg2
extensions for optimizing "executemany"-style queries. See the referenced
section below for details.
.. seealso::
:ref:`psycopg2_executemany_mode`
.. tip::
The above keyword arguments are **dialect** keyword arguments, meaning
that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`::
engine = create_engine(
"postgresql+psycopg2://scott:tiger@localhost/test",
isolation_level="SERIALIZABLE",
)
These should not be confused with **DBAPI** connect arguments, which
are passed as part of the :paramref:`_sa.create_engine.connect_args`
dictionary and/or are passed in the URL query string, as detailed in
the section :ref:`custom_dbapi_args`.
.. _psycopg2_ssl:
SSL Connections
---------------
The psycopg2 module has a connection argument named ``sslmode`` for
controlling its behavior regarding secure (SSL) connections. The default is
``sslmode=prefer``; it will attempt an SSL connection and if that fails it
will fall back to an unencrypted connection. ``sslmode=require`` may be used
to ensure that only secure connections are established. Consult the
psycopg2 / libpq documentation for further options that are available.
Note that ``sslmode`` is specific to psycopg2 so it is included in the
connection URI::
engine = sa.create_engine(
"postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
)
Unix Domain Connections
------------------------
psycopg2 supports connecting via Unix domain connections. When the ``host``
portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
which specifies Unix-domain communication rather than TCP/IP communication::
create_engine("postgresql+psycopg2://user:password@/dbname")
By default, the socket file used is to connect to a Unix-domain socket
in ``/tmp``, or whatever socket directory was specified when PostgreSQL
was built. This value can be overridden by passing a pathname to psycopg2,
using ``host`` as an additional keyword argument::
create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
.. warning:: The format accepted here allows for a hostname in the main URL
in addition to the "host" query string argument. **When using this URL
format, the initial host is silently ignored**. That is, this URL::
engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2")
Above, the hostname ``myhost1`` is **silently ignored and discarded.** The
host which is connected is the ``myhost2`` host.
This is to maintain some degree of compatibility with PostgreSQL's own URL
format which has been tested to behave the same way and for which tools like
PifPaf hardcode two hostnames.
.. seealso::
`PQconnectdbParams \
<https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
.. _psycopg2_multi_host:
Specifying multiple fallback hosts
-----------------------------------
psycopg2 supports multiple connection points in the connection string.
When the ``host`` parameter is used multiple times in the query section of
the URL, SQLAlchemy will create a single string of the host and port
information provided to make the connections. Tokens may consist of
``host::port`` or just ``host``; in the latter case, the default port
is selected by libpq. In the example below, three host connections
are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port,
and ``HostC::PortC``::
create_engine(
"postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC"
)
As an alternative, libpq query string format also may be used; this specifies
``host`` and ``port`` as single query string arguments with comma-separated
lists - the default port can be chosen by indicating an empty value
in the comma separated list::
create_engine(
"postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC"
)
With either URL style, connections to each host is attempted based on a
configurable strategy, which may be configured using the libpq
``target_session_attrs`` parameter. Per libpq this defaults to ``any``
which indicates a connection to each host is then attempted until a connection is successful.
Other strategies include ``primary``, ``prefer-standby``, etc. The complete
list is documented by PostgreSQL at
`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_.
For example, to indicate two hosts using the ``primary`` strategy::
create_engine(
"postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary"
)
.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format
is repaired, previously ports were not correctly interpreted in this context.
libpq comma-separated format is also now supported.
.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection
string.
.. seealso::
`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_ - please refer
to this section in the libpq documentation for complete background on multiple host support.
Empty DSN Connections / Environment Variable Connections
---------------------------------------------------------
The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the
libpq client library, which by default indicates to connect to a localhost
PostgreSQL database that is open for "trust" connections. This behavior can be
further tailored using a particular set of environment variables which are
prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of
any or all elements of the connection string.
For this form, the URL can be passed without any elements other than the
initial scheme::
engine = create_engine('postgresql+psycopg2://')
In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
function which in turn represents an empty DSN passed to libpq.
.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2.
.. seealso::
`Environment Variables\
<https://www.postgresql.org/docs/current/libpq-envars.html>`_ -
PostgreSQL documentation on how to use ``PG_...``
environment variables for connections.
.. _psycopg2_execution_options:
Per-Statement/Connection Execution Options
-------------------------------------------
The following DBAPI-specific options are respected when used with
:meth:`_engine.Connection.execution_options`,
:meth:`.Executable.execution_options`,
:meth:`_query.Query.execution_options`,
in addition to those not specific to DBAPIs:
* ``isolation_level`` - Set the transaction isolation level for the lifespan
of a :class:`_engine.Connection` (can only be set on a connection,
not a statement
or query). See :ref:`psycopg2_isolation_level`.
* ``stream_results`` - Enable or disable usage of psycopg2 server side
cursors - this feature makes use of "named" cursors in combination with
special result handling methods so that result rows are not fully buffered.
Defaults to False, meaning cursors are buffered by default.
* ``max_row_buffer`` - when using ``stream_results``, an integer value that
specifies the maximum number of rows to buffer at a time. This is
interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the
buffer will grow to ultimately store 1000 rows at a time.
.. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than
1000, and the buffer will grow to that size.
.. _psycopg2_batch_mode:
.. _psycopg2_executemany_mode:
Psycopg2 Fast Execution Helpers
-------------------------------
Modern versions of psycopg2 include a feature known as
`Fast Execution Helpers \
<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
have been shown in benchmarking to improve psycopg2's executemany()
performance, primarily with INSERT statements, by at least
an order of magnitude.
SQLAlchemy implements a native form of the "insert many values"
handler that will rewrite a single-row INSERT statement to accommodate for
many values at once within an extended VALUES clause; this handler is
equivalent to psycopg2's ``execute_values()`` handler; an overview of this
feature and its configuration are at :ref:`engine_insertmanyvalues`.
.. versionadded:: 2.0 Replaced psycopg2's ``execute_values()`` fast execution
helper with a native SQLAlchemy mechanism known as
:ref:`insertmanyvalues <engine_insertmanyvalues>`.
The psycopg2 dialect retains the ability to use the psycopg2-specific
``execute_batch()`` feature, although it is not expected that this is a widely
used feature. The use of this extension may be enabled using the
``executemany_mode`` flag which may be passed to :func:`_sa.create_engine`::
engine = create_engine(
"postgresql+psycopg2://scott:tiger@host/dbname",
executemany_mode='values_plus_batch')
Possible options for ``executemany_mode`` include:
* ``values_only`` - this is the default value. SQLAlchemy's native
:ref:`insertmanyvalues <engine_insertmanyvalues>` handler is used for qualifying
INSERT statements, assuming
:paramref:`_sa.create_engine.use_insertmanyvalues` is left at
its default value of ``True``. This handler rewrites simple
INSERT statements to include multiple VALUES clauses so that many
parameter sets can be inserted with one statement.
* ``'values_plus_batch'``- SQLAlchemy's native
:ref:`insertmanyvalues <engine_insertmanyvalues>` handler is used for qualifying
INSERT statements, assuming
:paramref:`_sa.create_engine.use_insertmanyvalues` is left at its default
value of ``True``. Then, psycopg2's ``execute_batch()`` handler is used for
qualifying UPDATE and DELETE statements when executed with multiple parameter
sets. When using this mode, the :attr:`_engine.CursorResult.rowcount`
attribute will not contain a value for executemany-style executions against
UPDATE and DELETE statements.
.. versionchanged:: 2.0 Removed the ``'batch'`` and ``'None'`` options
from psycopg2 ``executemany_mode``. Control over batching for INSERT
statements is now configured via the
:paramref:`_sa.create_engine.use_insertmanyvalues` engine-level parameter.
The term "qualifying statements" refers to the statement being executed
being a Core :func:`_expression.insert`, :func:`_expression.update`
or :func:`_expression.delete` construct, and **not** a plain textual SQL
string or one constructed using :func:`_expression.text`. It also may **not** be
a special "extension" statement such as an "ON CONFLICT" "upsert" statement.
When using the ORM, all insert/update/delete statements used by the ORM flush process
are qualifying.
The "page size" for the psycopg2 "batch" strategy can be affected
by using the ``executemany_batch_page_size`` parameter, which defaults to
100.
For the "insertmanyvalues" feature, the page size can be controlled using the
:paramref:`_sa.create_engine.insertmanyvalues_page_size` parameter,
which defaults to 1000. An example of modifying both parameters
is below::
engine = create_engine(
"postgresql+psycopg2://scott:tiger@host/dbname",
executemany_mode='values_plus_batch',
insertmanyvalues_page_size=5000, executemany_batch_page_size=500)
.. seealso::
:ref:`engine_insertmanyvalues` - background on "insertmanyvalues"
:ref:`tutorial_multiple_parameters` - General information on using the
:class:`_engine.Connection`
object to execute statements in such a way as to make
use of the DBAPI ``.executemany()`` method.
.. _psycopg2_unicode:
Unicode with Psycopg2
----------------------
The psycopg2 DBAPI driver supports Unicode data transparently.
The client character encoding can be controlled for the psycopg2 dialect
in the following ways:
* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be
passed in the database URL; this parameter is consumed by the underlying
``libpq`` PostgreSQL client library::
engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
Alternatively, the above ``client_encoding`` value may be passed using
:paramref:`_sa.create_engine.connect_args` for programmatic establishment with
``libpq``::
engine = create_engine(
"postgresql+psycopg2://user:pass@host/dbname",
connect_args={'client_encoding': 'utf8'}
)
* For all PostgreSQL versions, psycopg2 supports a client-side encoding
value that will be passed to database connections when they are first
established. The SQLAlchemy psycopg2 dialect supports this using the
``client_encoding`` parameter passed to :func:`_sa.create_engine`::
engine = create_engine(
"postgresql+psycopg2://user:pass@host/dbname",
client_encoding="utf8"
)
.. tip:: The above ``client_encoding`` parameter admittedly is very similar
in appearance to usage of the parameter within the
:paramref:`_sa.create_engine.connect_args` dictionary; the difference
above is that the parameter is consumed by psycopg2 and is
passed to the database connection using ``SET client_encoding TO
'utf8'``; in the previously mentioned style, the parameter is instead
passed through psycopg2 and consumed by the ``libpq`` library.
* A common way to set up client encoding with PostgreSQL databases is to
ensure it is configured within the server-side postgresql.conf file;
this is the recommended way to set encoding for a server that is
consistently of one encoding in all databases::
# postgresql.conf file
# client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
Transactions
------------
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
.. _psycopg2_isolation_level:
Psycopg2 Transaction Isolation Level
-------------------------------------
As discussed in :ref:`postgresql_isolation_level`,
all PostgreSQL dialects support setting of transaction isolation level
both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine`
,
as well as the ``isolation_level`` argument used by
:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect
, these
options make use of psycopg2's ``set_isolation_level()`` connection method,
rather than emitting a PostgreSQL directive; this is because psycopg2's
API-level setting is always emitted at the start of each transaction in any
case.
The psycopg2 dialect supports these constants for isolation level:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`pg8000_isolation_level`
NOTICE logging
---------------
The psycopg2 dialect will log PostgreSQL NOTICE messages
via the ``sqlalchemy.dialects.postgresql`` logger. When this logger
is set to the ``logging.INFO`` level, notice messages will be logged::
import logging
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
Above, it is assumed that logging is configured externally. If this is not
the case, configuration such as ``logging.basicConfig()`` must be utilized::
import logging
logging.basicConfig() # log messages to stdout
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
.. seealso::
`Logging HOWTO <https://docs.python.org/3/howto/logging.html>`_ - on the python.org website
.. _psycopg2_hstore:
HSTORE type
------------
The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of
the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension
by default when psycopg2 version 2.4 or greater is used, and
it is detected that the target database has the HSTORE type set up for use.
In other words, when the dialect makes the first
connection, a sequence like the following is performed:
1. Request the available HSTORE oids using
``psycopg2.extras.HstoreAdapter.get_oids()``.
If this function returns a list of HSTORE identifiers, we then determine
that the ``HSTORE`` extension is present.
This function is **skipped** if the version of psycopg2 installed is
less than version 2.4.
2. If the ``use_native_hstore`` flag is at its default of ``True``, and
we've detected that ``HSTORE`` oids are available, the
``psycopg2.extensions.register_hstore()`` extension is invoked for all
connections.
The ``register_hstore()`` extension has the effect of **all Python
dictionaries being accepted as parameters regardless of the type of target
column in SQL**. The dictionaries are converted by this extension into a
textual HSTORE expression. If this behavior is not desired, disable the
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
follows::
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
use_native_hstore=False)
The ``HSTORE`` type is **still supported** when the
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
means that the coercion between Python dictionaries and the HSTORE
string format, on both the parameter side and the result side, will take
place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
which may be more performant.
""" # noqa
from __future__ import annotations
import collections.abc as collections_abc
import logging
import re
from typing import cast
from . import ranges
from ._psycopg_common import _PGDialect_common_psycopg
from ._psycopg_common import _PGExecutionContext_common_psycopg
from .base import PGIdentifierPreparer
from .json import JSON
from .json import JSONB
from ... import types as sqltypes
from ... import util
from ...util import FastIntFlag
from ...util import parse_user_argument_for_enum
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
return None
class _PGJSONB(JSONB):
def result_processor(self, dialect, coltype):
return None
class _Psycopg2Range(ranges.AbstractSingleRangeImpl):
_psycopg2_range_cls = "none"
def bind_processor(self, dialect):
psycopg2_Range = getattr(
cast(PGDialect_psycopg2, dialect)._psycopg2_extras,
self._psycopg2_range_cls,
)
def to_range(value):
if isinstance(value, ranges.Range):
value = psycopg2_Range(
value.lower, value.upper, value.bounds, value.empty
)
return value
return to_range
def result_processor(self, dialect, coltype):
def to_range(value):
if value is not None:
value = ranges.Range(
value._lower,
value._upper,
bounds=value._bounds if value._bounds else "[)",
empty=not value._bounds,
)
return value
return to_range
class _Psycopg2NumericRange(_Psycopg2Range):
_psycopg2_range_cls = "NumericRange"
class _Psycopg2DateRange(_Psycopg2Range):
_psycopg2_range_cls = "DateRange"
class _Psycopg2DateTimeRange(_Psycopg2Range):
_psycopg2_range_cls = "DateTimeRange"
class _Psycopg2DateTimeTZRange(_Psycopg2Range):
_psycopg2_range_cls = "DateTimeTZRange"
class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg):
_psycopg2_fetched_rows = None
def post_exec(self):
self._log_notices(self.cursor)
def _log_notices(self, cursor):
# check also that notices is an iterable, after it's already
# established that we will be iterating through it. This is to get
# around test suites such as SQLAlchemy's using a Mock object for
# cursor
if not cursor.connection.notices or not isinstance(
cursor.connection.notices, collections_abc.Iterable
):
return
for notice in cursor.connection.notices:
# NOTICE messages have a
# newline character at the end
logger.info(notice.rstrip())
cursor.connection.notices[:] = []
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
pass
class ExecutemanyMode(FastIntFlag):
EXECUTEMANY_VALUES = 0
EXECUTEMANY_VALUES_PLUS_BATCH = 1
(
EXECUTEMANY_VALUES,
EXECUTEMANY_VALUES_PLUS_BATCH,
) = ExecutemanyMode.__members__.values()
class PGDialect_psycopg2(_PGDialect_common_psycopg):
driver = "psycopg2"
supports_statement_cache = True
supports_server_side_cursors = True
default_paramstyle = "pyformat"
# set to true based on psycopg2 version
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_psycopg2
preparer = PGIdentifierPreparer_psycopg2
psycopg2_version = (0, 0)
use_insertmanyvalues_wo_returning = True
returns_native_bytes = False
_has_native_hstore = True
colspecs = util.update_copy(
_PGDialect_common_psycopg.colspecs,
{
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
ranges.INT4RANGE: _Psycopg2NumericRange,
ranges.INT8RANGE: _Psycopg2NumericRange,
ranges.NUMRANGE: _Psycopg2NumericRange,
ranges.DATERANGE: _Psycopg2DateRange,
ranges.TSRANGE: _Psycopg2DateTimeRange,
ranges.TSTZRANGE: _Psycopg2DateTimeTZRange,
},
)
def __init__(
self,
executemany_mode="values_only",
executemany_batch_page_size=100,
**kwargs,
):
_PGDialect_common_psycopg.__init__(self, **kwargs)
if self._native_inet_types:
raise NotImplementedError(
"The psycopg2 dialect does not implement "
"ipaddress type handling; native_inet_types cannot be set "
"to ``True`` when using this dialect."
)
# Parse executemany_mode argument, allowing it to be only one of the
# symbol names
self.executemany_mode = parse_user_argument_for_enum(
executemany_mode,
{
EXECUTEMANY_VALUES: ["values_only"],
EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch"],
},
"executemany_mode",
)
self.executemany_batch_page_size = executemany_batch_page_size
if self.dbapi and hasattr(self.dbapi, "__version__"):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
self.psycopg2_version = tuple(
int(x) for x in m.group(1, 2, 3) if x is not None
)
if self.psycopg2_version < (2, 7):
raise ImportError(
"psycopg2 version 2.7 or higher is required."
)
def initialize(self, connection):
super().initialize(connection)
self._has_native_hstore = (
self.use_native_hstore
and self._hstore_oids(connection.connection.dbapi_connection)
is not None
)
self.supports_sane_multi_rowcount = (
self.executemany_mode is not EXECUTEMANY_VALUES_PLUS_BATCH
)
@classmethod
def import_dbapi(cls):
import psycopg2
return psycopg2
@util.memoized_property
def _psycopg2_extensions(cls):
from psycopg2 import extensions
return extensions
@util.memoized_property
def _psycopg2_extras(cls):
from psycopg2 import extras
return extras
@util.memoized_property
def _isolation_lookup(self):
extensions = self._psycopg2_extensions
return {
"AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT,
"READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED,
"READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
"REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ,
"SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
}
def set_isolation_level(self, dbapi_connection, level):
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
def set_readonly(self, connection, value):
connection.readonly = value
def get_readonly(self, connection):
return connection.readonly
def set_deferrable(self, connection, value):
connection.deferrable = value
def get_deferrable(self, connection):
return connection.deferrable
def on_connect(self):
extras = self._psycopg2_extras
fns = []
if self.client_encoding is not None:
def on_connect(dbapi_conn):
dbapi_conn.set_client_encoding(self.client_encoding)
fns.append(on_connect)
if self.dbapi:
def on_connect(dbapi_conn):
extras.register_uuid(None, dbapi_conn)
fns.append(on_connect)
if self.dbapi and self.use_native_hstore:
def on_connect(dbapi_conn):
hstore_oids = self._hstore_oids(dbapi_conn)
if hstore_oids is not None:
oid, array_oid = hstore_oids
kw = {"oid": oid}
kw["array_oid"] = array_oid
extras.register_hstore(dbapi_conn, **kw)
fns.append(on_connect)
if self.dbapi and self._json_deserializer:
def on_connect(dbapi_conn):
extras.register_default_json(
dbapi_conn, loads=self._json_deserializer
)
extras.register_default_jsonb(
dbapi_conn, loads=self._json_deserializer
)
fns.append(on_connect)
if fns:
def on_connect(dbapi_conn):
for fn in fns:
fn(dbapi_conn)
return on_connect
else:
return None
def do_executemany(self, cursor, statement, parameters, context=None):
if self.executemany_mode is EXECUTEMANY_VALUES_PLUS_BATCH:
if self.executemany_batch_page_size:
kwargs = {"page_size": self.executemany_batch_page_size}
else:
kwargs = {}
self._psycopg2_extras.execute_batch(
cursor, statement, parameters, **kwargs
)
else:
cursor.executemany(statement, parameters)
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin(xid)
def do_prepare_twophase(self, connection, xid):
connection.connection.tpc_prepare()
def _do_twophase(self, dbapi_conn, operation, xid, recover=False):
if recover:
if dbapi_conn.status != self._psycopg2_extensions.STATUS_READY:
dbapi_conn.rollback()
operation(xid)
else:
operation()
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
dbapi_conn = connection.connection.dbapi_connection
self._do_twophase(
dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover
)
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
dbapi_conn = connection.connection.dbapi_connection
self._do_twophase(
dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover
)
@util.memoized_instancemethod
def _hstore_oids(self, dbapi_connection):
extras = self._psycopg2_extras
oids = extras.HstoreAdapter.get_oids(dbapi_connection)
if oids is not None and oids[0]:
return oids[0:2]
else:
return None
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
# check the "closed" flag. this might not be
# present on old psycopg2 versions. Also,
# this flag doesn't actually help in a lot of disconnect
# situations, so don't rely on it.
if getattr(connection, "closed", False):
return True
# checks based on strings. in the case that .closed
# didn't cut it, fall back onto these.
str_e = str(e).partition("\n")[0]
for msg in [
# these error messages from libpq: interfaces/libpq/fe-misc.c
# and interfaces/libpq/fe-secure.c.
"terminating connection",
"closed the connection",
"connection not open",
"could not receive data from server",
"could not send data to server",
# psycopg2 client errors, psycopg2/connection.h,
# psycopg2/cursor.h
"connection already closed",
"cursor already closed",
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
"losed the connection unexpectedly",
# these can occur in newer SSL
"connection has been closed unexpectedly",
"SSL error: decryption failed or bad record mac",
"SSL SYSCALL error: Bad file descriptor",
"SSL SYSCALL error: EOF detected",
"SSL SYSCALL error: Operation timed out",
"SSL SYSCALL error: Bad address",
]:
idx = str_e.find(msg)
if idx >= 0 and '"' not in str_e[:idx]:
return True
return False
dialect = PGDialect_psycopg2

View file

@ -0,0 +1,61 @@
# dialects/postgresql/psycopg2cffi.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: postgresql+psycopg2cffi
:name: psycopg2cffi
:dbapi: psycopg2cffi
:connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/psycopg2cffi/
``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
layer. This makes it suitable for use in e.g. PyPy. Documentation
is as per ``psycopg2``.
.. seealso::
:mod:`sqlalchemy.dialects.postgresql.psycopg2`
""" # noqa
from .psycopg2 import PGDialect_psycopg2
from ... import util
class PGDialect_psycopg2cffi(PGDialect_psycopg2):
driver = "psycopg2cffi"
supports_unicode_statements = True
supports_statement_cache = True
# psycopg2cffi's first release is 2.5.0, but reports
# __version__ as 2.4.4. Subsequent releases seem to have
# fixed this.
FEATURE_VERSION_MAP = dict(
native_json=(2, 4, 4),
native_jsonb=(2, 7, 1),
sane_multi_rowcount=(2, 4, 4),
array_oid=(2, 4, 4),
hstore_adapter=(2, 4, 4),
)
@classmethod
def import_dbapi(cls):
return __import__("psycopg2cffi")
@util.memoized_property
def _psycopg2_extensions(cls):
root = __import__("psycopg2cffi", fromlist=["extensions"])
return root.extensions
@util.memoized_property
def _psycopg2_extras(cls):
root = __import__("psycopg2cffi", fromlist=["extras"])
return root.extras
dialect = PGDialect_psycopg2cffi

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,303 @@
# dialects/postgresql/types.py
# Copyright (C) 2013-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import datetime as dt
from typing import Any
from typing import Optional
from typing import overload
from typing import Type
from typing import TYPE_CHECKING
from uuid import UUID as _python_UUID
from ...sql import sqltypes
from ...sql import type_api
from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.operators import OperatorType
from ...sql.type_api import _LiteralProcessorType
from ...sql.type_api import TypeEngine
_DECIMAL_TYPES = (1231, 1700)
_FLOAT_TYPES = (700, 701, 1021, 1022)
_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
render_bind_cast = True
render_literal_cast = True
if TYPE_CHECKING:
@overload
def __init__(
self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
) -> None: ...
@overload
def __init__(
self: PGUuid[str], as_uuid: Literal[False] = ...
) -> None: ...
def __init__(self, as_uuid: bool = True) -> None: ...
class BYTEA(sqltypes.LargeBinary):
__visit_name__ = "BYTEA"
class INET(sqltypes.TypeEngine[str]):
__visit_name__ = "INET"
PGInet = INET
class CIDR(sqltypes.TypeEngine[str]):
__visit_name__ = "CIDR"
PGCidr = CIDR
class MACADDR(sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR"
PGMacAddr = MACADDR
class MACADDR8(sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR8"
PGMacAddr8 = MACADDR8
class MONEY(sqltypes.TypeEngine[str]):
r"""Provide the PostgreSQL MONEY type.
Depending on driver, result rows using this type may return a
string value which includes currency symbols.
For this reason, it may be preferable to provide conversion to a
numerically-based currency datatype using :class:`_types.TypeDecorator`::
import re
import decimal
from sqlalchemy import Dialect
from sqlalchemy import TypeDecorator
class NumericMoney(TypeDecorator):
impl = MONEY
def process_result_value(
self, value: Any, dialect: Dialect
) -> None:
if value is not None:
# adjust this for the currency and numeric
m = re.match(r"\$([\d.]+)", value)
if m:
value = decimal.Decimal(m.group(1))
return value
Alternatively, the conversion may be applied as a CAST using
the :meth:`_types.TypeDecorator.column_expression` method as follows::
import decimal
from sqlalchemy import cast
from sqlalchemy import TypeDecorator
class NumericMoney(TypeDecorator):
impl = MONEY
def column_expression(self, column: Any):
return cast(column, Numeric())
.. versionadded:: 1.2
"""
__visit_name__ = "MONEY"
class OID(sqltypes.TypeEngine[int]):
"""Provide the PostgreSQL OID type."""
__visit_name__ = "OID"
class REGCONFIG(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL REGCONFIG type.
.. versionadded:: 2.0.0rc1
"""
__visit_name__ = "REGCONFIG"
class TSQUERY(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL TSQUERY type.
.. versionadded:: 2.0.0rc1
"""
__visit_name__ = "TSQUERY"
class REGCLASS(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL REGCLASS type.
.. versionadded:: 1.2.7
"""
__visit_name__ = "REGCLASS"
class TIMESTAMP(sqltypes.TIMESTAMP):
"""Provide the PostgreSQL TIMESTAMP type."""
__visit_name__ = "TIMESTAMP"
def __init__(
self, timezone: bool = False, precision: Optional[int] = None
) -> None:
"""Construct a TIMESTAMP.
:param timezone: boolean value if timezone present, default False
:param precision: optional integer precision value
.. versionadded:: 1.4
"""
super().__init__(timezone=timezone)
self.precision = precision
class TIME(sqltypes.TIME):
"""PostgreSQL TIME type."""
__visit_name__ = "TIME"
def __init__(
self, timezone: bool = False, precision: Optional[int] = None
) -> None:
"""Construct a TIME.
:param timezone: boolean value if timezone present, default False
:param precision: optional integer precision value
.. versionadded:: 1.4
"""
super().__init__(timezone=timezone)
self.precision = precision
class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
"""PostgreSQL INTERVAL type."""
__visit_name__ = "INTERVAL"
native = True
def __init__(
self, precision: Optional[int] = None, fields: Optional[str] = None
) -> None:
"""Construct an INTERVAL.
:param precision: optional integer precision value
:param fields: string fields specifier. allows storage of fields
to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
etc.
.. versionadded:: 1.2
"""
self.precision = precision
self.fields = fields
@classmethod
def adapt_emulated_to_native(
cls, interval: sqltypes.Interval, **kw: Any # type: ignore[override]
) -> INTERVAL:
return INTERVAL(precision=interval.second_precision)
@property
def _type_affinity(self) -> Type[sqltypes.Interval]:
return sqltypes.Interval
def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval:
return sqltypes.Interval(native=True, second_precision=self.precision)
@property
def python_type(self) -> Type[dt.timedelta]:
return dt.timedelta
def literal_processor(
self, dialect: Dialect
) -> Optional[_LiteralProcessorType[dt.timedelta]]:
def process(value: dt.timedelta) -> str:
return f"make_interval(secs=>{value.total_seconds()})"
return process
PGInterval = INTERVAL
class BIT(sqltypes.TypeEngine[int]):
__visit_name__ = "BIT"
def __init__(
self, length: Optional[int] = None, varying: bool = False
) -> None:
if varying:
# BIT VARYING can be unlimited-length, so no default
self.length = length
else:
# BIT without VARYING defaults to length 1
self.length = length or 1
self.varying = varying
PGBit = BIT
class TSVECTOR(sqltypes.TypeEngine[str]):
"""The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
text search type TSVECTOR.
It can be used to do full text queries on natural language
documents.
.. seealso::
:ref:`postgresql_match`
"""
__visit_name__ = "TSVECTOR"
class CITEXT(sqltypes.TEXT):
"""Provide the PostgreSQL CITEXT type.
.. versionadded:: 2.0.7
"""
__visit_name__ = "CITEXT"
def coerce_compared_value(
self, op: Optional[OperatorType], value: Any
) -> TypeEngine[Any]:
return self

View file

@ -0,0 +1,57 @@
# dialects/sqlite/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from . import aiosqlite # noqa
from . import base # noqa
from . import pysqlcipher # noqa
from . import pysqlite # noqa
from .base import BLOB
from .base import BOOLEAN
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import DECIMAL
from .base import FLOAT
from .base import INTEGER
from .base import JSON
from .base import NUMERIC
from .base import REAL
from .base import SMALLINT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import VARCHAR
from .dml import Insert
from .dml import insert
# default dialect
base.dialect = dialect = pysqlite.dialect
__all__ = (
"BLOB",
"BOOLEAN",
"CHAR",
"DATE",
"DATETIME",
"DECIMAL",
"FLOAT",
"INTEGER",
"JSON",
"NUMERIC",
"SMALLINT",
"TEXT",
"TIME",
"TIMESTAMP",
"VARCHAR",
"REAL",
"Insert",
"insert",
"dialect",
)

View file

@ -0,0 +1,396 @@
# dialects/sqlite/aiosqlite.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: sqlite+aiosqlite
:name: aiosqlite
:dbapi: aiosqlite
:connectstring: sqlite+aiosqlite:///file_path
:url: https://pypi.org/project/aiosqlite/
The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
running on top of pysqlite.
aiosqlite is a wrapper around pysqlite that uses a background thread for
each connection. It does not actually use non-blocking IO, as SQLite
databases are not socket-based. However it does provide a working asyncio
interface that's useful for testing and prototyping purposes.
Using a special asyncio mediation layer, the aiosqlite dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("sqlite+aiosqlite:///filename")
The URL passes through all arguments to the ``pysqlite`` driver, so all
connection arguments are the same as they are for that of :ref:`pysqlite`.
.. _aiosqlite_udfs:
User-Defined Functions
----------------------
aiosqlite extends pysqlite to support async, so we can create our own user-defined functions (UDFs)
in Python and use them directly in SQLite queries as described here: :ref:`pysqlite_udfs`.
.. _aiosqlite_serializable:
Serializable isolation / Savepoints / Transactional DDL (asyncio version)
-------------------------------------------------------------------------
Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature.
The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async::
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("sqlite+aiosqlite:///myfile.db")
@event.listens_for(engine.sync_engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable aiosqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@event.listens_for(engine.sync_engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
.. warning:: When using the above recipe, it is advised to not use the
:paramref:`.Connection.execution_options.isolation_level` setting on
:class:`_engine.Connection` and :func:`_sa.create_engine`
with the SQLite driver,
as this function necessarily will also alter the ".isolation_level" setting.
""" # noqa
import asyncio
from functools import partial
from .base import SQLiteExecutionContext
from .pysqlite import SQLiteDialect_pysqlite
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_aiosqlite_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = (
"_adapt_connection",
"_connection",
"description",
"await_",
"_rows",
"arraysize",
"rowcount",
"lastrowid",
)
server_side = False
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
self.arraysize = 1
self.rowcount = -1
self.description = None
self._rows = []
def close(self):
self._rows[:] = []
def execute(self, operation, parameters=None):
try:
_cursor = self.await_(self._connection.cursor())
if parameters is None:
self.await_(_cursor.execute(operation))
else:
self.await_(_cursor.execute(operation, parameters))
if _cursor.description:
self.description = _cursor.description
self.lastrowid = self.rowcount = -1
if not self.server_side:
self._rows = self.await_(_cursor.fetchall())
else:
self.description = None
self.lastrowid = _cursor.lastrowid
self.rowcount = _cursor.rowcount
if not self.server_side:
self.await_(_cursor.close())
else:
self._cursor = _cursor
except Exception as error:
self._adapt_connection._handle_exception(error)
def executemany(self, operation, seq_of_parameters):
try:
_cursor = self.await_(self._connection.cursor())
self.await_(_cursor.executemany(operation, seq_of_parameters))
self.description = None
self.lastrowid = _cursor.lastrowid
self.rowcount = _cursor.rowcount
self.await_(_cursor.close())
except Exception as error:
self._adapt_connection._handle_exception(error)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = "_cursor"
server_side = True
def __init__(self, *arg, **kw):
super().__init__(*arg, **kw)
self._cursor = None
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi",)
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
@property
def isolation_level(self):
return self._connection.isolation_level
@isolation_level.setter
def isolation_level(self, value):
# aiosqlite's isolation_level setter works outside the Thread
# that it's supposed to, necessitating setting check_same_thread=False.
# for improved stability, we instead invent our own awaitable version
# using aiosqlite's async queue directly.
def set_iso(connection, value):
connection.isolation_level = value
function = partial(set_iso, self._connection._conn, value)
future = asyncio.get_event_loop().create_future()
self._connection._tx.put_nowait((future, function))
try:
return self.await_(future)
except Exception as error:
self._handle_exception(error)
def create_function(self, *args, **kw):
try:
self.await_(self._connection.create_function(*args, **kw))
except Exception as error:
self._handle_exception(error)
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiosqlite_ss_cursor(self)
else:
return AsyncAdapt_aiosqlite_cursor(self)
def execute(self, *args, **kw):
return self.await_(self._connection.execute(*args, **kw))
def rollback(self):
try:
self.await_(self._connection.rollback())
except Exception as error:
self._handle_exception(error)
def commit(self):
try:
self.await_(self._connection.commit())
except Exception as error:
self._handle_exception(error)
def close(self):
try:
self.await_(self._connection.close())
except ValueError:
# this is undocumented for aiosqlite, that ValueError
# was raised if .close() was called more than once, which is
# both not customary for DBAPI and is also not a DBAPI.Error
# exception. This is now fixed in aiosqlite via my PR
# https://github.com/omnilib/aiosqlite/pull/238, so we can be
# assured this will not become some other kind of exception,
# since it doesn't raise anymore.
pass
except Exception as error:
self._handle_exception(error)
def _handle_exception(self, error):
if (
isinstance(error, ValueError)
and error.args[0] == "no active connection"
):
raise self.dbapi.sqlite.OperationalError(
"no active connection"
) from error
else:
raise error
class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiosqlite_dbapi:
def __init__(self, aiosqlite, sqlite):
self.aiosqlite = aiosqlite
self.sqlite = sqlite
self.paramstyle = "qmark"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self):
for name in (
"DatabaseError",
"Error",
"IntegrityError",
"NotSupportedError",
"OperationalError",
"ProgrammingError",
"sqlite_version",
"sqlite_version_info",
):
setattr(self, name, getattr(self.aiosqlite, name))
for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
setattr(self, name, getattr(self.sqlite, name))
for name in ("Binary",):
setattr(self, name, getattr(self.sqlite, name))
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", None)
if creator_fn:
connection = creator_fn(*arg, **kw)
else:
connection = self.aiosqlite.connect(*arg, **kw)
# it's a Thread. you'll thank us later
connection.daemon = True
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiosqlite_connection(
self,
await_fallback(connection),
)
else:
return AsyncAdapt_aiosqlite_connection(
self,
await_only(connection),
)
class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
driver = "aiosqlite"
supports_statement_cache = True
is_async = True
supports_server_side_cursors = True
execution_ctx_cls = SQLiteExecutionContext_aiosqlite
@classmethod
def import_dbapi(cls):
return AsyncAdapt_aiosqlite_dbapi(
__import__("aiosqlite"), __import__("sqlite3")
)
@classmethod
def get_pool_class(cls, url):
if cls._is_url_file_db(url):
return pool.NullPool
else:
return pool.StaticPool
def is_disconnect(self, e, connection, cursor):
if isinstance(
e, self.dbapi.OperationalError
) and "no active connection" in str(e):
return True
return super().is_disconnect(e, connection, cursor)
def get_driver_connection(self, connection):
return connection._connection
dialect = SQLiteDialect_aiosqlite

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,240 @@
# dialects/sqlite/dml.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from .._typing import _OnConflictIndexElementsT
from .._typing import _OnConflictIndexWhereT
from .._typing import _OnConflictSetT
from .._typing import _OnConflictWhereT
from ... import util
from ...sql import coercions
from ...sql import roles
from ...sql._typing import _DMLTableArgument
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ColumnCollection
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.elements import KeyedColumnElement
from ...sql.expression import alias
from ...util.typing import Self
__all__ = ("Insert", "insert")
def insert(table: _DMLTableArgument) -> Insert:
"""Construct a sqlite-specific variant :class:`_sqlite.Insert`
construct.
.. container:: inherited_member
The :func:`sqlalchemy.dialects.sqlite.insert` function creates
a :class:`sqlalchemy.dialects.sqlite.Insert`. This class is based
on the dialect-agnostic :class:`_sql.Insert` construct which may
be constructed using the :func:`_sql.insert` function in
SQLAlchemy Core.
The :class:`_sqlite.Insert` construct includes additional methods
:meth:`_sqlite.Insert.on_conflict_do_update`,
:meth:`_sqlite.Insert.on_conflict_do_nothing`.
"""
return Insert(table)
class Insert(StandardInsert):
"""SQLite-specific implementation of INSERT.
Adds methods for SQLite-specific syntaxes such as ON CONFLICT.
The :class:`_sqlite.Insert` object is created using the
:func:`sqlalchemy.dialects.sqlite.insert` function.
.. versionadded:: 1.4
.. seealso::
:ref:`sqlite_on_conflict_insert`
"""
stringify_dialect = "sqlite"
inherit_cache = False
@util.memoized_property
def excluded(
self,
) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
SQLite's ON CONFLICT clause allows reference to the row that would
be inserted, known as ``excluded``. This attribute provides
all columns in this row to be referenceable.
.. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance
of :class:`_expression.ColumnCollection`, which provides an
interface the same as that of the :attr:`_schema.Table.c`
collection described at :ref:`metadata_tables_and_columns`.
With this collection, ordinary names are accessible like attributes
(e.g. ``stmt.excluded.some_column``), but special names and
dictionary method names should be accessed using indexed access,
such as ``stmt.excluded["column name"]`` or
``stmt.excluded["values"]``. See the docstring for
:class:`_expression.ColumnCollection` for further examples.
"""
return alias(self.table, name="excluded").columns
_on_conflict_exclusive = _exclusive_against(
"_post_values_clause",
msgs={
"_post_values_clause": "This Insert construct already has "
"an ON CONFLICT clause established"
},
)
@_generative
@_on_conflict_exclusive
def on_conflict_do_update(
self,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
set_: _OnConflictSetT = None,
where: _OnConflictWhereT = None,
) -> Self:
r"""
Specifies a DO UPDATE SET action for ON CONFLICT clause.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index or unique constraint.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
:param set\_:
A dictionary or other mapping object
where the keys are either names of columns in the target table,
or :class:`_schema.Column` objects or other ORM-mapped columns
matching that of the target table, and expressions or literals
as values, specifying the ``SET`` actions to take.
.. versionadded:: 1.4 The
:paramref:`_sqlite.Insert.on_conflict_do_update.set_`
parameter supports :class:`_schema.Column` objects from the target
:class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`_schema.Column.onupdate`.
These values will not be exercised for an ON CONFLICT style of
UPDATE, unless they are manually specified in the
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
"""
self._post_values_clause = OnConflictDoUpdate(
index_elements, index_where, set_, where
)
return self
@_generative
@_on_conflict_exclusive
def on_conflict_do_nothing(
self,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
) -> Self:
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index or unique constraint.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
"""
self._post_values_clause = OnConflictDoNothing(
index_elements, index_where
)
return self
class OnConflictClause(ClauseElement):
stringify_dialect = "sqlite"
constraint_target: None
inferred_target_elements: _OnConflictIndexElementsT
inferred_target_whereclause: _OnConflictIndexWhereT
def __init__(
self,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
):
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
else:
self.constraint_target = self.inferred_target_elements = (
self.inferred_target_whereclause
) = None
class OnConflictDoNothing(OnConflictClause):
__visit_name__ = "on_conflict_do_nothing"
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
def __init__(
self,
index_elements: _OnConflictIndexElementsT = None,
index_where: _OnConflictIndexWhereT = None,
set_: _OnConflictSetT = None,
where: _OnConflictWhereT = None,
):
super().__init__(
index_elements=index_elements,
index_where=index_where,
)
if isinstance(set_, dict):
if not set_:
raise ValueError("set parameter dictionary must not be empty")
elif isinstance(set_, ColumnCollection):
set_ = dict(set_)
else:
raise ValueError(
"set parameter must be a non-empty dictionary "
"or a ColumnCollection such as the `.c.` collection "
"of a Table object"
)
self.update_values_to_set = [
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
self.update_whereclause = where

View file

@ -0,0 +1,92 @@
# dialects/sqlite/json.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import types as sqltypes
class JSON(sqltypes.JSON):
"""SQLite JSON type.
SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
that JSON1_ is a
`loadable extension <https://www.sqlite.org/loadext.html>`_ and as such
may not be available, or may require run-time loading.
:class:`_sqlite.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a SQLite backend.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The :class:`_sqlite.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`_types.JSON`
datatype, by adapting the operations to render the ``JSON_EXTRACT``
function wrapped in the ``JSON_QUOTE`` function at the database level.
Extracted values are quoted in order to ensure that the results are
always JSON string values.
.. versionadded:: 1.3
.. _JSON1: https://www.sqlite.org/json1.html
"""
# Note: these objects currently match exactly those of MySQL, however since
# these are not generalizable to all JSON implementations, remain separately
# implemented for each dialect.
class _FormatTypeMixin:
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
)
)

View file

@ -0,0 +1,198 @@
# dialects/sqlite/provision.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import os
import re
from ... import exc
from ...engine import url as sa_url
from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import generate_driver_url
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
from ...testing.provision import upsert
# TODO: I can't get this to build dynamically with pytest-xdist procs
_drivernames = {
"pysqlite",
"aiosqlite",
"pysqlcipher",
"pysqlite_numeric",
"pysqlite_dollar",
}
def _format_url(url, driver, ident):
"""given a sqlite url + desired driver + ident, make a canonical
URL out of it
"""
url = sa_url.make_url(url)
if driver is None:
driver = url.get_driver_name()
filename = url.database
needs_enc = driver == "pysqlcipher"
name_token = None
if filename and filename != ":memory:":
assert "test_schema" not in filename
tokens = re.split(r"[_\.]", filename)
new_filename = f"{driver}"
for token in tokens:
if token in _drivernames:
if driver is None:
driver = token
continue
elif token in ("db", "enc"):
continue
elif name_token is None:
name_token = token.strip("_")
assert name_token, f"sqlite filename has no name token: {url.database}"
new_filename = f"{name_token}_{driver}"
if ident:
new_filename += f"_{ident}"
new_filename += ".db"
if needs_enc:
new_filename += ".enc"
url = url.set(database=new_filename)
if needs_enc:
url = url.set(password="test")
url = url.set(drivername="sqlite+%s" % (driver,))
return url
@generate_driver_url.for_db("sqlite")
def generate_driver_url(url, driver, query_str):
url = _format_url(url, driver, None)
try:
url.get_dialect()
except exc.NoSuchModuleError:
return None
else:
return url
@follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
return _format_url(url, None, ident)
@post_configure_engine.for_db("sqlite")
def _sqlite_post_configure_engine(url, engine, follower_ident):
from sqlalchemy import event
if follower_ident:
attach_path = f"{follower_ident}_{engine.driver}_test_schema.db"
else:
attach_path = f"{engine.driver}_test_schema.db"
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
# use file DBs in all cases, memory acts kind of strangely
# as an attached
# NOTE! this has to be done *per connection*. New sqlite connection,
# as we get with say, QueuePool, the attaches are gone.
# so schemes to delete those attached files have to be done at the
# filesystem level and not rely upon what attachments are in a
# particular SQLite connection
dbapi_connection.execute(
f'ATTACH DATABASE "{attach_path}" AS test_schema'
)
@event.listens_for(engine, "engine_disposed")
def dispose(engine):
"""most databases should be dropped using
stop_test_class_outside_fixtures
however a few tests like AttachedDBTest might not get triggered on
that main hook
"""
if os.path.exists(attach_path):
os.remove(attach_path)
filename = engine.url.database
if filename and filename != ":memory:" and os.path.exists(filename):
os.remove(filename)
@create_db.for_db("sqlite")
def _sqlite_create_db(cfg, eng, ident):
pass
@drop_db.for_db("sqlite")
def _sqlite_drop_db(cfg, eng, ident):
_drop_dbs_w_ident(eng.url.database, eng.driver, ident)
def _drop_dbs_w_ident(databasename, driver, ident):
for path in os.listdir("."):
fname, ext = os.path.split(path)
if ident in fname and ext in [".db", ".db.enc"]:
log.info("deleting SQLite database file: %s", path)
os.remove(path)
@stop_test_class_outside_fixtures.for_db("sqlite")
def stop_test_class_outside_fixtures(config, db, cls):
db.dispose()
@temp_table_keyword_args.for_db("sqlite")
def _sqlite_temp_table_keyword_args(cfg, eng):
return {"prefixes": ["TEMPORARY"]}
@run_reap_dbs.for_db("sqlite")
def _reap_sqlite_dbs(url, idents):
log.info("db reaper connecting to %r", url)
log.info("identifiers in file: %s", ", ".join(idents))
url = sa_url.make_url(url)
for ident in idents:
for drivername in _drivernames:
_drop_dbs_w_ident(url.database, drivername, ident)
@upsert.for_db("sqlite")
def _upsert(
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
):
from sqlalchemy.dialects.sqlite import insert
stmt = insert(table)
if set_lambda:
stmt = stmt.on_conflict_do_update(set_=set_lambda(stmt.excluded))
else:
stmt = stmt.on_conflict_do_nothing()
stmt = stmt.returning(
*returning, sort_by_parameter_order=sort_by_parameter_order
)
return stmt

View file

@ -0,0 +1,155 @@
# dialects/sqlite/pysqlcipher.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
.. dialect:: sqlite+pysqlcipher
:name: pysqlcipher
:dbapi: sqlcipher 3 or pysqlcipher
:connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=<iter>]
Dialect for support of DBAPIs that make use of the
`SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
Driver
------
Current dialect selection logic is:
* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module,
that module is used.
* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/
* If not available, fall back to https://pypi.org/project/pysqlcipher3/
* For Python 2, https://pypi.org/project/pysqlcipher/ is used.
.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no
longer maintained; the ``sqlcipher3`` driver as of this writing appears
to be current. For future compatibility, any pysqlcipher-compatible DBAPI
may be used as follows::
import sqlcipher_compatible_driver
from sqlalchemy import create_engine
e = create_engine(
"sqlite+pysqlcipher://:password@/dbname.db",
module=sqlcipher_compatible_driver
)
These drivers make use of the SQLCipher engine. This system essentially
introduces new PRAGMA commands to SQLite which allows the setting of a
passphrase and other encryption parameters, allowing the database file to be
encrypted.
Connect Strings
---------------
The format of the connect string is in every way the same as that
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
"password" field is now accepted, which should contain a passphrase::
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
For an absolute file path, two leading slashes should be used for the
database name::
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
A selection of additional encryption-related pragmas supported by SQLCipher
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
in the query string, and will result in that PRAGMA being called for each
new connection. Currently, ``cipher``, ``kdf_iter``
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
.. warning:: Previous versions of sqlalchemy did not take into consideration
the encryption-related pragmas passed in the url string, that were silently
ignored. This may cause errors when opening files saved by a
previous sqlalchemy version if the encryption options do not match.
Pooling Behavior
----------------
The driver makes a change to the default pool behavior of pysqlite
as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
has been observed to be significantly slower on connection than the
pysqlite driver, most likely due to the encryption overhead, so the
dialect here defaults to using the :class:`.SingletonThreadPool`
implementation,
instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
implementation is entirely configurable using the
:paramref:`_sa.create_engine.poolclass` parameter; the :class:`.
StaticPool` may
be more feasible for single-threaded use, or :class:`.NullPool` may be used
to prevent unencrypted connections from being held open for long periods of
time, at the expense of slower startup time for new connections.
""" # noqa
from .pysqlite import SQLiteDialect_pysqlite
from ... import pool
class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
driver = "pysqlcipher"
supports_statement_cache = True
pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
@classmethod
def import_dbapi(cls):
try:
import sqlcipher3 as sqlcipher
except ImportError:
pass
else:
return sqlcipher
from pysqlcipher3 import dbapi2 as sqlcipher
return sqlcipher
@classmethod
def get_pool_class(cls, url):
return pool.SingletonThreadPool
def on_connect_url(self, url):
super_on_connect = super().on_connect_url(url)
# pull the info we need from the URL early. Even though URL
# is immutable, we don't want any in-place changes to the URL
# to affect things
passphrase = url.password or ""
url_query = dict(url.query)
def on_connect(conn):
cursor = conn.cursor()
cursor.execute('pragma key="%s"' % passphrase)
for prag in self.pragmas:
value = url_query.get(prag, None)
if value is not None:
cursor.execute('pragma %s="%s"' % (prag, value))
cursor.close()
if super_on_connect:
super_on_connect(conn)
return on_connect
def create_connect_args(self, url):
plain_url = url._replace(password=None)
plain_url = plain_url.difference_update_query(self.pragmas)
return super().create_connect_args(plain_url)
dialect = SQLiteDialect_pysqlcipher

View file

@ -0,0 +1,753 @@
# dialects/sqlite/pysqlite.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: sqlite+pysqlite
:name: pysqlite
:dbapi: sqlite3
:connectstring: sqlite+pysqlite:///file_path
:url: https://docs.python.org/library/sqlite3.html
Note that ``pysqlite`` is the same driver as the ``sqlite3``
module included with the Python distribution.
Driver
------
The ``sqlite3`` Python DBAPI is standard on all modern Python versions;
for cPython and Pypy, no additional installation is necessary.
Connect Strings
---------------
The file specification for the SQLite database is taken as the "database"
portion of the URL. Note that the format of a SQLAlchemy url is::
driver://user:pass@host/database
This means that the actual filename to be used starts with the characters to
the **right** of the third slash. So connecting to a relative filepath
looks like::
# relative path
e = create_engine('sqlite:///path/to/database.db')
An absolute path, which is denoted by starting with a slash, means you
need **four** slashes::
# absolute path
e = create_engine('sqlite:////path/to/database.db')
To use a Windows path, regular drive specifications and backslashes can be
used. Double backslashes are probably needed::
# absolute path on Windows
e = create_engine('sqlite:///C:\\path\\to\\database.db')
The sqlite ``:memory:`` identifier is the default if no filepath is
present. Specify ``sqlite://`` and nothing else::
# in-memory database
e = create_engine('sqlite://')
.. _pysqlite_uri_connections:
URI Connections
^^^^^^^^^^^^^^^
Modern versions of SQLite support an alternative system of connecting using a
`driver level URI <https://www.sqlite.org/uri.html>`_, which has the advantage
that additional driver-level arguments can be passed including options such as
"read only". The Python sqlite3 driver supports this mode under modern Python
3 versions. The SQLAlchemy pysqlite driver supports this mode of use by
specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept
as the "database" portion of the SQLAlchemy url (that is, following a slash)::
e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true")
.. note:: The "uri=true" parameter must appear in the **query string**
of the URL. It will not currently work as expected if it is only
present in the :paramref:`_sa.create_engine.connect_args`
parameter dictionary.
The logic reconciles the simultaneous presence of SQLAlchemy's query string and
SQLite's query string by separating out the parameters that belong to the
Python sqlite3 driver vs. those that belong to the SQLite URI. This is
achieved through the use of a fixed list of parameters known to be accepted by
the Python side of the driver. For example, to include a URL that indicates
the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the
SQLite "mode" and "nolock" parameters, they can all be passed together on the
query string::
e = create_engine(
"sqlite:///file:path/to/database?"
"check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true"
)
Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
sqlite3.connect(
"file:path/to/database?mode=ro&nolock=1",
check_same_thread=True, timeout=10, uri=True
)
Regarding future parameters added to either the Python or native drivers. new
parameter names added to the SQLite URI scheme should be automatically
accommodated by this scheme. New parameter names added to the Python driver
side can be accommodated by specifying them in the
:paramref:`_sa.create_engine.connect_args` dictionary,
until dialect support is
added by SQLAlchemy. For the less likely case that the native SQLite driver
adds a new parameter name that overlaps with one of the existing, known Python
driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would
require adjustment for the URL scheme to continue to support this.
As is always the case for all SQLAlchemy dialects, the entire "URL" process
can be bypassed in :func:`_sa.create_engine` through the use of the
:paramref:`_sa.create_engine.creator`
parameter which allows for a custom callable
that creates a Python sqlite3 driver level connection directly.
.. versionadded:: 1.3.9
.. seealso::
`Uniform Resource Identifiers <https://www.sqlite.org/uri.html>`_ - in
the SQLite documentation
.. _pysqlite_regexp:
Regular Expression Support
---------------------------
.. versionadded:: 1.4
Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided
using Python's re.search_ function. SQLite itself does not include a working
regular expression operator; instead, it includes a non-implemented placeholder
operator ``REGEXP`` that calls a user-defined function that must be provided.
SQLAlchemy's implementation makes use of the pysqlite create_function_ hook
as follows::
def regexp(a, b):
return re.search(a, b) is not None
sqlite_connection.create_function(
"regexp", 2, regexp,
)
There is currently no support for regular expression flags as a separate
argument, as these are not supported by SQLite's REGEXP operator, however these
may be included inline within the regular expression string. See `Python regular expressions`_ for
details.
.. seealso::
`Python regular expressions`_: Documentation for Python's regular expression syntax.
.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
.. _re.search: https://docs.python.org/3/library/re.html#re.search
.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search
Compatibility with sqlite3 "native" date and datetime types
-----------------------------------------------------------
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
sqlite3.PARSE_COLNAMES options, which have the effect of any column
or expression explicitly cast as "date" or "timestamp" will be converted
to a Python date or datetime object. The date and datetime types provided
with the pysqlite dialect are not currently compatible with these options,
since they render the ISO date/datetime including microseconds, which
pysqlite's driver does not. Additionally, SQLAlchemy does not at
this time automatically render the "cast" syntax required for the
freestanding functions "current_timestamp" and "current_date" to return
datetime/date types natively. Unfortunately, pysqlite
does not provide the standard DBAPI types in ``cursor.description``,
leaving SQLAlchemy with no way to detect these types on the fly
without expensive per-row type checks.
Keeping in mind that pysqlite's parsing option is not recommended,
nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
can be forced if one configures "native_datetime=True" on create_engine()::
engine = create_engine('sqlite://',
connect_args={'detect_types':
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
native_datetime=True
)
With this flag enabled, the DATE and TIMESTAMP types (but note - not the
DATETIME or TIME types...confused yet ?) will not perform any bind parameter
or result processing. Execution of "func.current_date()" will return a string.
"func.current_timestamp()" is registered as returning a DATETIME type in
SQLAlchemy, so this function still receives SQLAlchemy-level result
processing.
.. _pysqlite_threading_pooling:
Threading/Pooling Behavior
---------------------------
The ``sqlite3`` DBAPI by default prohibits the use of a particular connection
in a thread which is not the one in which it was created. As SQLite has
matured, it's behavior under multiple threads has improved, and even includes
options for memory only databases to be used in multiple threads.
The thread prohibition is known as "check same thread" and may be controlled
using the ``sqlite3`` parameter ``check_same_thread``, which will disable or
enable this check. SQLAlchemy's default behavior here is to set
``check_same_thread`` to ``False`` automatically whenever a file-based database
is in use, to establish compatibility with the default pool class
:class:`.QueuePool`.
The SQLAlchemy ``pysqlite`` DBAPI establishes the connection pool differently
based on the kind of SQLite database that's requested:
* When a ``:memory:`` SQLite database is specified, the dialect by default
will use :class:`.SingletonThreadPool`. This pool maintains a single
connection per thread, so that all access to the engine within the current
thread use the same ``:memory:`` database - other threads would access a
different ``:memory:`` database. The ``check_same_thread`` parameter
defaults to ``True``.
* When a file-based database is specified, the dialect will use
:class:`.QueuePool` as the source of connections. at the same time,
the ``check_same_thread`` flag is set to False by default unless overridden.
.. versionchanged:: 2.0
SQLite file database engines now use :class:`.QueuePool` by default.
Previously, :class:`.NullPool` were used. The :class:`.NullPool` class
may be used by specifying it via the
:paramref:`_sa.create_engine.poolclass` parameter.
Disabling Connection Pooling for File Databases
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Pooling may be disabled for a file based database by specifying the
:class:`.NullPool` implementation for the :func:`_sa.create_engine.poolclass`
parameter::
from sqlalchemy import NullPool
engine = create_engine("sqlite:///myfile.db", poolclass=NullPool)
It's been observed that the :class:`.NullPool` implementation incurs an
extremely small performance overhead for repeated checkouts due to the lack of
connection re-use implemented by :class:`.QueuePool`. However, it still
may be beneficial to use this class if the application is experiencing
issues with files being locked.
Using a Memory Database in Multiple Threads
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
To use a ``:memory:`` database in a multithreaded scenario, the same
connection object must be shared among threads, since the database exists
only within the scope of that connection. The
:class:`.StaticPool` implementation will maintain a single connection
globally, and the ``check_same_thread`` flag can be passed to Pysqlite
as ``False``::
from sqlalchemy.pool import StaticPool
engine = create_engine('sqlite://',
connect_args={'check_same_thread':False},
poolclass=StaticPool)
Note that using a ``:memory:`` database in multiple threads requires a recent
version of SQLite.
Using Temporary Tables with SQLite
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Due to the way SQLite deals with temporary tables, if you wish to use a
temporary table in a file-based SQLite database across multiple checkouts
from the connection pool, such as when using an ORM :class:`.Session` where
the temporary table should continue to remain after :meth:`.Session.commit` or
:meth:`.Session.rollback` is called, a pool which maintains a single
connection must be used. Use :class:`.SingletonThreadPool` if the scope is
only needed within the current thread, or :class:`.StaticPool` is scope is
needed within multiple threads for this case::
# maintain the same connection per thread
from sqlalchemy.pool import SingletonThreadPool
engine = create_engine('sqlite:///mydb.db',
poolclass=SingletonThreadPool)
# maintain the same connection across all threads
from sqlalchemy.pool import StaticPool
engine = create_engine('sqlite:///mydb.db',
poolclass=StaticPool)
Note that :class:`.SingletonThreadPool` should be configured for the number
of threads that are to be used; beyond that number, connections will be
closed out in a non deterministic way.
Dealing with Mixed String / Binary Columns
------------------------------------------------------
The SQLite database is weakly typed, and as such it is possible when using
binary values, which in Python are represented as ``b'some string'``, that a
particular SQLite database can have data values within different rows where
some of them will be returned as a ``b''`` value by the Pysqlite driver, and
others will be returned as Python strings, e.g. ``''`` values. This situation
is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used
consistently, however if a particular SQLite database has data that was
inserted using the Pysqlite driver directly, or when using the SQLAlchemy
:class:`.String` type which was later changed to :class:`.LargeBinary`, the
table will not be consistently readable because SQLAlchemy's
:class:`.LargeBinary` datatype does not handle strings so it has no way of
"encoding" a value that is in string format.
To deal with a SQLite table that has mixed string / binary data in the
same column, use a custom type that will check each row individually::
from sqlalchemy import String
from sqlalchemy import TypeDecorator
class MixedBinary(TypeDecorator):
impl = String
cache_ok = True
def process_result_value(self, value, dialect):
if isinstance(value, str):
value = bytes(value, 'utf-8')
elif value is not None:
value = bytes(value)
return value
Then use the above ``MixedBinary`` datatype in the place where
:class:`.LargeBinary` would normally be used.
.. _pysqlite_serializable:
Serializable isolation / Savepoints / Transactional DDL
-------------------------------------------------------
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
driver's assortment of issues that prevent several features of SQLite
from working correctly. The pysqlite DBAPI driver has several
long-standing bugs which impact the correctness of its transactional
behavior. In its default mode of operation, SQLite features such as
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
non-functional, and in order to use these features, workarounds must
be taken.
The issue is essentially that the driver attempts to second-guess the user's
intent, failing to start transactions and sometimes ending them prematurely, in
an effort to minimize the SQLite databases's file locking behavior, even
though SQLite itself uses "shared" locks for read-only activities.
SQLAlchemy chooses to not alter this behavior by default, as it is the
long-expected behavior of the pysqlite driver; if and when the pysqlite
driver attempts to repair these issues, that will be more of a driver towards
defaults for SQLAlchemy.
The good news is that with a few events, we can implement transactional
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
ourselves. This is achieved using two event listeners::
from sqlalchemy import create_engine, event
engine = create_engine("sqlite:///myfile.db")
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
.. warning:: When using the above recipe, it is advised to not use the
:paramref:`.Connection.execution_options.isolation_level` setting on
:class:`_engine.Connection` and :func:`_sa.create_engine`
with the SQLite driver,
as this function necessarily will also alter the ".isolation_level" setting.
Above, we intercept a new pysqlite connection and disable any transactional
integration. Then, at the point at which SQLAlchemy knows that transaction
scope is to begin, we emit ``"BEGIN"`` ourselves.
When we take control of ``"BEGIN"``, we can also control directly SQLite's
locking modes, introduced at
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
by adding the desired locking mode to our ``"BEGIN"``::
@event.listens_for(engine, "begin")
def do_begin(conn):
conn.exec_driver_sql("BEGIN EXCLUSIVE")
.. seealso::
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
on the SQLite site
`sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
on the Python bug tracker
`sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
on the Python bug tracker
.. _pysqlite_udfs:
User-Defined Functions
----------------------
pysqlite supports a `create_function() <https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function>`_
method that allows us to create our own user-defined functions (UDFs) in Python and use them directly in SQLite queries.
These functions are registered with a specific DBAPI Connection.
SQLAlchemy uses connection pooling with file-based SQLite databases, so we need to ensure that the UDF is attached to the
connection when it is created. That is accomplished with an event listener::
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import text
def udf():
return "udf-ok"
engine = create_engine("sqlite:///./db_file")
@event.listens_for(engine, "connect")
def connect(conn, rec):
conn.create_function("udf", 0, udf)
for i in range(5):
with engine.connect() as conn:
print(conn.scalar(text("SELECT UDF()")))
""" # noqa
import math
import os
import re
from .base import DATE
from .base import DATETIME
from .base import SQLiteDialect
from ... import exc
from ... import pool
from ... import types as sqltypes
from ... import util
class _SQLite_pysqliteTimeStamp(DATETIME):
def bind_processor(self, dialect):
if dialect.native_datetime:
return None
else:
return DATETIME.bind_processor(self, dialect)
def result_processor(self, dialect, coltype):
if dialect.native_datetime:
return None
else:
return DATETIME.result_processor(self, dialect, coltype)
class _SQLite_pysqliteDate(DATE):
def bind_processor(self, dialect):
if dialect.native_datetime:
return None
else:
return DATE.bind_processor(self, dialect)
def result_processor(self, dialect, coltype):
if dialect.native_datetime:
return None
else:
return DATE.result_processor(self, dialect, coltype)
class SQLiteDialect_pysqlite(SQLiteDialect):
default_paramstyle = "qmark"
supports_statement_cache = True
returns_native_bytes = True
colspecs = util.update_copy(
SQLiteDialect.colspecs,
{
sqltypes.Date: _SQLite_pysqliteDate,
sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
},
)
description_encoding = None
driver = "pysqlite"
@classmethod
def import_dbapi(cls):
from sqlite3 import dbapi2 as sqlite
return sqlite
@classmethod
def _is_url_file_db(cls, url):
if (url.database and url.database != ":memory:") and (
url.query.get("mode", None) != "memory"
):
return True
else:
return False
@classmethod
def get_pool_class(cls, url):
if cls._is_url_file_db(url):
return pool.QueuePool
else:
return pool.SingletonThreadPool
def _get_server_version_info(self, connection):
return self.dbapi.sqlite_version_info
_isolation_lookup = SQLiteDialect._isolation_lookup.union(
{
"AUTOCOMMIT": None,
}
)
def set_isolation_level(self, dbapi_connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.isolation_level = None
else:
dbapi_connection.isolation_level = ""
return super().set_isolation_level(dbapi_connection, level)
def on_connect(self):
def regexp(a, b):
if b is None:
return None
return re.search(a, b) is not None
if util.py38 and self._get_server_version_info(None) >= (3, 9):
# sqlite must be greater than 3.8.3 for deterministic=True
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
# the check is more conservative since there were still issues
# with following 3.8 sqlite versions
create_func_kw = {"deterministic": True}
else:
create_func_kw = {}
def set_regexp(dbapi_connection):
dbapi_connection.create_function(
"regexp", 2, regexp, **create_func_kw
)
def floor_func(dbapi_connection):
# NOTE: floor is optionally present in sqlite 3.35+ , however
# as it is normally non-present we deliver floor() unconditionally
# for now.
# https://www.sqlite.org/lang_mathfunc.html
dbapi_connection.create_function(
"floor", 1, math.floor, **create_func_kw
)
fns = [set_regexp, floor_func]
def connect(conn):
for fn in fns:
fn(conn)
return connect
def create_connect_args(self, url):
if url.username or url.password or url.host or url.port:
raise exc.ArgumentError(
"Invalid SQLite URL: %s\n"
"Valid SQLite URL forms are:\n"
" sqlite:///:memory: (or, sqlite://)\n"
" sqlite:///relative/path/to/file.db\n"
" sqlite:////absolute/path/to/file.db" % (url,)
)
# theoretically, this list can be augmented, at least as far as
# parameter names accepted by sqlite3/pysqlite, using
# inspect.getfullargspec(). for the moment this seems like overkill
# as these parameters don't change very often, and as always,
# parameters passed to connect_args will always go to the
# sqlite3/pysqlite driver.
pysqlite_args = [
("uri", bool),
("timeout", float),
("isolation_level", str),
("detect_types", int),
("check_same_thread", bool),
("cached_statements", int),
]
opts = url.query
pysqlite_opts = {}
for key, type_ in pysqlite_args:
util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
if pysqlite_opts.get("uri", False):
uri_opts = dict(opts)
# here, we are actually separating the parameters that go to
# sqlite3/pysqlite vs. those that go the SQLite URI. What if
# two names conflict? again, this seems to be not the case right
# now, and in the case that new names are added to
# either side which overlap, again the sqlite3/pysqlite parameters
# can be passed through connect_args instead of in the URL.
# If SQLite native URIs add a parameter like "timeout" that
# we already have listed here for the python driver, then we need
# to adjust for that here.
for key, type_ in pysqlite_args:
uri_opts.pop(key, None)
filename = url.database
if uri_opts:
# sorting of keys is for unit test support
filename += "?" + (
"&".join(
"%s=%s" % (key, uri_opts[key])
for key in sorted(uri_opts)
)
)
else:
filename = url.database or ":memory:"
if filename != ":memory:":
filename = os.path.abspath(filename)
pysqlite_opts.setdefault(
"check_same_thread", not self._is_url_file_db(url)
)
return ([filename], pysqlite_opts)
def is_disconnect(self, e, connection, cursor):
return isinstance(
e, self.dbapi.ProgrammingError
) and "Cannot operate on a closed database." in str(e)
dialect = SQLiteDialect_pysqlite
class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
"""numeric dialect for testing only
internal use only. This dialect is **NOT** supported by SQLAlchemy
and may change at any time.
"""
supports_statement_cache = True
default_paramstyle = "numeric"
driver = "pysqlite_numeric"
_first_bind = ":1"
_not_in_statement_regexp = None
def __init__(self, *arg, **kw):
kw.setdefault("paramstyle", "numeric")
super().__init__(*arg, **kw)
def create_connect_args(self, url):
arg, opts = super().create_connect_args(url)
opts["factory"] = self._fix_sqlite_issue_99953()
return arg, opts
def _fix_sqlite_issue_99953(self):
import sqlite3
first_bind = self._first_bind
if self._not_in_statement_regexp:
nis = self._not_in_statement_regexp
def _test_sql(sql):
m = nis.search(sql)
assert not m, f"Found {nis.pattern!r} in {sql!r}"
else:
def _test_sql(sql):
pass
def _numeric_param_as_dict(parameters):
if parameters:
assert isinstance(parameters, tuple)
return {
str(idx): value for idx, value in enumerate(parameters, 1)
}
else:
return ()
class SQLiteFix99953Cursor(sqlite3.Cursor):
def execute(self, sql, parameters=()):
_test_sql(sql)
if first_bind in sql:
parameters = _numeric_param_as_dict(parameters)
return super().execute(sql, parameters)
def executemany(self, sql, parameters):
_test_sql(sql)
if first_bind in sql:
parameters = [
_numeric_param_as_dict(p) for p in parameters
]
return super().executemany(sql, parameters)
class SQLiteFix99953Connection(sqlite3.Connection):
def cursor(self, factory=None):
if factory is None:
factory = SQLiteFix99953Cursor
return super().cursor(factory=factory)
def execute(self, sql, parameters=()):
_test_sql(sql)
if first_bind in sql:
parameters = _numeric_param_as_dict(parameters)
return super().execute(sql, parameters)
def executemany(self, sql, parameters):
_test_sql(sql)
if first_bind in sql:
parameters = [
_numeric_param_as_dict(p) for p in parameters
]
return super().executemany(sql, parameters)
return SQLiteFix99953Connection
class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
"""numeric dialect that uses $ for testing only
internal use only. This dialect is **NOT** supported by SQLAlchemy
and may change at any time.
"""
supports_statement_cache = True
default_paramstyle = "numeric_dollar"
driver = "pysqlite_dollar"
_first_bind = "$1"
_not_in_statement_regexp = re.compile(r"[^\d]:\d+")
def __init__(self, *arg, **kw):
kw.setdefault("paramstyle", "numeric_dollar")
super().__init__(*arg, **kw)

View file

@ -0,0 +1,145 @@
Rules for Migrating TypeEngine classes to 0.6
---------------------------------------------
1. the TypeEngine classes are used for:
a. Specifying behavior which needs to occur for bind parameters
or result row columns.
b. Specifying types that are entirely specific to the database
in use and have no analogue in the sqlalchemy.types package.
c. Specifying types where there is an analogue in sqlalchemy.types,
but the database in use takes vendor-specific flags for those
types.
d. If a TypeEngine class doesn't provide any of this, it should be
*removed* from the dialect.
2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
now have a TypeCompiler subclass which uses the same visit_XXX model as
other compilers.
3. the "ischema_names" and "colspecs" dictionaries are now required members on
the Dialect class.
4. The names of types within dialects are now important. If a dialect-specific type
is a subclass of an existing generic type and is only provided for bind/result behavior,
the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case,
end users would never need to use _PGNumeric directly. However, if a dialect-specific
type is specifying a type *or* arguments that are not present generically, it should
match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
mysql.ENUM, postgresql.ARRAY.
Or follow this handy flowchart:
is the type meant to provide bind/result is the type the same name as an
behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ?
type in types.py ? | |
| no yes
yes | |
| | does your type need special
| +<--- yes --- behavior or arguments ?
| | |
| | no
name the type using | |
_MixedCase, i.e. v V
_OracleBoolean. it name the type don't make a
stays private to the dialect identically as that type, make sure the dialect's
and is invoked *only* via within the DB, base.py imports the types.py
the colspecs dict. using UPPERCASE UPPERCASE name into its namespace
| (i.e. BIT, NCHAR, INTERVAL).
| Users can import it.
| |
v v
subclass the closest is the name of this type
MixedCase type types.py, identical to an UPPERCASE
i.e. <--- no ------- name in types.py ?
class _DateTime(types.DateTime),
class DATETIME2(types.DateTime), |
class BIT(types.TypeEngine). yes
|
v
the type should
subclass the
UPPERCASE
type in types.py
(i.e. class BLOB(types.BLOB))
Example 1. pysqlite needs bind/result processing for the DateTime type in types.py,
which applies to all DateTimes and subclasses. It's named _SLDateTime and
subclasses types.DateTime.
Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument
that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py,
and subclasses types.TIME. Users can then say mssql.TIME(precision=10).
Example 3. MS-SQL dialects also need special bind/result processing for date
But its DATE type doesn't render DDL differently than that of a plain
DATE, i.e. it takes no special arguments. Therefore we are just adding behavior
to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses
types.Date.
Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
it ultimately deals with strings.
Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly,
and no special arguments are used in PG's DDL beyond what types.py provides.
PostgreSQL dialect therefore imports types.DATETIME into its base.py.
Ideally one should be able to specify a schema using names imported completely from a
dialect, all matching the real name on that backend:
from sqlalchemy.dialects.postgresql import base as pg
t = Table('mytable', metadata,
Column('id', pg.INTEGER, primary_key=True),
Column('name', pg.VARCHAR(300)),
Column('inetaddr', pg.INET)
)
where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types,
but the PG dialect makes them available in its own namespace.
5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types
linked to types specified in the dialect. Again, if a type in the dialect does not
specify any special behavior for bind_processor() or result_processor() and does not
indicate a special type only available in this database, it must be *removed* from the
module and from this dictionary.
6. "ischema_names" indicates string descriptions of types as returned from the database
linked to TypeEngine classes.
a. The string name should be matched to the most specific type possible within
sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
case it points to a dialect type. *It doesn't matter* if the dialect has its
own subclass of that type with special bind/result behavior - reflect to the types.py
UPPERCASE type as much as possible. With very few exceptions, all types
should reflect to an UPPERCASE type.
b. If the dialect contains a matching dialect-specific type that takes extra arguments
which the generic one does not, then point to the dialect-specific type. E.g.
mssql.VARCHAR takes a "collation" parameter which should be preserved.
5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
a subclass of compiler.GenericTypeCompiler.
a. your TypeCompiler class will receive generic and uppercase types from
sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
these types.
b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
methods that produce a different DDL name. Uppercase types don't do any kind of
"guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
all cases, regardless of whether or not that type is legal on the backend database.
c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
arguments and flags to those types.
d. the visit_lowercase methods are overridden to provide an interpretation of a generic
type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)".
e. visit_lowercase methods should *never* render strings directly - it should always
be via calling a visit_UPPERCASE() method.

View file

@ -0,0 +1,62 @@
# engine/__init__.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""SQL connections, SQL execution and high-level DB-API interface.
The engine package defines the basic components used to interface
DB-API modules with higher-level statement construction,
connection-management, execution and result contexts. The primary
"entry point" class into this package is the Engine and its public
constructor ``create_engine()``.
"""
from . import events as events
from . import util as util
from .base import Connection as Connection
from .base import Engine as Engine
from .base import NestedTransaction as NestedTransaction
from .base import RootTransaction as RootTransaction
from .base import Transaction as Transaction
from .base import TwoPhaseTransaction as TwoPhaseTransaction
from .create import create_engine as create_engine
from .create import create_pool_from_url as create_pool_from_url
from .create import engine_from_config as engine_from_config
from .cursor import CursorResult as CursorResult
from .cursor import ResultProxy as ResultProxy
from .interfaces import AdaptedConnection as AdaptedConnection
from .interfaces import BindTyping as BindTyping
from .interfaces import Compiled as Compiled
from .interfaces import Connectable as Connectable
from .interfaces import ConnectArgsType as ConnectArgsType
from .interfaces import ConnectionEventsTarget as ConnectionEventsTarget
from .interfaces import CreateEnginePlugin as CreateEnginePlugin
from .interfaces import Dialect as Dialect
from .interfaces import ExceptionContext as ExceptionContext
from .interfaces import ExecutionContext as ExecutionContext
from .interfaces import TypeCompiler as TypeCompiler
from .mock import create_mock_engine as create_mock_engine
from .reflection import Inspector as Inspector
from .reflection import ObjectKind as ObjectKind
from .reflection import ObjectScope as ObjectScope
from .result import ChunkedIteratorResult as ChunkedIteratorResult
from .result import FilterResult as FilterResult
from .result import FrozenResult as FrozenResult
from .result import IteratorResult as IteratorResult
from .result import MappingResult as MappingResult
from .result import MergedResult as MergedResult
from .result import Result as Result
from .result import result_tuple as result_tuple
from .result import ScalarResult as ScalarResult
from .result import TupleResult as TupleResult
from .row import BaseRow as BaseRow
from .row import Row as Row
from .row import RowMapping as RowMapping
from .url import make_url as make_url
from .url import URL as URL
from .util import connection_memoize as connection_memoize
from ..sql import ddl as ddl

View file

@ -0,0 +1,136 @@
# engine/_py_processors.py
# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""defines generic type conversion functions, as used in bind and result
processors.
They all share one common characteristic: None is passed through unchanged.
"""
from __future__ import annotations
import datetime
from datetime import date as date_cls
from datetime import datetime as datetime_cls
from datetime import time as time_cls
from decimal import Decimal
import typing
from typing import Any
from typing import Callable
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union
_DT = TypeVar(
"_DT", bound=Union[datetime.datetime, datetime.time, datetime.date]
)
def str_to_datetime_processor_factory(
regexp: typing.Pattern[str], type_: Callable[..., _DT]
) -> Callable[[Optional[str]], Optional[_DT]]:
rmatch = regexp.match
# Even on python2.6 datetime.strptime is both slower than this code
# and it does not support microseconds.
has_named_groups = bool(regexp.groupindex)
def process(value: Optional[str]) -> Optional[_DT]:
if value is None:
return None
else:
try:
m = rmatch(value)
except TypeError as err:
raise ValueError(
"Couldn't parse %s string '%r' "
"- value is not a string." % (type_.__name__, value)
) from err
if m is None:
raise ValueError(
"Couldn't parse %s string: "
"'%s'" % (type_.__name__, value)
)
if has_named_groups:
groups = m.groupdict(0)
return type_(
**dict(
list(
zip(
iter(groups.keys()),
list(map(int, iter(groups.values()))),
)
)
)
)
else:
return type_(*list(map(int, m.groups(0))))
return process
def to_decimal_processor_factory(
target_class: Type[Decimal], scale: int
) -> Callable[[Optional[float]], Optional[Decimal]]:
fstring = "%%.%df" % scale
def process(value: Optional[float]) -> Optional[Decimal]:
if value is None:
return None
else:
return target_class(fstring % value)
return process
def to_float(value: Optional[Union[int, float]]) -> Optional[float]:
if value is None:
return None
else:
return float(value)
def to_str(value: Optional[Any]) -> Optional[str]:
if value is None:
return None
else:
return str(value)
def int_to_boolean(value: Optional[int]) -> Optional[bool]:
if value is None:
return None
else:
return bool(value)
def str_to_datetime(value: Optional[str]) -> Optional[datetime.datetime]:
if value is not None:
dt_value = datetime_cls.fromisoformat(value)
else:
dt_value = None
return dt_value
def str_to_time(value: Optional[str]) -> Optional[datetime.time]:
if value is not None:
dt_value = time_cls.fromisoformat(value)
else:
dt_value = None
return dt_value
def str_to_date(value: Optional[str]) -> Optional[datetime.date]:
if value is not None:
dt_value = date_cls.fromisoformat(value)
else:
dt_value = None
return dt_value

View file

@ -0,0 +1,128 @@
# engine/_py_row.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import operator
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Type
if typing.TYPE_CHECKING:
from .result import _KeyType
from .result import _ProcessorsType
from .result import _RawRowType
from .result import _TupleGetterType
from .result import ResultMetaData
MD_INDEX = 0 # integer index in cursor.description
class BaseRow:
__slots__ = ("_parent", "_data", "_key_to_index")
_parent: ResultMetaData
_key_to_index: Mapping[_KeyType, int]
_data: _RawRowType
def __init__(
self,
parent: ResultMetaData,
processors: Optional[_ProcessorsType],
key_to_index: Mapping[_KeyType, int],
data: _RawRowType,
):
"""Row objects are constructed by CursorResult objects."""
object.__setattr__(self, "_parent", parent)
object.__setattr__(self, "_key_to_index", key_to_index)
if processors:
object.__setattr__(
self,
"_data",
tuple(
[
proc(value) if proc else value
for proc, value in zip(processors, data)
]
),
)
else:
object.__setattr__(self, "_data", tuple(data))
def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]:
return (
rowproxy_reconstructor,
(self.__class__, self.__getstate__()),
)
def __getstate__(self) -> Dict[str, Any]:
return {"_parent": self._parent, "_data": self._data}
def __setstate__(self, state: Dict[str, Any]) -> None:
parent = state["_parent"]
object.__setattr__(self, "_parent", parent)
object.__setattr__(self, "_data", state["_data"])
object.__setattr__(self, "_key_to_index", parent._key_to_index)
def _values_impl(self) -> List[Any]:
return list(self)
def __iter__(self) -> Iterator[Any]:
return iter(self._data)
def __len__(self) -> int:
return len(self._data)
def __hash__(self) -> int:
return hash(self._data)
def __getitem__(self, key: Any) -> Any:
return self._data[key]
def _get_by_key_impl_mapping(self, key: str) -> Any:
try:
return self._data[self._key_to_index[key]]
except KeyError:
pass
self._parent._key_not_found(key, False)
def __getattr__(self, name: str) -> Any:
try:
return self._data[self._key_to_index[name]]
except KeyError:
pass
self._parent._key_not_found(name, True)
def _to_tuple_instance(self) -> Tuple[Any, ...]:
return self._data
# This reconstructor is necessary so that pickles with the Cy extension or
# without use the same Binary format.
def rowproxy_reconstructor(
cls: Type[BaseRow], state: Dict[str, Any]
) -> BaseRow:
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
def tuplegetter(*indexes: int) -> _TupleGetterType:
if len(indexes) != 1:
for i in range(1, len(indexes)):
if indexes[i - 1] != indexes[i] - 1:
return operator.itemgetter(*indexes)
# slice form is faster but returns a list if input is list
return operator.itemgetter(slice(indexes[0], indexes[-1] + 1))

View file

@ -0,0 +1,74 @@
# engine/_py_util.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import typing
from typing import Any
from typing import Mapping
from typing import Optional
from typing import Tuple
from .. import exc
if typing.TYPE_CHECKING:
from .interfaces import _CoreAnyExecuteParams
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _DBAPIAnyExecuteParams
from .interfaces import _DBAPIMultiExecuteParams
_no_tuple: Tuple[Any, ...] = ()
def _distill_params_20(
params: Optional[_CoreAnyExecuteParams],
) -> _CoreMultiExecuteParams:
if params is None:
return _no_tuple
# Assume list is more likely than tuple
elif isinstance(params, list) or isinstance(params, tuple):
# collections_abc.MutableSequence): # avoid abc.__instancecheck__
if params and not isinstance(params[0], (tuple, Mapping)):
raise exc.ArgumentError(
"List argument must consist only of tuples or dictionaries"
)
return params
elif isinstance(params, dict) or isinstance(
# only do immutabledict or abc.__instancecheck__ for Mapping after
# we've checked for plain dictionaries and would otherwise raise
params,
Mapping,
):
return [params]
else:
raise exc.ArgumentError("mapping or list expected for parameters")
def _distill_raw_params(
params: Optional[_DBAPIAnyExecuteParams],
) -> _DBAPIMultiExecuteParams:
if params is None:
return _no_tuple
elif isinstance(params, list):
# collections_abc.MutableSequence): # avoid abc.__instancecheck__
if params and not isinstance(params[0], (tuple, Mapping)):
raise exc.ArgumentError(
"List argument must consist only of tuples or dictionaries"
)
return params
elif isinstance(params, (tuple, dict)) or isinstance(
# only do abc.__instancecheck__ for Mapping after we've checked
# for plain dictionaries and would otherwise raise
params,
Mapping,
):
# cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params])
return [params] # type: ignore
else:
raise exc.ArgumentError("mapping or sequence expected for parameters")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,81 @@
# engine/characteristics.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import abc
import typing
from typing import Any
from typing import ClassVar
if typing.TYPE_CHECKING:
from .interfaces import DBAPIConnection
from .interfaces import Dialect
class ConnectionCharacteristic(abc.ABC):
"""An abstract base for an object that can set, get and reset a
per-connection characteristic, typically one that gets reset when the
connection is returned to the connection pool.
transaction isolation is the canonical example, and the
``IsolationLevelCharacteristic`` implementation provides this for the
``DefaultDialect``.
The ``ConnectionCharacteristic`` class should call upon the ``Dialect`` for
the implementation of each method. The object exists strictly to serve as
a dialect visitor that can be placed into the
``DefaultDialect.connection_characteristics`` dictionary where it will take
effect for calls to :meth:`_engine.Connection.execution_options` and
related APIs.
.. versionadded:: 1.4
"""
__slots__ = ()
transactional: ClassVar[bool] = False
@abc.abstractmethod
def reset_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> None:
"""Reset the characteristic on the connection to its default value."""
@abc.abstractmethod
def set_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any
) -> None:
"""set characteristic on the connection to a given value."""
@abc.abstractmethod
def get_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> Any:
"""Given a DBAPI connection, get the current value of the
characteristic.
"""
class IsolationLevelCharacteristic(ConnectionCharacteristic):
transactional: ClassVar[bool] = True
def reset_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> None:
dialect.reset_isolation_level(dbapi_conn)
def set_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any
) -> None:
dialect._assert_and_set_isolation_level(dbapi_conn, value)
def get_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> Any:
return dialect.get_isolation_level(dbapi_conn)

View file

@ -0,0 +1,864 @@
# engine/create.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import inspect
import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import overload
from typing import Type
from typing import Union
from . import base
from . import url as _url
from .interfaces import DBAPIConnection
from .mock import create_mock_engine
from .. import event
from .. import exc
from .. import util
from ..pool import _AdhocProxiedConnection
from ..pool import ConnectionPoolEntry
from ..sql import compiler
from ..util import immutabledict
if typing.TYPE_CHECKING:
from .base import Engine
from .interfaces import _ExecuteOptions
from .interfaces import _ParamStyle
from .interfaces import IsolationLevel
from .url import URL
from ..log import _EchoFlagType
from ..pool import _CreatorFnType
from ..pool import _CreatorWRecFnType
from ..pool import _ResetStyleArgType
from ..pool import Pool
from ..util.typing import Literal
@overload
def create_engine(
url: Union[str, URL],
*,
connect_args: Dict[Any, Any] = ...,
convert_unicode: bool = ...,
creator: Union[_CreatorFnType, _CreatorWRecFnType] = ...,
echo: _EchoFlagType = ...,
echo_pool: _EchoFlagType = ...,
enable_from_linting: bool = ...,
execution_options: _ExecuteOptions = ...,
future: Literal[True],
hide_parameters: bool = ...,
implicit_returning: Literal[True] = ...,
insertmanyvalues_page_size: int = ...,
isolation_level: IsolationLevel = ...,
json_deserializer: Callable[..., Any] = ...,
json_serializer: Callable[..., Any] = ...,
label_length: Optional[int] = ...,
logging_name: str = ...,
max_identifier_length: Optional[int] = ...,
max_overflow: int = ...,
module: Optional[Any] = ...,
paramstyle: Optional[_ParamStyle] = ...,
pool: Optional[Pool] = ...,
poolclass: Optional[Type[Pool]] = ...,
pool_logging_name: str = ...,
pool_pre_ping: bool = ...,
pool_size: int = ...,
pool_recycle: int = ...,
pool_reset_on_return: Optional[_ResetStyleArgType] = ...,
pool_timeout: float = ...,
pool_use_lifo: bool = ...,
plugins: List[str] = ...,
query_cache_size: int = ...,
use_insertmanyvalues: bool = ...,
**kwargs: Any,
) -> Engine: ...
@overload
def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: ...
@util.deprecated_params(
strategy=(
"1.4",
"The :paramref:`_sa.create_engine.strategy` keyword is deprecated, "
"and the only argument accepted is 'mock'; please use "
":func:`.create_mock_engine` going forward. For general "
"customization of create_engine which may have been accomplished "
"using strategies, see :class:`.CreateEnginePlugin`.",
),
empty_in_strategy=(
"1.4",
"The :paramref:`_sa.create_engine.empty_in_strategy` keyword is "
"deprecated, and no longer has any effect. All IN expressions "
"are now rendered using "
'the "expanding parameter" strategy which renders a set of bound'
'expressions, or an "empty set" SELECT, at statement execution'
"time.",
),
implicit_returning=(
"2.0",
"The :paramref:`_sa.create_engine.implicit_returning` parameter "
"is deprecated and will be removed in a future release. ",
),
)
def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine:
"""Create a new :class:`_engine.Engine` instance.
The standard calling form is to send the :ref:`URL <database_urls>` as the
first positional argument, usually a string
that indicates database dialect and connection arguments::
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test")
.. note::
Please review :ref:`database_urls` for general guidelines in composing
URL strings. In particular, special characters, such as those often
part of passwords, must be URL encoded to be properly parsed.
Additional keyword arguments may then follow it which
establish various options on the resulting :class:`_engine.Engine`
and its underlying :class:`.Dialect` and :class:`_pool.Pool`
constructs::
engine = create_engine("mysql+mysqldb://scott:tiger@hostname/dbname",
pool_recycle=3600, echo=True)
The string form of the URL is
``dialect[+driver]://user:password@host/dbname[?key=value..]``, where
``dialect`` is a database name such as ``mysql``, ``oracle``,
``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
``**kwargs`` takes a wide variety of options which are routed
towards their appropriate components. Arguments may be specific to
the :class:`_engine.Engine`, the underlying :class:`.Dialect`,
as well as the
:class:`_pool.Pool`. Specific dialects also accept keyword arguments that
are unique to that dialect. Here, we describe the parameters
that are common to most :func:`_sa.create_engine()` usage.
Once established, the newly resulting :class:`_engine.Engine` will
request a connection from the underlying :class:`_pool.Pool` once
:meth:`_engine.Engine.connect` is called, or a method which depends on it
such as :meth:`_engine.Engine.execute` is invoked. The
:class:`_pool.Pool` in turn
will establish the first actual DBAPI connection when this request
is received. The :func:`_sa.create_engine` call itself does **not**
establish any actual DBAPI connections directly.
.. seealso::
:doc:`/core/engines`
:doc:`/dialects/index`
:ref:`connections_toplevel`
:param connect_args: a dictionary of options which will be
passed directly to the DBAPI's ``connect()`` method as
additional keyword arguments. See the example
at :ref:`custom_dbapi_args`.
:param creator: a callable which returns a DBAPI connection.
This creation function will be passed to the underlying
connection pool and will be used to create all new database
connections. Usage of this function causes connection
parameters specified in the URL argument to be bypassed.
This hook is not as flexible as the newer
:meth:`_events.DialectEvents.do_connect` hook which allows complete
control over how a connection is made to the database, given the full
set of URL arguments and state beforehand.
.. seealso::
:meth:`_events.DialectEvents.do_connect` - event hook that allows
full control over DBAPI connection mechanics.
:ref:`custom_dbapi_args`
:param echo=False: if True, the Engine will log all statements
as well as a ``repr()`` of their parameter lists to the default log
handler, which defaults to ``sys.stdout`` for output. If set to the
string ``"debug"``, result rows will be printed to the standard output
as well. The ``echo`` attribute of ``Engine`` can be modified at any
time to turn logging on and off; direct control of logging is also
available using the standard Python ``logging`` module.
.. seealso::
:ref:`dbengine_logging` - further detail on how to configure
logging.
:param echo_pool=False: if True, the connection pool will log
informational output such as when connections are invalidated
as well as when connections are recycled to the default log handler,
which defaults to ``sys.stdout`` for output. If set to the string
``"debug"``, the logging will include pool checkouts and checkins.
Direct control of logging is also available using the standard Python
``logging`` module.
.. seealso::
:ref:`dbengine_logging` - further detail on how to configure
logging.
:param empty_in_strategy: No longer used; SQLAlchemy now uses
"empty set" behavior for IN in all cases.
:param enable_from_linting: defaults to True. Will emit a warning
if a given SELECT statement is found to have un-linked FROM elements
which would cause a cartesian product.
.. versionadded:: 1.4
.. seealso::
:ref:`change_4737`
:param execution_options: Dictionary execution options which will
be applied to all connections. See
:meth:`~sqlalchemy.engine.Connection.execution_options`
:param future: Use the 2.0 style :class:`_engine.Engine` and
:class:`_engine.Connection` API.
As of SQLAlchemy 2.0, this parameter is present for backwards
compatibility only and must remain at its default value of ``True``.
The :paramref:`_sa.create_engine.future` parameter will be
deprecated in a subsequent 2.x release and eventually removed.
.. versionadded:: 1.4
.. versionchanged:: 2.0 All :class:`_engine.Engine` objects are
"future" style engines and there is no longer a ``future=False``
mode of operation.
.. seealso::
:ref:`migration_20_toplevel`
:param hide_parameters: Boolean, when set to True, SQL statement parameters
will not be displayed in INFO logging nor will they be formatted into
the string representation of :class:`.StatementError` objects.
.. versionadded:: 1.3.8
.. seealso::
:ref:`dbengine_logging` - further detail on how to configure
logging.
:param implicit_returning=True: Legacy parameter that may only be set
to True. In SQLAlchemy 2.0, this parameter does nothing. In order to
disable "implicit returning" for statements invoked by the ORM,
configure this on a per-table basis using the
:paramref:`.Table.implicit_returning` parameter.
:param insertmanyvalues_page_size: number of rows to format into an
INSERT statement when the statement uses "insertmanyvalues" mode, which is
a paged form of bulk insert that is used for many backends when using
:term:`executemany` execution typically in conjunction with RETURNING.
Defaults to 1000, but may also be subject to dialect-specific limiting
factors which may override this value on a per-statement basis.
.. versionadded:: 2.0
.. seealso::
:ref:`engine_insertmanyvalues`
:ref:`engine_insertmanyvalues_page_size`
:paramref:`_engine.Connection.execution_options.insertmanyvalues_page_size`
:param isolation_level: optional string name of an isolation level
which will be set on all new connections unconditionally.
Isolation levels are typically some subset of the string names
``"SERIALIZABLE"``, ``"REPEATABLE READ"``,
``"READ COMMITTED"``, ``"READ UNCOMMITTED"`` and ``"AUTOCOMMIT"``
based on backend.
The :paramref:`_sa.create_engine.isolation_level` parameter is
in contrast to the
:paramref:`.Connection.execution_options.isolation_level`
execution option, which may be set on an individual
:class:`.Connection`, as well as the same parameter passed to
:meth:`.Engine.execution_options`, where it may be used to create
multiple engines with different isolation levels that share a common
connection pool and dialect.
.. versionchanged:: 2.0 The
:paramref:`_sa.create_engine.isolation_level`
parameter has been generalized to work on all dialects which support
the concept of isolation level, and is provided as a more succinct,
up front configuration switch in contrast to the execution option
which is more of an ad-hoc programmatic option.
.. seealso::
:ref:`dbapi_autocommit`
:param json_deserializer: for dialects that support the
:class:`_types.JSON`
datatype, this is a Python callable that will convert a JSON string
to a Python object. By default, the Python ``json.loads`` function is
used.
.. versionchanged:: 1.3.7 The SQLite dialect renamed this from
``_json_deserializer``.
:param json_serializer: for dialects that support the :class:`_types.JSON`
datatype, this is a Python callable that will render a given object
as JSON. By default, the Python ``json.dumps`` function is used.
.. versionchanged:: 1.3.7 The SQLite dialect renamed this from
``_json_serializer``.
:param label_length=None: optional integer value which limits
the size of dynamically generated column labels to that many
characters. If less than 6, labels are generated as
"_(counter)". If ``None``, the value of
``dialect.max_identifier_length``, which may be affected via the
:paramref:`_sa.create_engine.max_identifier_length` parameter,
is used instead. The value of
:paramref:`_sa.create_engine.label_length`
may not be larger than that of
:paramref:`_sa.create_engine.max_identfier_length`.
.. seealso::
:paramref:`_sa.create_engine.max_identifier_length`
:param logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.engine" logger. Defaults to a hexstring of the
object's id.
.. seealso::
:ref:`dbengine_logging` - further detail on how to configure
logging.
:paramref:`_engine.Connection.execution_options.logging_token`
:param max_identifier_length: integer; override the max_identifier_length
determined by the dialect. if ``None`` or zero, has no effect. This
is the database's configured maximum number of characters that may be
used in a SQL identifier such as a table name, column name, or label
name. All dialects determine this value automatically, however in the
case of a new database version for which this value has changed but
SQLAlchemy's dialect has not been adjusted, the value may be passed
here.
.. versionadded:: 1.3.9
.. seealso::
:paramref:`_sa.create_engine.label_length`
:param max_overflow=10: the number of connections to allow in
connection pool "overflow", that is connections that can be
opened above and beyond the pool_size setting, which defaults
to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
:param module=None: reference to a Python module object (the module
itself, not its string name). Specifies an alternate DBAPI module to
be used by the engine's dialect. Each sub-dialect references a
specific DBAPI which will be imported before first connect. This
parameter causes the import to be bypassed, and the given module to
be used instead. Can be used for testing of DBAPIs as well as to
inject "mock" DBAPI implementations into the :class:`_engine.Engine`.
:param paramstyle=None: The `paramstyle <https://legacy.python.org/dev/peps/pep-0249/#paramstyle>`_
to use when rendering bound parameters. This style defaults to the
one recommended by the DBAPI itself, which is retrieved from the
``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept
more than one paramstyle, and in particular it may be desirable
to change a "named" paramstyle into a "positional" one, or vice versa.
When this attribute is passed, it should be one of the values
``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or
``"pyformat"``, and should correspond to a parameter style known
to be supported by the DBAPI in use.
:param pool=None: an already-constructed instance of
:class:`~sqlalchemy.pool.Pool`, such as a
:class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this
pool will be used directly as the underlying connection pool
for the engine, bypassing whatever connection parameters are
present in the URL argument. For information on constructing
connection pools manually, see :ref:`pooling_toplevel`.
:param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
subclass, which will be used to create a connection pool
instance using the connection parameters given in the URL. Note
this differs from ``pool`` in that you don't actually
instantiate the pool in this case, you just indicate what type
of pool to be used.
:param pool_logging_name: String identifier which will be used within
the "name" field of logging records generated within the
"sqlalchemy.pool" logger. Defaults to a hexstring of the object's
id.
.. seealso::
:ref:`dbengine_logging` - further detail on how to configure
logging.
:param pool_pre_ping: boolean, if True will enable the connection pool
"pre-ping" feature that tests connections for liveness upon
each checkout.
.. versionadded:: 1.2
.. seealso::
:ref:`pool_disconnects_pessimistic`
:param pool_size=5: the number of connections to keep open
inside the connection pool. This used with
:class:`~sqlalchemy.pool.QueuePool` as
well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With
:class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting
of 0 indicates no limit; to disable pooling, set ``poolclass`` to
:class:`~sqlalchemy.pool.NullPool` instead.
:param pool_recycle=-1: this setting causes the pool to recycle
connections after the given number of seconds has passed. It
defaults to -1, or no timeout. For example, setting to 3600
means connections will be recycled after one hour. Note that
MySQL in particular will disconnect automatically if no
activity is detected on a connection for eight hours (although
this is configurable with the MySQLDB connection itself and the
server configuration as well).
.. seealso::
:ref:`pool_setting_recycle`
:param pool_reset_on_return='rollback': set the
:paramref:`_pool.Pool.reset_on_return` parameter of the underlying
:class:`_pool.Pool` object, which can be set to the values
``"rollback"``, ``"commit"``, or ``None``.
.. seealso::
:ref:`pool_reset_on_return`
:param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is
subject to the limitations of Python time functions which may not be
reliable in the tens of milliseconds.
.. note: don't use 30.0 above, it seems to break with the :param tag
:param pool_use_lifo=False: use LIFO (last-in-first-out) when retrieving
connections from :class:`.QueuePool` instead of FIFO
(first-in-first-out). Using LIFO, a server-side timeout scheme can
reduce the number of connections used during non- peak periods of
use. When planning for server-side timeouts, ensure that a recycle or
pre-ping strategy is in use to gracefully handle stale connections.
.. versionadded:: 1.3
.. seealso::
:ref:`pool_use_lifo`
:ref:`pool_disconnects`
:param plugins: string list of plugin names to load. See
:class:`.CreateEnginePlugin` for background.
.. versionadded:: 1.2.3
:param query_cache_size: size of the cache used to cache the SQL string
form of queries. Set to zero to disable caching.
The cache is pruned of its least recently used items when its size reaches
N * 1.5. Defaults to 500, meaning the cache will always store at least
500 SQL statements when filled, and will grow up to 750 items at which
point it is pruned back down to 500 by removing the 250 least recently
used items.
Caching is accomplished on a per-statement basis by generating a
cache key that represents the statement's structure, then generating
string SQL for the current dialect only if that key is not present
in the cache. All statements support caching, however some features
such as an INSERT with a large set of parameters will intentionally
bypass the cache. SQL logging will indicate statistics for each
statement whether or not it were pull from the cache.
.. note:: some ORM functions related to unit-of-work persistence as well
as some attribute loading strategies will make use of individual
per-mapper caches outside of the main cache.
.. seealso::
:ref:`sql_caching`
.. versionadded:: 1.4
:param use_insertmanyvalues: True by default, use the "insertmanyvalues"
execution style for INSERT..RETURNING statements by default.
.. versionadded:: 2.0
.. seealso::
:ref:`engine_insertmanyvalues`
""" # noqa
if "strategy" in kwargs:
strat = kwargs.pop("strategy")
if strat == "mock":
# this case is deprecated
return create_mock_engine(url, **kwargs) # type: ignore
else:
raise exc.ArgumentError("unknown strategy: %r" % strat)
kwargs.pop("empty_in_strategy", None)
# create url.URL object
u = _url.make_url(url)
u, plugins, kwargs = u._instantiate_plugins(kwargs)
entrypoint = u._get_entrypoint()
_is_async = kwargs.pop("_is_async", False)
if _is_async:
dialect_cls = entrypoint.get_async_dialect_cls(u)
else:
dialect_cls = entrypoint.get_dialect_cls(u)
if kwargs.pop("_coerce_config", False):
def pop_kwarg(key: str, default: Optional[Any] = None) -> Any:
value = kwargs.pop(key, default)
if key in dialect_cls.engine_config_types:
value = dialect_cls.engine_config_types[key](value)
return value
else:
pop_kwarg = kwargs.pop # type: ignore
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kwargs:
dialect_args[k] = pop_kwarg(k)
dbapi = kwargs.pop("module", None)
if dbapi is None:
dbapi_args = {}
if "import_dbapi" in dialect_cls.__dict__:
dbapi_meth = dialect_cls.import_dbapi
elif hasattr(dialect_cls, "dbapi") and inspect.ismethod(
dialect_cls.dbapi
):
util.warn_deprecated(
"The dbapi() classmethod on dialect classes has been "
"renamed to import_dbapi(). Implement an import_dbapi() "
f"classmethod directly on class {dialect_cls} to remove this "
"warning; the old .dbapi() classmethod may be maintained for "
"backwards compatibility.",
"2.0",
)
dbapi_meth = dialect_cls.dbapi
else:
dbapi_meth = dialect_cls.import_dbapi
for k in util.get_func_kwargs(dbapi_meth):
if k in kwargs:
dbapi_args[k] = pop_kwarg(k)
dbapi = dbapi_meth(**dbapi_args)
dialect_args["dbapi"] = dbapi
dialect_args.setdefault("compiler_linting", compiler.NO_LINTING)
enable_from_linting = kwargs.pop("enable_from_linting", True)
if enable_from_linting:
dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS
for plugin in plugins:
plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
# create dialect
dialect = dialect_cls(**dialect_args)
# assemble connection arguments
(cargs_tup, cparams) = dialect.create_connect_args(u)
cparams.update(pop_kwarg("connect_args", {}))
if "async_fallback" in cparams and util.asbool(cparams["async_fallback"]):
util.warn_deprecated(
"The async_fallback dialect argument is deprecated and will be "
"removed in SQLAlchemy 2.1.",
"2.0",
)
cargs = list(cargs_tup) # allow mutability
# look for existing pool or create
pool = pop_kwarg("pool", None)
if pool is None:
def connect(
connection_record: Optional[ConnectionPoolEntry] = None,
) -> DBAPIConnection:
if dialect._has_events:
for fn in dialect.dispatch.do_connect:
connection = cast(
DBAPIConnection,
fn(dialect, connection_record, cargs, cparams),
)
if connection is not None:
return connection
return dialect.connect(*cargs, **cparams)
creator = pop_kwarg("creator", connect)
poolclass = pop_kwarg("poolclass", None)
if poolclass is None:
poolclass = dialect.get_dialect_pool_class(u)
pool_args = {"dialect": dialect}
# consume pool arguments from kwargs, translating a few of
# the arguments
for k in util.get_cls_kwargs(poolclass):
tk = _pool_translate_kwargs.get(k, k)
if tk in kwargs:
pool_args[k] = pop_kwarg(tk)
for plugin in plugins:
plugin.handle_pool_kwargs(poolclass, pool_args)
pool = poolclass(creator, **pool_args)
else:
pool._dialect = dialect
# create engine.
if not pop_kwarg("future", True):
raise exc.ArgumentError(
"The 'future' parameter passed to "
"create_engine() may only be set to True."
)
engineclass = base.Engine
engine_args = {}
for k in util.get_cls_kwargs(engineclass):
if k in kwargs:
engine_args[k] = pop_kwarg(k)
# internal flags used by the test suite for instrumenting / proxying
# engines with mocks etc.
_initialize = kwargs.pop("_initialize", True)
# all kwargs should be consumed
if kwargs:
raise TypeError(
"Invalid argument(s) %s sent to create_engine(), "
"using configuration %s/%s/%s. Please check that the "
"keyword arguments are appropriate for this combination "
"of components."
% (
",".join("'%s'" % k for k in kwargs),
dialect.__class__.__name__,
pool.__class__.__name__,
engineclass.__name__,
)
)
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
do_on_connect = dialect.on_connect_url(u)
if do_on_connect:
def on_connect(
dbapi_connection: DBAPIConnection,
connection_record: ConnectionPoolEntry,
) -> None:
assert do_on_connect is not None
do_on_connect(dbapi_connection)
event.listen(pool, "connect", on_connect)
builtin_on_connect = dialect._builtin_onconnect()
if builtin_on_connect:
event.listen(pool, "connect", builtin_on_connect)
def first_connect(
dbapi_connection: DBAPIConnection,
connection_record: ConnectionPoolEntry,
) -> None:
c = base.Connection(
engine,
connection=_AdhocProxiedConnection(
dbapi_connection, connection_record
),
_has_events=False,
# reconnecting will be a reentrant condition, so if the
# connection goes away, Connection is then closed
_allow_revalidate=False,
# dont trigger the autobegin sequence
# within the up front dialect checks
_allow_autobegin=False,
)
c._execution_options = util.EMPTY_DICT
try:
dialect.initialize(c)
finally:
# note that "invalidated" and "closed" are mutually
# exclusive in 1.4 Connection.
if not c.invalidated and not c.closed:
# transaction is rolled back otherwise, tested by
# test/dialect/postgresql/test_dialect.py
# ::MiscBackendTest::test_initial_transaction_state
dialect.do_rollback(c.connection)
# previously, the "first_connect" event was used here, which was then
# scaled back if the "on_connect" handler were present. now,
# since "on_connect" is virtually always present, just use
# "connect" event with once_unless_exception in all cases so that
# the connection event flow is consistent in all cases.
event.listen(
pool, "connect", first_connect, _once_unless_exception=True
)
dialect_cls.engine_created(engine)
if entrypoint is not dialect_cls:
entrypoint.engine_created(engine)
for plugin in plugins:
plugin.engine_created(engine)
return engine
def engine_from_config(
configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any
) -> Engine:
"""Create a new Engine instance using a configuration dictionary.
The dictionary is typically produced from a config file.
The keys of interest to ``engine_from_config()`` should be prefixed, e.g.
``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument
indicates the prefix to be searched for. Each matching key (after the
prefix is stripped) is treated as though it were the corresponding keyword
argument to a :func:`_sa.create_engine` call.
The only required key is (assuming the default prefix) ``sqlalchemy.url``,
which provides the :ref:`database URL <database_urls>`.
A select set of keyword arguments will be "coerced" to their
expected type based on string values. The set of arguments
is extensible per-dialect using the ``engine_config_types`` accessor.
:param configuration: A dictionary (typically produced from a config file,
but this is not a requirement). Items whose keys start with the value
of 'prefix' will have that prefix stripped, and will then be passed to
:func:`_sa.create_engine`.
:param prefix: Prefix to match and then strip from keys
in 'configuration'.
:param kwargs: Each keyword argument to ``engine_from_config()`` itself
overrides the corresponding item taken from the 'configuration'
dictionary. Keyword arguments should *not* be prefixed.
"""
options = {
key[len(prefix) :]: configuration[key]
for key in configuration
if key.startswith(prefix)
}
options["_coerce_config"] = True
options.update(kwargs)
url = options.pop("url")
return create_engine(url, **options)
@overload
def create_pool_from_url(
url: Union[str, URL],
*,
poolclass: Optional[Type[Pool]] = ...,
logging_name: str = ...,
pre_ping: bool = ...,
size: int = ...,
recycle: int = ...,
reset_on_return: Optional[_ResetStyleArgType] = ...,
timeout: float = ...,
use_lifo: bool = ...,
**kwargs: Any,
) -> Pool: ...
@overload
def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: ...
def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool:
"""Create a pool instance from the given url.
If ``poolclass`` is not provided the pool class used
is selected using the dialect specified in the URL.
The arguments passed to :func:`_sa.create_pool_from_url` are
identical to the pool argument passed to the :func:`_sa.create_engine`
function.
.. versionadded:: 2.0.10
"""
for key in _pool_translate_kwargs:
if key in kwargs:
kwargs[_pool_translate_kwargs[key]] = kwargs.pop(key)
engine = create_engine(url, **kwargs, _initialize=False)
return engine.pool
_pool_translate_kwargs = immutabledict(
{
"logging_name": "pool_logging_name",
"echo": "echo_pool",
"timeout": "pool_timeout",
"recycle": "pool_recycle",
"events": "pool_events", # deprecated
"reset_on_return": "pool_reset_on_return",
"pre_ping": "pool_pre_ping",
"use_lifo": "pool_use_lifo",
}
)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,951 @@
# engine/events.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import typing
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union
from .base import Connection
from .base import Engine
from .interfaces import ConnectionEventsTarget
from .interfaces import DBAPIConnection
from .interfaces import DBAPICursor
from .interfaces import Dialect
from .. import event
from .. import exc
from ..util.typing import Literal
if typing.TYPE_CHECKING:
from .interfaces import _CoreMultiExecuteParams
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPIAnyExecuteParams
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import ExceptionContext
from .interfaces import ExecutionContext
from .result import Result
from ..pool import ConnectionPoolEntry
from ..sql import Executable
from ..sql.elements import BindParameter
class ConnectionEvents(event.Events[ConnectionEventsTarget]):
"""Available events for
:class:`_engine.Connection` and :class:`_engine.Engine`.
The methods here define the name of an event as well as the names of
members that are passed to listener functions.
An event listener can be associated with any
:class:`_engine.Connection` or :class:`_engine.Engine`
class or instance, such as an :class:`_engine.Engine`, e.g.::
from sqlalchemy import event, create_engine
def before_cursor_execute(conn, cursor, statement, parameters, context,
executemany):
log.info("Received statement: %s", statement)
engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/test')
event.listen(engine, "before_cursor_execute", before_cursor_execute)
or with a specific :class:`_engine.Connection`::
with engine.begin() as conn:
@event.listens_for(conn, 'before_cursor_execute')
def before_cursor_execute(conn, cursor, statement, parameters,
context, executemany):
log.info("Received statement: %s", statement)
When the methods are called with a `statement` parameter, such as in
:meth:`.after_cursor_execute` or :meth:`.before_cursor_execute`,
the statement is the exact SQL string that was prepared for transmission
to the DBAPI ``cursor`` in the connection's :class:`.Dialect`.
The :meth:`.before_execute` and :meth:`.before_cursor_execute`
events can also be established with the ``retval=True`` flag, which
allows modification of the statement and parameters to be sent
to the database. The :meth:`.before_cursor_execute` event is
particularly useful here to add ad-hoc string transformations, such
as comments, to all executions::
from sqlalchemy.engine import Engine
from sqlalchemy import event
@event.listens_for(Engine, "before_cursor_execute", retval=True)
def comment_sql_calls(conn, cursor, statement, parameters,
context, executemany):
statement = statement + " -- some comment"
return statement, parameters
.. note:: :class:`_events.ConnectionEvents` can be established on any
combination of :class:`_engine.Engine`, :class:`_engine.Connection`,
as well
as instances of each of those classes. Events across all
four scopes will fire off for a given instance of
:class:`_engine.Connection`. However, for performance reasons, the
:class:`_engine.Connection` object determines at instantiation time
whether or not its parent :class:`_engine.Engine` has event listeners
established. Event listeners added to the :class:`_engine.Engine`
class or to an instance of :class:`_engine.Engine`
*after* the instantiation
of a dependent :class:`_engine.Connection` instance will usually
*not* be available on that :class:`_engine.Connection` instance.
The newly
added listeners will instead take effect for
:class:`_engine.Connection`
instances created subsequent to those event listeners being
established on the parent :class:`_engine.Engine` class or instance.
:param retval=False: Applies to the :meth:`.before_execute` and
:meth:`.before_cursor_execute` events only. When True, the
user-defined event function must have a return value, which
is a tuple of parameters that replace the given statement
and parameters. See those methods for a description of
specific return arguments.
""" # noqa
_target_class_doc = "SomeEngine"
_dispatch_target = ConnectionEventsTarget
@classmethod
def _accept_with(
cls,
target: Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]],
identifier: str,
) -> Optional[Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]]]:
default_dispatch = super()._accept_with(target, identifier)
if default_dispatch is None and hasattr(
target, "_no_async_engine_events"
):
target._no_async_engine_events()
return default_dispatch
@classmethod
def _listen(
cls,
event_key: event._EventKey[ConnectionEventsTarget],
*,
retval: bool = False,
**kw: Any,
) -> None:
target, identifier, fn = (
event_key.dispatch_target,
event_key.identifier,
event_key._listen_fn,
)
target._has_events = True
if not retval:
if identifier == "before_execute":
orig_fn = fn
def wrap_before_execute( # type: ignore
conn, clauseelement, multiparams, params, execution_options
):
orig_fn(
conn,
clauseelement,
multiparams,
params,
execution_options,
)
return clauseelement, multiparams, params
fn = wrap_before_execute
elif identifier == "before_cursor_execute":
orig_fn = fn
def wrap_before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
orig_fn(
conn,
cursor,
statement,
parameters,
context,
executemany,
)
return statement, parameters
fn = wrap_before_cursor_execute
elif retval and identifier not in (
"before_execute",
"before_cursor_execute",
):
raise exc.ArgumentError(
"Only the 'before_execute', "
"'before_cursor_execute' and 'handle_error' engine "
"event listeners accept the 'retval=True' "
"argument."
)
event_key.with_wrapper(fn).base_listen()
@event._legacy_signature(
"1.4",
["conn", "clauseelement", "multiparams", "params"],
lambda conn, clauseelement, multiparams, params, execution_options: (
conn,
clauseelement,
multiparams,
params,
),
)
def before_execute(
self,
conn: Connection,
clauseelement: Executable,
multiparams: _CoreMultiExecuteParams,
params: _CoreSingleExecuteParams,
execution_options: _ExecuteOptions,
) -> Optional[
Tuple[Executable, _CoreMultiExecuteParams, _CoreSingleExecuteParams]
]:
"""Intercept high level execute() events, receiving uncompiled
SQL constructs and other objects prior to rendering into SQL.
This event is good for debugging SQL compilation issues as well
as early manipulation of the parameters being sent to the database,
as the parameter lists will be in a consistent format here.
This event can be optionally established with the ``retval=True``
flag. The ``clauseelement``, ``multiparams``, and ``params``
arguments should be returned as a three-tuple in this case::
@event.listens_for(Engine, "before_execute", retval=True)
def before_execute(conn, clauseelement, multiparams, params):
# do something with clauseelement, multiparams, params
return clauseelement, multiparams, params
:param conn: :class:`_engine.Connection` object
:param clauseelement: SQL expression construct, :class:`.Compiled`
instance, or string statement passed to
:meth:`_engine.Connection.execute`.
:param multiparams: Multiple parameter sets, a list of dictionaries.
:param params: Single parameter set, a single dictionary.
:param execution_options: dictionary of execution
options passed along with the statement, if any. This is a merge
of all options that will be used, including those of the statement,
the connection, and those passed in to the method itself for
the 2.0 style of execution.
.. versionadded: 1.4
.. seealso::
:meth:`.before_cursor_execute`
"""
@event._legacy_signature(
"1.4",
["conn", "clauseelement", "multiparams", "params", "result"],
lambda conn, clauseelement, multiparams, params, execution_options, result: ( # noqa
conn,
clauseelement,
multiparams,
params,
result,
),
)
def after_execute(
self,
conn: Connection,
clauseelement: Executable,
multiparams: _CoreMultiExecuteParams,
params: _CoreSingleExecuteParams,
execution_options: _ExecuteOptions,
result: Result[Any],
) -> None:
"""Intercept high level execute() events after execute.
:param conn: :class:`_engine.Connection` object
:param clauseelement: SQL expression construct, :class:`.Compiled`
instance, or string statement passed to
:meth:`_engine.Connection.execute`.
:param multiparams: Multiple parameter sets, a list of dictionaries.
:param params: Single parameter set, a single dictionary.
:param execution_options: dictionary of execution
options passed along with the statement, if any. This is a merge
of all options that will be used, including those of the statement,
the connection, and those passed in to the method itself for
the 2.0 style of execution.
.. versionadded: 1.4
:param result: :class:`_engine.CursorResult` generated by the
execution.
"""
def before_cursor_execute(
self,
conn: Connection,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIAnyExecuteParams,
context: Optional[ExecutionContext],
executemany: bool,
) -> Optional[Tuple[str, _DBAPIAnyExecuteParams]]:
"""Intercept low-level cursor execute() events before execution,
receiving the string SQL statement and DBAPI-specific parameter list to
be invoked against a cursor.
This event is a good choice for logging as well as late modifications
to the SQL string. It's less ideal for parameter modifications except
for those which are specific to a target backend.
This event can be optionally established with the ``retval=True``
flag. The ``statement`` and ``parameters`` arguments should be
returned as a two-tuple in this case::
@event.listens_for(Engine, "before_cursor_execute", retval=True)
def before_cursor_execute(conn, cursor, statement,
parameters, context, executemany):
# do something with statement, parameters
return statement, parameters
See the example at :class:`_events.ConnectionEvents`.
:param conn: :class:`_engine.Connection` object
:param cursor: DBAPI cursor object
:param statement: string SQL statement, as to be passed to the DBAPI
:param parameters: Dictionary, tuple, or list of parameters being
passed to the ``execute()`` or ``executemany()`` method of the
DBAPI ``cursor``. In some cases may be ``None``.
:param context: :class:`.ExecutionContext` object in use. May
be ``None``.
:param executemany: boolean, if ``True``, this is an ``executemany()``
call, if ``False``, this is an ``execute()`` call.
.. seealso::
:meth:`.before_execute`
:meth:`.after_cursor_execute`
"""
def after_cursor_execute(
self,
conn: Connection,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIAnyExecuteParams,
context: Optional[ExecutionContext],
executemany: bool,
) -> None:
"""Intercept low-level cursor execute() events after execution.
:param conn: :class:`_engine.Connection` object
:param cursor: DBAPI cursor object. Will have results pending
if the statement was a SELECT, but these should not be consumed
as they will be needed by the :class:`_engine.CursorResult`.
:param statement: string SQL statement, as passed to the DBAPI
:param parameters: Dictionary, tuple, or list of parameters being
passed to the ``execute()`` or ``executemany()`` method of the
DBAPI ``cursor``. In some cases may be ``None``.
:param context: :class:`.ExecutionContext` object in use. May
be ``None``.
:param executemany: boolean, if ``True``, this is an ``executemany()``
call, if ``False``, this is an ``execute()`` call.
"""
@event._legacy_signature(
"2.0", ["conn", "branch"], converter=lambda conn: (conn, False)
)
def engine_connect(self, conn: Connection) -> None:
"""Intercept the creation of a new :class:`_engine.Connection`.
This event is called typically as the direct result of calling
the :meth:`_engine.Engine.connect` method.
It differs from the :meth:`_events.PoolEvents.connect` method, which
refers to the actual connection to a database at the DBAPI level;
a DBAPI connection may be pooled and reused for many operations.
In contrast, this event refers only to the production of a higher level
:class:`_engine.Connection` wrapper around such a DBAPI connection.
It also differs from the :meth:`_events.PoolEvents.checkout` event
in that it is specific to the :class:`_engine.Connection` object,
not the
DBAPI connection that :meth:`_events.PoolEvents.checkout` deals with,
although
this DBAPI connection is available here via the
:attr:`_engine.Connection.connection` attribute.
But note there can in fact
be multiple :meth:`_events.PoolEvents.checkout`
events within the lifespan
of a single :class:`_engine.Connection` object, if that
:class:`_engine.Connection`
is invalidated and re-established.
:param conn: :class:`_engine.Connection` object.
.. seealso::
:meth:`_events.PoolEvents.checkout`
the lower-level pool checkout event
for an individual DBAPI connection
"""
def set_connection_execution_options(
self, conn: Connection, opts: Dict[str, Any]
) -> None:
"""Intercept when the :meth:`_engine.Connection.execution_options`
method is called.
This method is called after the new :class:`_engine.Connection`
has been
produced, with the newly updated execution options collection, but
before the :class:`.Dialect` has acted upon any of those new options.
Note that this method is not called when a new
:class:`_engine.Connection`
is produced which is inheriting execution options from its parent
:class:`_engine.Engine`; to intercept this condition, use the
:meth:`_events.ConnectionEvents.engine_connect` event.
:param conn: The newly copied :class:`_engine.Connection` object
:param opts: dictionary of options that were passed to the
:meth:`_engine.Connection.execution_options` method.
This dictionary may be modified in place to affect the ultimate
options which take effect.
.. versionadded:: 2.0 the ``opts`` dictionary may be modified
in place.
.. seealso::
:meth:`_events.ConnectionEvents.set_engine_execution_options`
- event
which is called when :meth:`_engine.Engine.execution_options`
is called.
"""
def set_engine_execution_options(
self, engine: Engine, opts: Dict[str, Any]
) -> None:
"""Intercept when the :meth:`_engine.Engine.execution_options`
method is called.
The :meth:`_engine.Engine.execution_options` method produces a shallow
copy of the :class:`_engine.Engine` which stores the new options.
That new
:class:`_engine.Engine` is passed here.
A particular application of this
method is to add a :meth:`_events.ConnectionEvents.engine_connect`
event
handler to the given :class:`_engine.Engine`
which will perform some per-
:class:`_engine.Connection` task specific to these execution options.
:param conn: The newly copied :class:`_engine.Engine` object
:param opts: dictionary of options that were passed to the
:meth:`_engine.Connection.execution_options` method.
This dictionary may be modified in place to affect the ultimate
options which take effect.
.. versionadded:: 2.0 the ``opts`` dictionary may be modified
in place.
.. seealso::
:meth:`_events.ConnectionEvents.set_connection_execution_options`
- event
which is called when :meth:`_engine.Connection.execution_options`
is
called.
"""
def engine_disposed(self, engine: Engine) -> None:
"""Intercept when the :meth:`_engine.Engine.dispose` method is called.
The :meth:`_engine.Engine.dispose` method instructs the engine to
"dispose" of it's connection pool (e.g. :class:`_pool.Pool`), and
replaces it with a new one. Disposing of the old pool has the
effect that existing checked-in connections are closed. The new
pool does not establish any new connections until it is first used.
This event can be used to indicate that resources related to the
:class:`_engine.Engine` should also be cleaned up,
keeping in mind that the
:class:`_engine.Engine`
can still be used for new requests in which case
it re-acquires connection resources.
"""
def begin(self, conn: Connection) -> None:
"""Intercept begin() events.
:param conn: :class:`_engine.Connection` object
"""
def rollback(self, conn: Connection) -> None:
"""Intercept rollback() events, as initiated by a
:class:`.Transaction`.
Note that the :class:`_pool.Pool` also "auto-rolls back"
a DBAPI connection upon checkin, if the ``reset_on_return``
flag is set to its default value of ``'rollback'``.
To intercept this
rollback, use the :meth:`_events.PoolEvents.reset` hook.
:param conn: :class:`_engine.Connection` object
.. seealso::
:meth:`_events.PoolEvents.reset`
"""
def commit(self, conn: Connection) -> None:
"""Intercept commit() events, as initiated by a
:class:`.Transaction`.
Note that the :class:`_pool.Pool` may also "auto-commit"
a DBAPI connection upon checkin, if the ``reset_on_return``
flag is set to the value ``'commit'``. To intercept this
commit, use the :meth:`_events.PoolEvents.reset` hook.
:param conn: :class:`_engine.Connection` object
"""
def savepoint(self, conn: Connection, name: str) -> None:
"""Intercept savepoint() events.
:param conn: :class:`_engine.Connection` object
:param name: specified name used for the savepoint.
"""
def rollback_savepoint(
self, conn: Connection, name: str, context: None
) -> None:
"""Intercept rollback_savepoint() events.
:param conn: :class:`_engine.Connection` object
:param name: specified name used for the savepoint.
:param context: not used
"""
# TODO: deprecate "context"
def release_savepoint(
self, conn: Connection, name: str, context: None
) -> None:
"""Intercept release_savepoint() events.
:param conn: :class:`_engine.Connection` object
:param name: specified name used for the savepoint.
:param context: not used
"""
# TODO: deprecate "context"
def begin_twophase(self, conn: Connection, xid: Any) -> None:
"""Intercept begin_twophase() events.
:param conn: :class:`_engine.Connection` object
:param xid: two-phase XID identifier
"""
def prepare_twophase(self, conn: Connection, xid: Any) -> None:
"""Intercept prepare_twophase() events.
:param conn: :class:`_engine.Connection` object
:param xid: two-phase XID identifier
"""
def rollback_twophase(
self, conn: Connection, xid: Any, is_prepared: bool
) -> None:
"""Intercept rollback_twophase() events.
:param conn: :class:`_engine.Connection` object
:param xid: two-phase XID identifier
:param is_prepared: boolean, indicates if
:meth:`.TwoPhaseTransaction.prepare` was called.
"""
def commit_twophase(
self, conn: Connection, xid: Any, is_prepared: bool
) -> None:
"""Intercept commit_twophase() events.
:param conn: :class:`_engine.Connection` object
:param xid: two-phase XID identifier
:param is_prepared: boolean, indicates if
:meth:`.TwoPhaseTransaction.prepare` was called.
"""
class DialectEvents(event.Events[Dialect]):
"""event interface for execution-replacement functions.
These events allow direct instrumentation and replacement
of key dialect functions which interact with the DBAPI.
.. note::
:class:`.DialectEvents` hooks should be considered **semi-public**
and experimental.
These hooks are not for general use and are only for those situations
where intricate re-statement of DBAPI mechanics must be injected onto
an existing dialect. For general-use statement-interception events,
please use the :class:`_events.ConnectionEvents` interface.
.. seealso::
:meth:`_events.ConnectionEvents.before_cursor_execute`
:meth:`_events.ConnectionEvents.before_execute`
:meth:`_events.ConnectionEvents.after_cursor_execute`
:meth:`_events.ConnectionEvents.after_execute`
"""
_target_class_doc = "SomeEngine"
_dispatch_target = Dialect
@classmethod
def _listen(
cls,
event_key: event._EventKey[Dialect],
*,
retval: bool = False,
**kw: Any,
) -> None:
target = event_key.dispatch_target
target._has_events = True
event_key.base_listen()
@classmethod
def _accept_with(
cls,
target: Union[Engine, Type[Engine], Dialect, Type[Dialect]],
identifier: str,
) -> Optional[Union[Dialect, Type[Dialect]]]:
if isinstance(target, type):
if issubclass(target, Engine):
return Dialect
elif issubclass(target, Dialect):
return target
elif isinstance(target, Engine):
return target.dialect
elif isinstance(target, Dialect):
return target
elif isinstance(target, Connection) and identifier == "handle_error":
raise exc.InvalidRequestError(
"The handle_error() event hook as of SQLAlchemy 2.0 is "
"established on the Dialect, and may only be applied to the "
"Engine as a whole or to a specific Dialect as a whole, "
"not on a per-Connection basis."
)
elif hasattr(target, "_no_async_engine_events"):
target._no_async_engine_events()
else:
return None
def handle_error(
self, exception_context: ExceptionContext
) -> Optional[BaseException]:
r"""Intercept all exceptions processed by the
:class:`_engine.Dialect`, typically but not limited to those
emitted within the scope of a :class:`_engine.Connection`.
.. versionchanged:: 2.0 the :meth:`.DialectEvents.handle_error` event
is moved to the :class:`.DialectEvents` class, moved from the
:class:`.ConnectionEvents` class, so that it may also participate in
the "pre ping" operation configured with the
:paramref:`_sa.create_engine.pool_pre_ping` parameter. The event
remains registered by using the :class:`_engine.Engine` as the event
target, however note that using the :class:`_engine.Connection` as
an event target for :meth:`.DialectEvents.handle_error` is no longer
supported.
This includes all exceptions emitted by the DBAPI as well as
within SQLAlchemy's statement invocation process, including
encoding errors and other statement validation errors. Other areas
in which the event is invoked include transaction begin and end,
result row fetching, cursor creation.
Note that :meth:`.handle_error` may support new kinds of exceptions
and new calling scenarios at *any time*. Code which uses this
event must expect new calling patterns to be present in minor
releases.
To support the wide variety of members that correspond to an exception,
as well as to allow extensibility of the event without backwards
incompatibility, the sole argument received is an instance of
:class:`.ExceptionContext`. This object contains data members
representing detail about the exception.
Use cases supported by this hook include:
* read-only, low-level exception handling for logging and
debugging purposes
* Establishing whether a DBAPI connection error message indicates
that the database connection needs to be reconnected, including
for the "pre_ping" handler used by **some** dialects
* Establishing or disabling whether a connection or the owning
connection pool is invalidated or expired in response to a
specific exception
* exception re-writing
The hook is called while the cursor from the failed operation
(if any) is still open and accessible. Special cleanup operations
can be called on this cursor; SQLAlchemy will attempt to close
this cursor subsequent to this hook being invoked.
As of SQLAlchemy 2.0, the "pre_ping" handler enabled using the
:paramref:`_sa.create_engine.pool_pre_ping` parameter will also
participate in the :meth:`.handle_error` process, **for those dialects
that rely upon disconnect codes to detect database liveness**. Note
that some dialects such as psycopg, psycopg2, and most MySQL dialects
make use of a native ``ping()`` method supplied by the DBAPI which does
not make use of disconnect codes.
.. versionchanged:: 2.0.0 The :meth:`.DialectEvents.handle_error`
event hook participates in connection pool "pre-ping" operations.
Within this usage, the :attr:`.ExceptionContext.engine` attribute
will be ``None``, however the :class:`.Dialect` in use is always
available via the :attr:`.ExceptionContext.dialect` attribute.
.. versionchanged:: 2.0.5 Added :attr:`.ExceptionContext.is_pre_ping`
attribute which will be set to ``True`` when the
:meth:`.DialectEvents.handle_error` event hook is triggered within
a connection pool pre-ping operation.
.. versionchanged:: 2.0.5 An issue was repaired that allows for the
PostgreSQL ``psycopg`` and ``psycopg2`` drivers, as well as all
MySQL drivers, to properly participate in the
:meth:`.DialectEvents.handle_error` event hook during
connection pool "pre-ping" operations; previously, the
implementation was non-working for these drivers.
A handler function has two options for replacing
the SQLAlchemy-constructed exception into one that is user
defined. It can either raise this new exception directly, in
which case all further event listeners are bypassed and the
exception will be raised, after appropriate cleanup as taken
place::
@event.listens_for(Engine, "handle_error")
def handle_exception(context):
if isinstance(context.original_exception,
psycopg2.OperationalError) and \
"failed" in str(context.original_exception):
raise MySpecialException("failed operation")
.. warning:: Because the
:meth:`_events.DialectEvents.handle_error`
event specifically provides for exceptions to be re-thrown as
the ultimate exception raised by the failed statement,
**stack traces will be misleading** if the user-defined event
handler itself fails and throws an unexpected exception;
the stack trace may not illustrate the actual code line that
failed! It is advised to code carefully here and use
logging and/or inline debugging if unexpected exceptions are
occurring.
Alternatively, a "chained" style of event handling can be
used, by configuring the handler with the ``retval=True``
modifier and returning the new exception instance from the
function. In this case, event handling will continue onto the
next handler. The "chained" exception is available using
:attr:`.ExceptionContext.chained_exception`::
@event.listens_for(Engine, "handle_error", retval=True)
def handle_exception(context):
if context.chained_exception is not None and \
"special" in context.chained_exception.message:
return MySpecialException("failed",
cause=context.chained_exception)
Handlers that return ``None`` may be used within the chain; when
a handler returns ``None``, the previous exception instance,
if any, is maintained as the current exception that is passed onto the
next handler.
When a custom exception is raised or returned, SQLAlchemy raises
this new exception as-is, it is not wrapped by any SQLAlchemy
object. If the exception is not a subclass of
:class:`sqlalchemy.exc.StatementError`,
certain features may not be available; currently this includes
the ORM's feature of adding a detail hint about "autoflush" to
exceptions raised within the autoflush process.
:param context: an :class:`.ExceptionContext` object. See this
class for details on all available members.
.. seealso::
:ref:`pool_new_disconnect_codes`
"""
def do_connect(
self,
dialect: Dialect,
conn_rec: ConnectionPoolEntry,
cargs: Tuple[Any, ...],
cparams: Dict[str, Any],
) -> Optional[DBAPIConnection]:
"""Receive connection arguments before a connection is made.
This event is useful in that it allows the handler to manipulate the
cargs and/or cparams collections that control how the DBAPI
``connect()`` function will be called. ``cargs`` will always be a
Python list that can be mutated in-place, and ``cparams`` a Python
dictionary that may also be mutated::
e = create_engine("postgresql+psycopg2://user@host/dbname")
@event.listens_for(e, 'do_connect')
def receive_do_connect(dialect, conn_rec, cargs, cparams):
cparams["password"] = "some_password"
The event hook may also be used to override the call to ``connect()``
entirely, by returning a non-``None`` DBAPI connection object::
e = create_engine("postgresql+psycopg2://user@host/dbname")
@event.listens_for(e, 'do_connect')
def receive_do_connect(dialect, conn_rec, cargs, cparams):
return psycopg2.connect(*cargs, **cparams)
.. seealso::
:ref:`custom_dbapi_args`
"""
def do_executemany(
self,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIMultiExecuteParams,
context: ExecutionContext,
) -> Optional[Literal[True]]:
"""Receive a cursor to have executemany() called.
Return the value True to halt further events from invoking,
and to indicate that the cursor execution has already taken
place within the event handler.
"""
def do_execute_no_params(
self, cursor: DBAPICursor, statement: str, context: ExecutionContext
) -> Optional[Literal[True]]:
"""Receive a cursor to have execute() with no parameters called.
Return the value True to halt further events from invoking,
and to indicate that the cursor execution has already taken
place within the event handler.
"""
def do_execute(
self,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPISingleExecuteParams,
context: ExecutionContext,
) -> Optional[Literal[True]]:
"""Receive a cursor to have execute() called.
Return the value True to halt further events from invoking,
and to indicate that the cursor execution has already taken
place within the event handler.
"""
def do_setinputsizes(
self,
inputsizes: Dict[BindParameter[Any], Any],
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIAnyExecuteParams,
context: ExecutionContext,
) -> None:
"""Receive the setinputsizes dictionary for possible modification.
This event is emitted in the case where the dialect makes use of the
DBAPI ``cursor.setinputsizes()`` method which passes information about
parameter binding for a particular statement. The given
``inputsizes`` dictionary will contain :class:`.BindParameter` objects
as keys, linked to DBAPI-specific type objects as values; for
parameters that are not bound, they are added to the dictionary with
``None`` as the value, which means the parameter will not be included
in the ultimate setinputsizes call. The event may be used to inspect
and/or log the datatypes that are being bound, as well as to modify the
dictionary in place. Parameters can be added, modified, or removed
from this dictionary. Callers will typically want to inspect the
:attr:`.BindParameter.type` attribute of the given bind objects in
order to make decisions about the DBAPI object.
After the event, the ``inputsizes`` dictionary is converted into
an appropriate datastructure to be passed to ``cursor.setinputsizes``;
either a list for a positional bound parameter execution style,
or a dictionary of string parameter keys to DBAPI type objects for
a named bound parameter execution style.
The setinputsizes hook overall is only used for dialects which include
the flag ``use_setinputsizes=True``. Dialects which use this
include cx_Oracle, pg8000, asyncpg, and pyodbc dialects.
.. note::
For use with pyodbc, the ``use_setinputsizes`` flag
must be passed to the dialect, e.g.::
create_engine("mssql+pyodbc://...", use_setinputsizes=True)
.. seealso::
:ref:`mssql_pyodbc_setinputsizes`
.. versionadded:: 1.2.9
.. seealso::
:ref:`cx_oracle_setinputsizes`
"""
pass

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,131 @@
# engine/mock.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from operator import attrgetter
import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Optional
from typing import Type
from typing import Union
from . import url as _url
from .. import util
if typing.TYPE_CHECKING:
from .base import Engine
from .interfaces import _CoreAnyExecuteParams
from .interfaces import CoreExecuteOptionsParameter
from .interfaces import Dialect
from .url import URL
from ..sql.base import Executable
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
class MockConnection:
def __init__(self, dialect: Dialect, execute: Callable[..., Any]):
self._dialect = dialect
self._execute_impl = execute
engine: Engine = cast(Any, property(lambda s: s))
dialect: Dialect = cast(Any, property(attrgetter("_dialect")))
name: str = cast(Any, property(lambda s: s._dialect.name))
def connect(self, **kwargs: Any) -> MockConnection:
return self
def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]:
return obj.schema
def execution_options(self, **kw: Any) -> MockConnection:
return self
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: SchemaItem,
**kwargs: Any,
) -> None:
kwargs["checkfirst"] = False
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
def execute(
self,
obj: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Any:
return self._execute_impl(obj, parameters)
def create_mock_engine(
url: Union[str, URL], executor: Any, **kw: Any
) -> MockConnection:
"""Create a "mock" engine used for echoing DDL.
This is a utility function used for debugging or storing the output of DDL
sequences as generated by :meth:`_schema.MetaData.create_all`
and related methods.
The function accepts a URL which is used only to determine the kind of
dialect to be used, as well as an "executor" callable function which
will receive a SQL expression object and parameters, which can then be
echoed or otherwise printed. The executor's return value is not handled,
nor does the engine allow regular string statements to be invoked, and
is therefore only useful for DDL that is sent to the database without
receiving any results.
E.g.::
from sqlalchemy import create_mock_engine
def dump(sql, *multiparams, **params):
print(sql.compile(dialect=engine.dialect))
engine = create_mock_engine('postgresql+psycopg2://', dump)
metadata.create_all(engine, checkfirst=False)
:param url: A string URL which typically needs to contain only the
database backend name.
:param executor: a callable which receives the arguments ``sql``,
``*multiparams`` and ``**params``. The ``sql`` parameter is typically
an instance of :class:`.ExecutableDDLElement`, which can then be compiled
into a string using :meth:`.ExecutableDDLElement.compile`.
.. versionadded:: 1.4 - the :func:`.create_mock_engine` function replaces
the previous "mock" engine strategy used with
:func:`_sa.create_engine`.
.. seealso::
:ref:`faq_ddl_as_string`
"""
# create url.URL object
u = _url.make_url(url)
dialect_cls = u.get_dialect()
dialect_args = {}
# consume dialect arguments from kwargs
for k in util.get_cls_kwargs(dialect_cls):
if k in kw:
dialect_args[k] = kw.pop(k)
# create dialect
dialect = dialect_cls(**dialect_args)
return MockConnection(dialect, executor)

View file

@ -0,0 +1,61 @@
# engine/processors.py
# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""defines generic type conversion functions, as used in bind and result
processors.
They all share one common characteristic: None is passed through unchanged.
"""
from __future__ import annotations
import typing
from ._py_processors import str_to_datetime_processor_factory # noqa
from ..util._has_cy import HAS_CYEXTENSION
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_processors import int_to_boolean as int_to_boolean
from ._py_processors import str_to_date as str_to_date
from ._py_processors import str_to_datetime as str_to_datetime
from ._py_processors import str_to_time as str_to_time
from ._py_processors import (
to_decimal_processor_factory as to_decimal_processor_factory,
)
from ._py_processors import to_float as to_float
from ._py_processors import to_str as to_str
else:
from sqlalchemy.cyextension.processors import (
DecimalResultProcessor,
)
from sqlalchemy.cyextension.processors import ( # noqa: F401
int_to_boolean as int_to_boolean,
)
from sqlalchemy.cyextension.processors import ( # noqa: F401,E501
str_to_date as str_to_date,
)
from sqlalchemy.cyextension.processors import ( # noqa: F401
str_to_datetime as str_to_datetime,
)
from sqlalchemy.cyextension.processors import ( # noqa: F401,E501
str_to_time as str_to_time,
)
from sqlalchemy.cyextension.processors import ( # noqa: F401,E501
to_float as to_float,
)
from sqlalchemy.cyextension.processors import ( # noqa: F401,E501
to_str as to_str,
)
def to_decimal_processor_factory(target_class, scale):
# Note that the scale argument is not taken into account for integer
# values in the C implementation while it is in the Python one.
# For example, the Python implementation might return
# Decimal('5.00000') whereas the C implementation will
# return Decimal('5'). These are equivalent of course.
return DecimalResultProcessor(target_class, "%%.%df" % scale).process

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,401 @@
# engine/row.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Define row constructs including :class:`.Row`."""
from __future__ import annotations
from abc import ABC
import collections.abc as collections_abc
import operator
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterator
from typing import List
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from ..sql import util as sql_util
from ..util import deprecated
from ..util._has_cy import HAS_CYEXTENSION
if TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_row import BaseRow as BaseRow
else:
from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow
if TYPE_CHECKING:
from .result import _KeyType
from .result import _ProcessorsType
from .result import RMKeyView
_T = TypeVar("_T", bound=Any)
_TP = TypeVar("_TP", bound=Tuple[Any, ...])
class Row(BaseRow, Sequence[Any], Generic[_TP]):
"""Represent a single result row.
The :class:`.Row` object represents a row of a database result. It is
typically associated in the 1.x series of SQLAlchemy with the
:class:`_engine.CursorResult` object, however is also used by the ORM for
tuple-like results as of SQLAlchemy 1.4.
The :class:`.Row` object seeks to act as much like a Python named
tuple as possible. For mapping (i.e. dictionary) behavior on a row,
such as testing for containment of keys, refer to the :attr:`.Row._mapping`
attribute.
.. seealso::
:ref:`tutorial_selecting_data` - includes examples of selecting
rows from SELECT statements.
.. versionchanged:: 1.4
Renamed ``RowProxy`` to :class:`.Row`. :class:`.Row` is no longer a
"proxy" object in that it contains the final form of data within it,
and now acts mostly like a named tuple. Mapping-like functionality is
moved to the :attr:`.Row._mapping` attribute. See
:ref:`change_4710_core` for background on this change.
"""
__slots__ = ()
def __setattr__(self, name: str, value: Any) -> NoReturn:
raise AttributeError("can't set attribute")
def __delattr__(self, name: str) -> NoReturn:
raise AttributeError("can't delete attribute")
def _tuple(self) -> _TP:
"""Return a 'tuple' form of this :class:`.Row`.
At runtime, this method returns "self"; the :class:`.Row` object is
already a named tuple. However, at the typing level, if this
:class:`.Row` is typed, the "tuple" return type will be a :pep:`484`
``Tuple`` datatype that contains typing information about individual
elements, supporting typed unpacking and attribute access.
.. versionadded:: 2.0.19 - The :meth:`.Row._tuple` method supersedes
the previous :meth:`.Row.tuple` method, which is now underscored
to avoid name conflicts with column names in the same way as other
named-tuple methods on :class:`.Row`.
.. seealso::
:attr:`.Row._t` - shorthand attribute notation
:meth:`.Result.tuples`
"""
return self # type: ignore
@deprecated(
"2.0.19",
"The :meth:`.Row.tuple` method is deprecated in favor of "
":meth:`.Row._tuple`; all :class:`.Row` "
"methods and library-level attributes are intended to be underscored "
"to avoid name conflicts. Please use :meth:`Row._tuple`.",
)
def tuple(self) -> _TP:
"""Return a 'tuple' form of this :class:`.Row`.
.. versionadded:: 2.0
"""
return self._tuple()
@property
def _t(self) -> _TP:
"""A synonym for :meth:`.Row._tuple`.
.. versionadded:: 2.0.19 - The :attr:`.Row._t` attribute supersedes
the previous :attr:`.Row.t` attribute, which is now underscored
to avoid name conflicts with column names in the same way as other
named-tuple methods on :class:`.Row`.
.. seealso::
:attr:`.Result.t`
"""
return self # type: ignore
@property
@deprecated(
"2.0.19",
"The :attr:`.Row.t` attribute is deprecated in favor of "
":attr:`.Row._t`; all :class:`.Row` "
"methods and library-level attributes are intended to be underscored "
"to avoid name conflicts. Please use :attr:`Row._t`.",
)
def t(self) -> _TP:
"""A synonym for :meth:`.Row._tuple`.
.. versionadded:: 2.0
"""
return self._t
@property
def _mapping(self) -> RowMapping:
"""Return a :class:`.RowMapping` for this :class:`.Row`.
This object provides a consistent Python mapping (i.e. dictionary)
interface for the data contained within the row. The :class:`.Row`
by itself behaves like a named tuple.
.. seealso::
:attr:`.Row._fields`
.. versionadded:: 1.4
"""
return RowMapping(self._parent, None, self._key_to_index, self._data)
def _filter_on_values(
self, processor: Optional[_ProcessorsType]
) -> Row[Any]:
return Row(self._parent, processor, self._key_to_index, self._data)
if not TYPE_CHECKING:
def _special_name_accessor(name: str) -> Any:
"""Handle ambiguous names such as "count" and "index" """
@property
def go(self: Row) -> Any:
if self._parent._has_key(name):
return self.__getattr__(name)
else:
def meth(*arg: Any, **kw: Any) -> Any:
return getattr(collections_abc.Sequence, name)(
self, *arg, **kw
)
return meth
return go
count = _special_name_accessor("count")
index = _special_name_accessor("index")
def __contains__(self, key: Any) -> bool:
return key in self._data
def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool:
return (
op(self._to_tuple_instance(), other._to_tuple_instance())
if isinstance(other, Row)
else op(self._to_tuple_instance(), other)
)
__hash__ = BaseRow.__hash__
if TYPE_CHECKING:
@overload
def __getitem__(self, index: int) -> Any: ...
@overload
def __getitem__(self, index: slice) -> Sequence[Any]: ...
def __getitem__(self, index: Union[int, slice]) -> Any: ...
def __lt__(self, other: Any) -> bool:
return self._op(other, operator.lt)
def __le__(self, other: Any) -> bool:
return self._op(other, operator.le)
def __ge__(self, other: Any) -> bool:
return self._op(other, operator.ge)
def __gt__(self, other: Any) -> bool:
return self._op(other, operator.gt)
def __eq__(self, other: Any) -> bool:
return self._op(other, operator.eq)
def __ne__(self, other: Any) -> bool:
return self._op(other, operator.ne)
def __repr__(self) -> str:
return repr(sql_util._repr_row(self))
@property
def _fields(self) -> Tuple[str, ...]:
"""Return a tuple of string keys as represented by this
:class:`.Row`.
The keys can represent the labels of the columns returned by a core
statement or the names of the orm classes returned by an orm
execution.
This attribute is analogous to the Python named tuple ``._fields``
attribute.
.. versionadded:: 1.4
.. seealso::
:attr:`.Row._mapping`
"""
return tuple([k for k in self._parent.keys if k is not None])
def _asdict(self) -> Dict[str, Any]:
"""Return a new dict which maps field names to their corresponding
values.
This method is analogous to the Python named tuple ``._asdict()``
method, and works by applying the ``dict()`` constructor to the
:attr:`.Row._mapping` attribute.
.. versionadded:: 1.4
.. seealso::
:attr:`.Row._mapping`
"""
return dict(self._mapping)
BaseRowProxy = BaseRow
RowProxy = Row
class ROMappingView(ABC):
__slots__ = ()
_items: Sequence[Any]
_mapping: Mapping["_KeyType", Any]
def __init__(
self, mapping: Mapping["_KeyType", Any], items: Sequence[Any]
):
self._mapping = mapping # type: ignore[misc]
self._items = items # type: ignore[misc]
def __len__(self) -> int:
return len(self._items)
def __repr__(self) -> str:
return "{0.__class__.__name__}({0._mapping!r})".format(self)
def __iter__(self) -> Iterator[Any]:
return iter(self._items)
def __contains__(self, item: Any) -> bool:
return item in self._items
def __eq__(self, other: Any) -> bool:
return list(other) == list(self)
def __ne__(self, other: Any) -> bool:
return list(other) != list(self)
class ROMappingKeysValuesView(
ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any]
):
__slots__ = ("_items",) # mapping slot is provided by KeysView
class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]):
__slots__ = ("_items",) # mapping slot is provided by ItemsView
class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]):
"""A ``Mapping`` that maps column names and objects to :class:`.Row`
values.
The :class:`.RowMapping` is available from a :class:`.Row` via the
:attr:`.Row._mapping` attribute, as well as from the iterable interface
provided by the :class:`.MappingResult` object returned by the
:meth:`_engine.Result.mappings` method.
:class:`.RowMapping` supplies Python mapping (i.e. dictionary) access to
the contents of the row. This includes support for testing of
containment of specific keys (string column names or objects), as well
as iteration of keys, values, and items::
for row in result:
if 'a' in row._mapping:
print("Column 'a': %s" % row._mapping['a'])
print("Column b: %s" % row._mapping[table.c.b])
.. versionadded:: 1.4 The :class:`.RowMapping` object replaces the
mapping-like access previously provided by a database result row,
which now seeks to behave mostly like a named tuple.
"""
__slots__ = ()
if TYPE_CHECKING:
def __getitem__(self, key: _KeyType) -> Any: ...
else:
__getitem__ = BaseRow._get_by_key_impl_mapping
def _values_impl(self) -> List[Any]:
return list(self._data)
def __iter__(self) -> Iterator[str]:
return (k for k in self._parent.keys if k is not None)
def __len__(self) -> int:
return len(self._data)
def __contains__(self, key: object) -> bool:
return self._parent._has_key(key)
def __repr__(self) -> str:
return repr(dict(self))
def items(self) -> ROMappingItemsView:
"""Return a view of key/value tuples for the elements in the
underlying :class:`.Row`.
"""
return ROMappingItemsView(
self, [(key, self[key]) for key in self.keys()]
)
def keys(self) -> RMKeyView:
"""Return a view of 'keys' for string column names represented
by the underlying :class:`.Row`.
"""
return self._parent.keys
def values(self) -> ROMappingKeysValuesView:
"""Return a view of values for the values represented in the
underlying :class:`.Row`.
"""
return ROMappingKeysValuesView(self, self._values_impl())

View file

@ -0,0 +1,19 @@
# engine/strategies.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Deprecated mock engine strategy used by Alembic.
"""
from __future__ import annotations
from .mock import MockConnection # noqa
class MockEngineStrategy:
MockConnection = MockConnection

View file

@ -0,0 +1,910 @@
# engine/url.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
information about a database connection specification.
The URL object is created automatically when
:func:`~sqlalchemy.engine.create_engine` is called with a string
argument; alternatively, the URL is a public-facing construct which can
be used directly and is also accepted directly by ``create_engine()``.
"""
from __future__ import annotations
import collections.abc as collections_abc
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import Union
from urllib.parse import parse_qsl
from urllib.parse import quote
from urllib.parse import quote_plus
from urllib.parse import unquote
from .interfaces import Dialect
from .. import exc
from .. import util
from ..dialects import plugins
from ..dialects import registry
class URL(NamedTuple):
"""
Represent the components of a URL used to connect to a database.
URLs are typically constructed from a fully formatted URL string, where the
:func:`.make_url` function is used internally by the
:func:`_sa.create_engine` function in order to parse the URL string into
its individual components, which are then used to construct a new
:class:`.URL` object. When parsing from a formatted URL string, the parsing
format generally follows
`RFC-1738 <https://www.ietf.org/rfc/rfc1738.txt>`_, with some exceptions.
A :class:`_engine.URL` object may also be produced directly, either by
using the :func:`.make_url` function with a fully formed URL string, or
by using the :meth:`_engine.URL.create` constructor in order
to construct a :class:`_engine.URL` programmatically given individual
fields. The resulting :class:`.URL` object may be passed directly to
:func:`_sa.create_engine` in place of a string argument, which will bypass
the usage of :func:`.make_url` within the engine's creation process.
.. versionchanged:: 1.4
The :class:`_engine.URL` object is now an immutable object. To
create a URL, use the :func:`_engine.make_url` or
:meth:`_engine.URL.create` function / method. To modify
a :class:`_engine.URL`, use methods like
:meth:`_engine.URL.set` and
:meth:`_engine.URL.update_query_dict` to return a new
:class:`_engine.URL` object with modifications. See notes for this
change at :ref:`change_5526`.
.. seealso::
:ref:`database_urls`
:class:`_engine.URL` contains the following attributes:
* :attr:`_engine.URL.drivername`: database backend and driver name, such as
``postgresql+psycopg2``
* :attr:`_engine.URL.username`: username string
* :attr:`_engine.URL.password`: password string
* :attr:`_engine.URL.host`: string hostname
* :attr:`_engine.URL.port`: integer port number
* :attr:`_engine.URL.database`: string database name
* :attr:`_engine.URL.query`: an immutable mapping representing the query
string. contains strings for keys and either strings or tuples of
strings for values.
"""
drivername: str
"""database backend and driver name, such as
``postgresql+psycopg2``
"""
username: Optional[str]
"username string"
password: Optional[str]
"""password, which is normally a string but may also be any
object that has a ``__str__()`` method."""
host: Optional[str]
"""hostname or IP number. May also be a data source name for some
drivers."""
port: Optional[int]
"""integer port number"""
database: Optional[str]
"""database name"""
query: util.immutabledict[str, Union[Tuple[str, ...], str]]
"""an immutable mapping representing the query string. contains strings
for keys and either strings or tuples of strings for values, e.g.::
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
>>> url.query
immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'})
To create a mutable copy of this mapping, use the ``dict`` constructor::
mutable_query_opts = dict(url.query)
.. seealso::
:attr:`_engine.URL.normalized_query` - normalizes all values into sequences
for consistent processing
Methods for altering the contents of :attr:`_engine.URL.query`:
:meth:`_engine.URL.update_query_dict`
:meth:`_engine.URL.update_query_string`
:meth:`_engine.URL.update_query_pairs`
:meth:`_engine.URL.difference_update_query`
""" # noqa: E501
@classmethod
def create(
cls,
drivername: str,
username: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
database: Optional[str] = None,
query: Mapping[str, Union[Sequence[str], str]] = util.EMPTY_DICT,
) -> URL:
"""Create a new :class:`_engine.URL` object.
.. seealso::
:ref:`database_urls`
:param drivername: the name of the database backend. This name will
correspond to a module in sqlalchemy/databases or a third party
plug-in.
:param username: The user name.
:param password: database password. Is typically a string, but may
also be an object that can be stringified with ``str()``.
.. note:: The password string should **not** be URL encoded when
passed as an argument to :meth:`_engine.URL.create`; the string
should contain the password characters exactly as they would be
typed.
.. note:: A password-producing object will be stringified only
**once** per :class:`_engine.Engine` object. For dynamic password
generation per connect, see :ref:`engines_dynamic_tokens`.
:param host: The name of the host.
:param port: The port number.
:param database: The database name.
:param query: A dictionary of string keys to string values to be passed
to the dialect and/or the DBAPI upon connect. To specify non-string
parameters to a Python DBAPI directly, use the
:paramref:`_sa.create_engine.connect_args` parameter to
:func:`_sa.create_engine`. See also
:attr:`_engine.URL.normalized_query` for a dictionary that is
consistently string->list of string.
:return: new :class:`_engine.URL` object.
.. versionadded:: 1.4
The :class:`_engine.URL` object is now an **immutable named
tuple**. In addition, the ``query`` dictionary is also immutable.
To create a URL, use the :func:`_engine.url.make_url` or
:meth:`_engine.URL.create` function/ method. To modify a
:class:`_engine.URL`, use the :meth:`_engine.URL.set` and
:meth:`_engine.URL.update_query` methods.
"""
return cls(
cls._assert_str(drivername, "drivername"),
cls._assert_none_str(username, "username"),
password,
cls._assert_none_str(host, "host"),
cls._assert_port(port),
cls._assert_none_str(database, "database"),
cls._str_dict(query),
)
@classmethod
def _assert_port(cls, port: Optional[int]) -> Optional[int]:
if port is None:
return None
try:
return int(port)
except TypeError:
raise TypeError("Port argument must be an integer or None")
@classmethod
def _assert_str(cls, v: str, paramname: str) -> str:
if not isinstance(v, str):
raise TypeError("%s must be a string" % paramname)
return v
@classmethod
def _assert_none_str(
cls, v: Optional[str], paramname: str
) -> Optional[str]:
if v is None:
return v
return cls._assert_str(v, paramname)
@classmethod
def _str_dict(
cls,
dict_: Optional[
Union[
Sequence[Tuple[str, Union[Sequence[str], str]]],
Mapping[str, Union[Sequence[str], str]],
]
],
) -> util.immutabledict[str, Union[Tuple[str, ...], str]]:
if dict_ is None:
return util.EMPTY_DICT
@overload
def _assert_value(
val: str,
) -> str: ...
@overload
def _assert_value(
val: Sequence[str],
) -> Union[str, Tuple[str, ...]]: ...
def _assert_value(
val: Union[str, Sequence[str]],
) -> Union[str, Tuple[str, ...]]:
if isinstance(val, str):
return val
elif isinstance(val, collections_abc.Sequence):
return tuple(_assert_value(elem) for elem in val)
else:
raise TypeError(
"Query dictionary values must be strings or "
"sequences of strings"
)
def _assert_str(v: str) -> str:
if not isinstance(v, str):
raise TypeError("Query dictionary keys must be strings")
return v
dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]]
if isinstance(dict_, collections_abc.Sequence):
dict_items = dict_
else:
dict_items = dict_.items()
return util.immutabledict(
{
_assert_str(key): _assert_value(
value,
)
for key, value in dict_items
}
)
def set(
self,
drivername: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
database: Optional[str] = None,
query: Optional[Mapping[str, Union[Sequence[str], str]]] = None,
) -> URL:
"""return a new :class:`_engine.URL` object with modifications.
Values are used if they are non-None. To set a value to ``None``
explicitly, use the :meth:`_engine.URL._replace` method adapted
from ``namedtuple``.
:param drivername: new drivername
:param username: new username
:param password: new password
:param host: new hostname
:param port: new port
:param query: new query parameters, passed a dict of string keys
referring to string or sequence of string values. Fully
replaces the previous list of arguments.
:return: new :class:`_engine.URL` object.
.. versionadded:: 1.4
.. seealso::
:meth:`_engine.URL.update_query_dict`
"""
kw: Dict[str, Any] = {}
if drivername is not None:
kw["drivername"] = drivername
if username is not None:
kw["username"] = username
if password is not None:
kw["password"] = password
if host is not None:
kw["host"] = host
if port is not None:
kw["port"] = port
if database is not None:
kw["database"] = database
if query is not None:
kw["query"] = query
return self._assert_replace(**kw)
def _assert_replace(self, **kw: Any) -> URL:
"""argument checks before calling _replace()"""
if "drivername" in kw:
self._assert_str(kw["drivername"], "drivername")
for name in "username", "host", "database":
if name in kw:
self._assert_none_str(kw[name], name)
if "port" in kw:
self._assert_port(kw["port"])
if "query" in kw:
kw["query"] = self._str_dict(kw["query"])
return self._replace(**kw)
def update_query_string(
self, query_string: str, append: bool = False
) -> URL:
"""Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query`
parameter dictionary updated by the given query string.
E.g.::
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname")
>>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
>>> str(url)
'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
:param query_string: a URL escaped query string, not including the
question mark.
:param append: if True, parameters in the existing query string will
not be removed; new parameters will be in addition to those present.
If left at its default of False, keys present in the given query
parameters will replace those of the existing query string.
.. versionadded:: 1.4
.. seealso::
:attr:`_engine.URL.query`
:meth:`_engine.URL.update_query_dict`
""" # noqa: E501
return self.update_query_pairs(parse_qsl(query_string), append=append)
def update_query_pairs(
self,
key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]],
append: bool = False,
) -> URL:
"""Return a new :class:`_engine.URL` object with the
:attr:`_engine.URL.query`
parameter dictionary updated by the given sequence of key/value pairs
E.g.::
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname")
>>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")])
>>> str(url)
'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
:param key_value_pairs: A sequence of tuples containing two strings
each.
:param append: if True, parameters in the existing query string will
not be removed; new parameters will be in addition to those present.
If left at its default of False, keys present in the given query
parameters will replace those of the existing query string.
.. versionadded:: 1.4
.. seealso::
:attr:`_engine.URL.query`
:meth:`_engine.URL.difference_update_query`
:meth:`_engine.URL.set`
""" # noqa: E501
existing_query = self.query
new_keys: Dict[str, Union[str, List[str]]] = {}
for key, value in key_value_pairs:
if key in new_keys:
new_keys[key] = util.to_list(new_keys[key])
cast("List[str]", new_keys[key]).append(cast(str, value))
else:
new_keys[key] = (
list(value) if isinstance(value, (list, tuple)) else value
)
new_query: Mapping[str, Union[str, Sequence[str]]]
if append:
new_query = {}
for k in new_keys:
if k in existing_query:
new_query[k] = tuple(
util.to_list(existing_query[k])
+ util.to_list(new_keys[k])
)
else:
new_query[k] = new_keys[k]
new_query.update(
{
k: existing_query[k]
for k in set(existing_query).difference(new_keys)
}
)
else:
new_query = self.query.union(
{
k: tuple(v) if isinstance(v, list) else v
for k, v in new_keys.items()
}
)
return self.set(query=new_query)
def update_query_dict(
self,
query_parameters: Mapping[str, Union[str, List[str]]],
append: bool = False,
) -> URL:
"""Return a new :class:`_engine.URL` object with the
:attr:`_engine.URL.query` parameter dictionary updated by the given
dictionary.
The dictionary typically contains string keys and string values.
In order to represent a query parameter that is expressed multiple
times, pass a sequence of string values.
E.g.::
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname")
>>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"})
>>> str(url)
'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
:param query_parameters: A dictionary with string keys and values
that are either strings, or sequences of strings.
:param append: if True, parameters in the existing query string will
not be removed; new parameters will be in addition to those present.
If left at its default of False, keys present in the given query
parameters will replace those of the existing query string.
.. versionadded:: 1.4
.. seealso::
:attr:`_engine.URL.query`
:meth:`_engine.URL.update_query_string`
:meth:`_engine.URL.update_query_pairs`
:meth:`_engine.URL.difference_update_query`
:meth:`_engine.URL.set`
""" # noqa: E501
return self.update_query_pairs(query_parameters.items(), append=append)
def difference_update_query(self, names: Iterable[str]) -> URL:
"""
Remove the given names from the :attr:`_engine.URL.query` dictionary,
returning the new :class:`_engine.URL`.
E.g.::
url = url.difference_update_query(['foo', 'bar'])
Equivalent to using :meth:`_engine.URL.set` as follows::
url = url.set(
query={
key: url.query[key]
for key in set(url.query).difference(['foo', 'bar'])
}
)
.. versionadded:: 1.4
.. seealso::
:attr:`_engine.URL.query`
:meth:`_engine.URL.update_query_dict`
:meth:`_engine.URL.set`
"""
if not set(names).intersection(self.query):
return self
return URL(
self.drivername,
self.username,
self.password,
self.host,
self.port,
self.database,
util.immutabledict(
{
key: self.query[key]
for key in set(self.query).difference(names)
}
),
)
@property
def normalized_query(self) -> Mapping[str, Sequence[str]]:
"""Return the :attr:`_engine.URL.query` dictionary with values normalized
into sequences.
As the :attr:`_engine.URL.query` dictionary may contain either
string values or sequences of string values to differentiate between
parameters that are specified multiple times in the query string,
code that needs to handle multiple parameters generically will wish
to use this attribute so that all parameters present are presented
as sequences. Inspiration is from Python's ``urllib.parse.parse_qs``
function. E.g.::
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
>>> url.query
immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'})
>>> url.normalized_query
immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': ('/path/to/crt',)})
""" # noqa: E501
return util.immutabledict(
{
k: (v,) if not isinstance(v, tuple) else v
for k, v in self.query.items()
}
)
@util.deprecated(
"1.4",
"The :meth:`_engine.URL.__to_string__ method is deprecated and will "
"be removed in a future release. Please use the "
":meth:`_engine.URL.render_as_string` method.",
)
def __to_string__(self, hide_password: bool = True) -> str:
"""Render this :class:`_engine.URL` object as a string.
:param hide_password: Defaults to True. The password is not shown
in the string unless this is set to False.
"""
return self.render_as_string(hide_password=hide_password)
def render_as_string(self, hide_password: bool = True) -> str:
"""Render this :class:`_engine.URL` object as a string.
This method is used when the ``__str__()`` or ``__repr__()``
methods are used. The method directly includes additional options.
:param hide_password: Defaults to True. The password is not shown
in the string unless this is set to False.
"""
s = self.drivername + "://"
if self.username is not None:
s += quote(self.username, safe=" +")
if self.password is not None:
s += ":" + (
"***"
if hide_password
else quote(str(self.password), safe=" +")
)
s += "@"
if self.host is not None:
if ":" in self.host:
s += f"[{self.host}]"
else:
s += self.host
if self.port is not None:
s += ":" + str(self.port)
if self.database is not None:
s += "/" + self.database
if self.query:
keys = list(self.query)
keys.sort()
s += "?" + "&".join(
f"{quote_plus(k)}={quote_plus(element)}"
for k in keys
for element in util.to_list(self.query[k])
)
return s
def __repr__(self) -> str:
return self.render_as_string()
def __copy__(self) -> URL:
return self.__class__.create(
self.drivername,
self.username,
self.password,
self.host,
self.port,
self.database,
# note this is an immutabledict of str-> str / tuple of str,
# also fully immutable. does not require deepcopy
self.query,
)
def __deepcopy__(self, memo: Any) -> URL:
return self.__copy__()
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, URL)
and self.drivername == other.drivername
and self.username == other.username
and self.password == other.password
and self.host == other.host
and self.database == other.database
and self.query == other.query
and self.port == other.port
)
def __ne__(self, other: Any) -> bool:
return not self == other
def get_backend_name(self) -> str:
"""Return the backend name.
This is the name that corresponds to the database backend in
use, and is the portion of the :attr:`_engine.URL.drivername`
that is to the left of the plus sign.
"""
if "+" not in self.drivername:
return self.drivername
else:
return self.drivername.split("+")[0]
def get_driver_name(self) -> str:
"""Return the backend name.
This is the name that corresponds to the DBAPI driver in
use, and is the portion of the :attr:`_engine.URL.drivername`
that is to the right of the plus sign.
If the :attr:`_engine.URL.drivername` does not include a plus sign,
then the default :class:`_engine.Dialect` for this :class:`_engine.URL`
is imported in order to get the driver name.
"""
if "+" not in self.drivername:
return self.get_dialect().driver
else:
return self.drivername.split("+")[1]
def _instantiate_plugins(
self, kwargs: Mapping[str, Any]
) -> Tuple[URL, List[Any], Dict[str, Any]]:
plugin_names = util.to_list(self.query.get("plugin", ()))
plugin_names += kwargs.get("plugins", [])
kwargs = dict(kwargs)
loaded_plugins = [
plugins.load(plugin_name)(self, kwargs)
for plugin_name in plugin_names
]
u = self.difference_update_query(["plugin", "plugins"])
for plugin in loaded_plugins:
new_u = plugin.update_url(u)
if new_u is not None:
u = new_u
kwargs.pop("plugins", None)
return u, loaded_plugins, kwargs
def _get_entrypoint(self) -> Type[Dialect]:
"""Return the "entry point" dialect class.
This is normally the dialect itself except in the case when the
returned class implements the get_dialect_cls() method.
"""
if "+" not in self.drivername:
name = self.drivername
else:
name = self.drivername.replace("+", ".")
cls = registry.load(name)
# check for legacy dialects that
# would return a module with 'dialect' as the
# actual class
if (
hasattr(cls, "dialect")
and isinstance(cls.dialect, type)
and issubclass(cls.dialect, Dialect)
):
return cls.dialect
else:
return cast("Type[Dialect]", cls)
def get_dialect(self, _is_async: bool = False) -> Type[Dialect]:
"""Return the SQLAlchemy :class:`_engine.Dialect` class corresponding
to this URL's driver name.
"""
entrypoint = self._get_entrypoint()
if _is_async:
dialect_cls = entrypoint.get_async_dialect_cls(self)
else:
dialect_cls = entrypoint.get_dialect_cls(self)
return dialect_cls
def translate_connect_args(
self, names: Optional[List[str]] = None, **kw: Any
) -> Dict[str, Any]:
r"""Translate url attributes into a dictionary of connection arguments.
Returns attributes of this url (`host`, `database`, `username`,
`password`, `port`) as a plain dictionary. The attribute names are
used as the keys by default. Unset or false attributes are omitted
from the final dictionary.
:param \**kw: Optional, alternate key names for url attributes.
:param names: Deprecated. Same purpose as the keyword-based alternate
names, but correlates the name to the original positionally.
"""
if names is not None:
util.warn_deprecated(
"The `URL.translate_connect_args.name`s parameter is "
"deprecated. Please pass the "
"alternate names as kw arguments.",
"1.4",
)
translated = {}
attribute_names = ["host", "database", "username", "password", "port"]
for sname in attribute_names:
if names:
name = names.pop(0)
elif sname in kw:
name = kw[sname]
else:
name = sname
if name is not None and getattr(self, sname, False):
if sname == "password":
translated[name] = str(getattr(self, sname))
else:
translated[name] = getattr(self, sname)
return translated
def make_url(name_or_url: Union[str, URL]) -> URL:
"""Given a string, produce a new URL instance.
The format of the URL generally follows `RFC-1738
<https://www.ietf.org/rfc/rfc1738.txt>`_, with some exceptions, including
that underscores, and not dashes or periods, are accepted within the
"scheme" portion.
If a :class:`.URL` object is passed, it is returned as is.
.. seealso::
:ref:`database_urls`
"""
if isinstance(name_or_url, str):
return _parse_url(name_or_url)
elif not isinstance(name_or_url, URL) and not hasattr(
name_or_url, "_sqla_is_testing_if_this_is_a_mock_object"
):
raise exc.ArgumentError(
f"Expected string or URL object, got {name_or_url!r}"
)
else:
return name_or_url
def _parse_url(name: str) -> URL:
pattern = re.compile(
r"""
(?P<name>[\w\+]+)://
(?:
(?P<username>[^:/]*)
(?::(?P<password>[^@]*))?
@)?
(?:
(?:
\[(?P<ipv6host>[^/\?]+)\] |
(?P<ipv4host>[^/:\?]+)
)?
(?::(?P<port>[^/\?]*))?
)?
(?:/(?P<database>[^\?]*))?
(?:\?(?P<query>.*))?
""",
re.X,
)
m = pattern.match(name)
if m is not None:
components = m.groupdict()
query: Optional[Dict[str, Union[str, List[str]]]]
if components["query"] is not None:
query = {}
for key, value in parse_qsl(components["query"]):
if key in query:
query[key] = util.to_list(query[key])
cast("List[str]", query[key]).append(value)
else:
query[key] = value
else:
query = None
components["query"] = query
if components["username"] is not None:
components["username"] = unquote(components["username"])
if components["password"] is not None:
components["password"] = unquote(components["password"])
ipv4host = components.pop("ipv4host")
ipv6host = components.pop("ipv6host")
components["host"] = ipv4host or ipv6host
name = components.pop("name")
if components["port"]:
components["port"] = int(components["port"])
return URL.create(name, **components) # type: ignore
else:
raise exc.ArgumentError(
"Could not parse SQLAlchemy URL from string '%s'" % name
)

View file

@ -0,0 +1,166 @@
# engine/util.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import typing
from typing import Any
from typing import Callable
from typing import Optional
from typing import TypeVar
from .. import exc
from .. import util
from ..util._has_cy import HAS_CYEXTENSION
from ..util.typing import Protocol
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import _distill_params_20 as _distill_params_20
from ._py_util import _distill_raw_params as _distill_raw_params
else:
from sqlalchemy.cyextension.util import ( # noqa: F401
_distill_params_20 as _distill_params_20,
)
from sqlalchemy.cyextension.util import ( # noqa: F401
_distill_raw_params as _distill_raw_params,
)
_C = TypeVar("_C", bound=Callable[[], Any])
def connection_memoize(key: str) -> Callable[[_C], _C]:
"""Decorator, memoize a function in a connection.info stash.
Only applicable to functions which take no arguments other than a
connection. The memo will be stored in ``connection.info[key]``.
"""
@util.decorator
def decorated(fn, self, connection): # type: ignore
connection = connection.connect()
try:
return connection.info[key]
except KeyError:
connection.info[key] = val = fn(self, connection)
return val
return decorated
class _TConsSubject(Protocol):
_trans_context_manager: Optional[TransactionalContext]
class TransactionalContext:
"""Apply Python context manager behavior to transaction objects.
Performs validation to ensure the subject of the transaction is not
used if the transaction were ended prematurely.
"""
__slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__")
_trans_subject: Optional[_TConsSubject]
def _transaction_is_active(self) -> bool:
raise NotImplementedError()
def _transaction_is_closed(self) -> bool:
raise NotImplementedError()
def _rollback_can_be_called(self) -> bool:
"""indicates the object is in a state that is known to be acceptable
for rollback() to be called.
This does not necessarily mean rollback() will succeed or not raise
an error, just that there is currently no state detected that indicates
rollback() would fail or emit warnings.
It also does not mean that there's a transaction in progress, as
it is usually safe to call rollback() even if no transaction is
present.
.. versionadded:: 1.4.28
"""
raise NotImplementedError()
def _get_subject(self) -> _TConsSubject:
raise NotImplementedError()
def commit(self) -> None:
raise NotImplementedError()
def rollback(self) -> None:
raise NotImplementedError()
def close(self) -> None:
raise NotImplementedError()
@classmethod
def _trans_ctx_check(cls, subject: _TConsSubject) -> None:
trans_context = subject._trans_context_manager
if trans_context:
if not trans_context._transaction_is_active():
raise exc.InvalidRequestError(
"Can't operate on closed transaction inside context "
"manager. Please complete the context manager "
"before emitting further commands."
)
def __enter__(self) -> TransactionalContext:
subject = self._get_subject()
# none for outer transaction, may be non-None for nested
# savepoint, legacy nesting cases
trans_context = subject._trans_context_manager
self._outer_trans_ctx = trans_context
self._trans_subject = subject
subject._trans_context_manager = self
return self
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
subject = getattr(self, "_trans_subject", None)
# simplistically we could assume that
# "subject._trans_context_manager is self". However, any calling
# code that is manipulating __exit__ directly would break this
# assumption. alembic context manager
# is an example of partial use that just calls __exit__ and
# not __enter__ at the moment. it's safe to assume this is being done
# in the wild also
out_of_band_exit = (
subject is None or subject._trans_context_manager is not self
)
if type_ is None and self._transaction_is_active():
try:
self.commit()
except:
with util.safe_reraise():
if self._rollback_can_be_called():
self.rollback()
finally:
if not out_of_band_exit:
assert subject is not None
subject._trans_context_manager = self._outer_trans_ctx
self._trans_subject = self._outer_trans_ctx = None
else:
try:
if not self._transaction_is_active():
if not self._transaction_is_closed():
self.close()
else:
if self._rollback_can_be_called():
self.rollback()
finally:
if not out_of_band_exit:
assert subject is not None
subject._trans_context_manager = self._outer_trans_ctx
self._trans_subject = self._outer_trans_ctx = None

Some files were not shown because too many files have changed in this diff Show more